001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo;
018
019import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
020import com.oracle.labs.mlrg.olcut.util.MutableLong;
021import org.tribuo.provenance.DataProvenance;
022import org.tribuo.provenance.DatasetProvenance;
023import org.tribuo.transform.TransformStatistics;
024import org.tribuo.transform.Transformation;
025import org.tribuo.transform.TransformationMap;
026import org.tribuo.transform.Transformer;
027import org.tribuo.transform.TransformerMap;
028import org.tribuo.util.Util;
029
030import java.io.Serializable;
031import java.util.ArrayList;
032import java.util.Collections;
033import java.util.HashMap;
034import java.util.Iterator;
035import java.util.LinkedHashMap;
036import java.util.LinkedHashSet;
037import java.util.LinkedList;
038import java.util.List;
039import java.util.Map;
040import java.util.Queue;
041import java.util.Set;
042import java.util.SplittableRandom;
043import java.util.logging.Logger;
044import java.util.regex.Pattern;
045
046/**
047 * A class for sets of data, which are used to train and evaluate classifiers.
048 * <p>
049 * Subclass {@link MutableDataset} rather than this class.
050 * <p>
051 * @param <T> the type of the features in the data set.
052 */
053public abstract class Dataset<T extends Output<T>> implements Iterable<Example<T>>, Provenancable<DatasetProvenance>, Serializable {
054    private static final long serialVersionUID = 2L;
055
056    private static final Logger logger = Logger.getLogger(Dataset.class.getName());
057
058    /**
059     * Users of this RNG should synchronize on the Dataset to prevent replicability issues.
060     */
061    private static final SplittableRandom rng = new SplittableRandom(Trainer.DEFAULT_SEED);
062
063    /**
064     * The data in this data set.
065     */
066    protected final List<Example<T>> data = new ArrayList<>();
067
068    /**
069     * The provenance of the data source, extracted on construction.
070     */
071    protected final DataProvenance sourceProvenance;
072
073    /**
074     * A factory for making {@link OutputInfo} and {@link Output} of the appropriate type.
075     */
076    protected final OutputFactory<T> outputFactory;
077
078    /**
079     * The indices of the shuffled order.
080     */
081    protected int[] indices = null;
082
083    /**
084     * Creates a dataset.
085     * @param provenance A description of the data, including preprocessing steps.
086     * @param outputFactory The output factory.
087     */
088    protected Dataset(DataProvenance provenance, OutputFactory<T> outputFactory) {
089        this.sourceProvenance = provenance;
090        this.outputFactory = outputFactory;
091    }
092
093    /**
094     * Creates a dataset.
095     * @param dataSource the DataSource to use.
096     */
097    protected Dataset(DataSource<T> dataSource) {
098        this(dataSource.getProvenance(),dataSource.getOutputFactory());
099    }
100
101    /**
102     * A String description of this dataset.
103     * @return The description
104     */
105    public String getSourceDescription() {
106        return "Dataset(source="+ sourceProvenance.toString() +")";
107    }
108
109    /**
110     * The provenance of the data this Dataset contains.
111     * @return The data provenance.
112     */
113    public DataProvenance getSourceProvenance() {
114        return sourceProvenance;
115    }
116
117    /**
118     * Gets the examples as an unmodifiable list. This list will throw an UnsupportedOperationException if any elements
119     * are added to it.
120     * <p>
121     * In other words, using the following to add additional examples to this dataset with throw an exception:
122     *
123     * {@code dataset.getData().add(example)}
124     *
125     * Instead, use {@link MutableDataset#add(Example)}.
126     *
127     * @return The unmodifiable example list.
128     */
129    public List<Example<T>> getData() {
130        return Collections.unmodifiableList(data);
131    }
132
133    /**
134     * Gets the output factory this dataset contains.
135     * @return The output factory.
136     */
137    public OutputFactory<T> getOutputFactory() {
138        return outputFactory;
139    }
140
141    /**
142     * Gets the set of outputs that occur in the examples in this dataset.
143     *
144     * @return the set of outputs that occur in the examples in this dataset.
145     */
146    public abstract Set<T> getOutputs();
147
148    /**
149     * Gets the example at the supplied index.
150     * <p>
151     * Throws IllegalArgumentException if the index is invalid or outside the bounds.
152     * @param index The index of the example.
153     * @return The example.
154     */
155    public Example<T> getExample(int index) {
156        if ((index < 0) || (index >= size())) {
157            throw new IllegalArgumentException("Example index " + index + " is out of bounds.");  
158        }
159        return data.get(index);
160    }
161
162    /**
163     * Gets the size of the data set.
164     *
165     * @return the size of the data set.
166     */
167    public int size() {
168        return data.size();
169    }
170
171    /**
172     * Shuffles the indices, or stops shuffling them.
173     * <p>
174     * The shuffle only affects the iterator, it does not affect
175     * {@link Dataset#getExample}.
176     * <p>
177     * Multiple calls with the argument true will shuffle the dataset multiple times.
178     * The RNG is shared across all Dataset instances, so methods which access it are synchronized.
179     * <p>
180     * Using this method will prevent the provenance system from tracking the exact state of the dataset,
181     * which may be important for trainers which depend on the example order, like those
182     * using stochastic gradient descent.
183     * @param shuffle If true shuffle the data.
184     */
185    public synchronized void shuffle(boolean shuffle) {
186        if (shuffle) {
187            indices = Util.randperm(data.size(), rng);
188        } else {
189            indices = null;
190        }
191    }
192
193    /**
194     * Returns or generates an {@link ImmutableOutputInfo}.
195     * @return An immutable output info.
196     */
197    public abstract ImmutableOutputInfo<T> getOutputIDInfo();
198
199    /**
200     * Returns this dataset's {@link OutputInfo}.
201     * @return The output info.
202     */
203    public abstract OutputInfo<T> getOutputInfo();
204
205    /**
206     * Returns or generates an {@link ImmutableFeatureMap}.
207     * @return An immutable feature map with id numbers.
208     */
209    public abstract ImmutableFeatureMap getFeatureIDMap();
210
211    /**
212     * Returns this dataset's {@link FeatureMap}.
213     * @return The feature map from this dataset.
214     */
215    public abstract FeatureMap getFeatureMap();
216
217    @Override
218    public synchronized Iterator<Example<T>> iterator() {
219        if (indices == null) {
220            return data.iterator();
221        } else {
222            return new ShuffleIterator<>(this,indices);
223        }
224    }
225
226    @Override
227    public String toString(){
228        return "Dataset(source="+ sourceProvenance +")";
229    }
230
231    /**
232     * Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by
233     * observing all the values in this dataset.
234     * <p>
235     * Does not mutate the dataset, if you wish to apply the TransformerMap, use
236     * {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
237     * <p>
238     * Currently TransformationMaps and TransformerMaps only operate on feature values
239     * which are present, sparse values are ignored and not transformed. If the zeros
240     * should be transformed, call {@link MutableDataset#densify} on the datasets.
241     * <p>
242     * Throws {@link IllegalArgumentException} if the TransformationMap object has
243     * regexes which apply to multiple features.
244     * @param transformations The transformations to fit.
245     * @return A TransformerMap which can apply the transformations to a dataset.
246     */
247    public TransformerMap createTransformers(TransformationMap transformations) {
248        ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());
249
250        // Validate map by checking no regex applies to multiple features.
251        Map<String,List<Transformation>> featureTransformations = new HashMap<>();
252        for (Map.Entry<String,List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) {
253            // Compile the regex.
254            Pattern pattern = Pattern.compile(entry.getKey());
255            // Check all the feature names
256            for (String name : featureNames) {
257                // If the regex matches
258                if (pattern.matcher(name).matches()) {
259                    List<Transformation> oldTransformations = featureTransformations.put(name,entry.getValue());
260                    // See if there is already a transformation list for that name.
261                    if (oldTransformations != null) {
262                        throw new IllegalArgumentException("Feature name '"
263                                + name + "' matches multiple regexes, at least one of which was '"
264                                + entry.getKey() + "'.");
265                    }
266                }
267            }
268        }
269
270        // Populate the feature transforms map.
271        Map<String,Queue<TransformStatistics>> featureStats = new HashMap<>();
272        // sparseCount tracks how many times a feature was not observed
273        Map<String,MutableLong> sparseCount = new HashMap<>();
274        for (Map.Entry<String,List<Transformation>> entry : featureTransformations.entrySet()) {
275            // Create the queue of feature transformations for this feature
276            Queue<TransformStatistics> l = new LinkedList<>();
277            for (Transformation t : entry.getValue()) {
278                l.add(t.createStats());
279            }
280            // Add the queue to the map for that feature
281            featureStats.put(entry.getKey(),l);
282            sparseCount.put(entry.getKey(), new MutableLong(data.size()));
283        }
284        if (!transformations.getGlobalTransformations().isEmpty()) {
285            // Append all the global transformations
286            for (String v : featureNames) {
287                // Create the queue of feature transformations for this feature
288                Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>());
289                for (Transformation t : transformations.getGlobalTransformations()) {
290                    l.add(t.createStats());
291                }
292                // Add the queue to the map for that feature
293                featureStats.put(v, l);
294                // Generate the sparse count initialised to the number of features.
295                sparseCount.putIfAbsent(v, new MutableLong(data.size()));
296            }
297        }
298
299        Map<String,List<Transformer>> output = new LinkedHashMap<>();
300        Set<String> removeSet = new LinkedHashSet<>();
301        boolean initialisedSparseCounts = false;
302        // Iterate through the dataset max(transformations.length) times.
303        while (!featureStats.isEmpty()) {
304            for (Example<T> example : data) {
305                for (Feature f : example) {
306                    if (featureStats.containsKey(f.getName())) {
307                        if (!initialisedSparseCounts) {
308                            sparseCount.get(f.getName()).decrement();
309                        }
310                        List<Transformer> curTransformers = output.get(f.getName());
311                        // Apply all current transformations
312                        double fValue = TransformerMap.applyTransformerList(f.getValue(), curTransformers);
313                        // Observe the transformed value
314                        featureStats.get(f.getName()).peek().observeValue(fValue);
315                    }
316                }
317            }
318            // Sparse counts are updated (this could be protected by an if statement)
319            initialisedSparseCounts = true;
320
321            removeSet.clear();
322            // Emit the new transformers onto the end of the list in the output map.
323            for (Map.Entry<String,Queue<TransformStatistics>> entry : featureStats.entrySet()) {
324                // Observe all the sparse feature values
325                int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
326                TransformStatistics currentStats = entry.getValue().poll();
327                currentStats.observeSparse(unobservedFeatures);
328                // Get the transformer list for that feature (if absent)
329                List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>());
330                // Generate the transformer and add it to the appropriate list.
331                l.add(currentStats.generateTransformer());
332                // If the queue is empty, remove that feature, ensuring that featureStats is eventually empty.
333                if (entry.getValue().isEmpty()) {
334                    removeSet.add(entry.getKey());
335                }
336            }
337            // Remove the features with empty queues.
338            for (String s : removeSet) {
339                featureStats.remove(s);
340            }
341        }
342
343        return new TransformerMap(output,getProvenance(),transformations.getProvenance());
344    }
345
346    private static class ShuffleIterator<T extends Output<T>> implements Iterator<Example<T>> {
347        private final Dataset<T> data;
348        private final int[] indices;
349        private int index;
350
351        public ShuffleIterator(Dataset<T> data, int[] indices) {
352            this.data = data;
353            this.indices = indices;
354            this.index = 0;
355        }
356
357        @Override
358        public boolean hasNext() {
359            return index < indices.length;
360        }
361
362        @Override
363        public Example<T> next() {
364            Example<T> e = data.getExample(indices[index]);
365            index++;
366            return e;
367        }
368    }
369}
370