001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.dataset;
018
019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException;
023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
027import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
028import com.oracle.labs.mlrg.olcut.util.Pair;
029import org.tribuo.Dataset;
030import org.tribuo.Example;
031import org.tribuo.ImmutableDataset;
032import org.tribuo.ImmutableFeatureMap;
033import org.tribuo.ImmutableOutputInfo;
034import org.tribuo.Output;
035import org.tribuo.provenance.DatasetProvenance;
036import org.tribuo.util.Util;
037
038import java.util.ArrayList;
039import java.util.Arrays;
040import java.util.Collections;
041import java.util.Iterator;
042import java.util.List;
043import java.util.Map;
044import java.util.Objects;
045import java.util.Set;
046import java.util.SplittableRandom;
047import java.util.function.Predicate;
048
049/**
050 * DatasetView provides an immutable view on another {@link Dataset} that only exposes selected examples.
051 * Does not copy the examples.
052 *
053 * @param <T> The output type of this dataset.
054 */
055public final class DatasetView<T extends Output<T>> extends ImmutableDataset<T> {
056    private static final long serialVersionUID = 1L;
057
058    private final Dataset<T> innerDataset;
059    
060    private final int size;
061
062    private final int[] exampleIndices;
063
064    private final long seed;
065
066    private final String tag;
067
068    private final boolean sampled;
069
070    private final boolean weighted;
071
072    private boolean storeIndices = false;
073
074    /**
075     * Creates a DatasetView which includes the supplied indices from the dataset.
076     * <p>
077     * It uses the feature and output infos from the wrapped dataset.
078     *
079     * @param dataset The dataset to wrap.
080     * @param exampleIndices The indices to present.
081     * @param tag A tag for the view.
082     */
083    public DatasetView(Dataset<T> dataset, int[] exampleIndices, String tag) {
084        this(dataset,exampleIndices,dataset.getFeatureIDMap(),dataset.getOutputIDInfo(), tag);
085    }
086
087    /**
088     * Creates a DatasetView which includes the supplied indices from the dataset.
089     * <p>
090     * This takes the ImmutableFeatureMap and ImmutableOutputInfo parameters to save them being
091     * regenerated (e.g., in BaggingTrainer).
092     *
093     * @param dataset The dataset to sample from.
094     * @param exampleIndices The indices of this view in the wrapped dataset.
095     * @param featureIDs The featureIDs to use for this dataset.
096     * @param labelIDs The labelIDs to use for this dataset.
097     * @param tag A tag for the view.
098     */
099    public DatasetView(Dataset<T> dataset, int[] exampleIndices, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, String tag) {
100        super(dataset.getProvenance(),dataset.getOutputFactory(),featureIDs,labelIDs);
101        if (!validateIndices(dataset.size(),exampleIndices)) {
102            throw new IllegalArgumentException("Invalid indices supplied, dataset.size() = " + dataset.size() + ", but found a negative index or a value greater than or equal to size.");
103        }
104        this.innerDataset = dataset;
105        this.size = exampleIndices.length;
106        this.exampleIndices = exampleIndices;
107        this.seed = -1;
108        this.tag = tag;
109        this.storeIndices = true;
110        this.sampled = false;
111        this.weighted = false;
112    }
113
114    /**
115     * Constructor used by the sampling factory methods.
116     * @param dataset The dataset to create the view over.
117     * @param exampleIndices The indices to use.
118     * @param seed The seed for the RNG.
119     * @param featureIDs The feature IDs to use.
120     * @param outputIDs The output IDs to use.
121     * @param weighted Is it a weighted sample? (Weighted samples store the indices in the provenance by default).
122     */
123    private DatasetView(Dataset<T> dataset, int[] exampleIndices, long seed, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs, boolean weighted) {
124        super(dataset.getProvenance(),dataset.getOutputFactory(),featureIDs,outputIDs);
125        this.innerDataset = dataset;
126        this.size = exampleIndices.length;
127        this.exampleIndices = exampleIndices;
128        this.tag = "";
129        this.seed = seed;
130        this.sampled = true;
131        this.weighted = weighted;
132        this.storeIndices = weighted;
133    }
134
135    /**
136     * Creates a view from the supplied dataset, using the specified predicate to
137     * test if each example should be in this view.
138     * @param dataset The dataset to create a view over.
139     * @param predicate The predicate which determines if an example is in this view.
140     * @param tag A tag denoting what the predicate does.
141     * @param <T> The type of the Output in the dataset.
142     * @return A dataset view containing each example where the predicate is true.
143     */
144    public static <T extends Output<T>> DatasetView<T> createView(Dataset<T> dataset, Predicate<Example<T>> predicate, String tag) {
145        List<Integer> selectedIndices = new ArrayList<>();
146
147        int i = 0;
148        for (Example<T> e : dataset) {
149            if (predicate.test(e)) {
150                selectedIndices.add(i);
151            }
152            i++;
153        }
154
155        int[] exampleIndices = Util.toPrimitiveInt(selectedIndices);
156        return new DatasetView<>(dataset,exampleIndices,tag);
157    }
158
159    /**
160     * Generates a DatasetView bootstrapped from the supplied Dataset.
161     *
162     * @param dataset The dataset to sample from.
163     * @param size The size of the sample.
164     * @param seed A seed for the RNG.
165     * @param <T> The type of the Output in the dataset.
166     * @return A dataset view containing a bootstrap sample of the supplied dataset.
167     */
168    public static <T extends Output<T>> DatasetView<T> createBootstrapView(Dataset<T> dataset, int size, long seed) {
169        return createBootstrapView(dataset,size,seed,dataset.getFeatureIDMap(),dataset.getOutputIDInfo());
170    }
171
172    /**
173     * Generates a DatasetView bootstrapped from the supplied Dataset.
174     * <p>
175     * This takes the ImmutableFeatureMap and ImmutableOutputInfo parameters to save them being
176     * regenerated.
177     *
178     * @param dataset The dataset to sample from.
179     * @param size The size of the sample.
180     * @param seed A seed for the RNG.
181     * @param featureIDs The featureIDs to use for this dataset.
182     * @param outputIDs The output info to use for this dataset.
183     * @param <T> The type of the Output in the dataset.
184     * @return A dataset view containing a bootstrap sample of the supplied dataset.
185     */
186    public static <T extends Output<T>> DatasetView<T> createBootstrapView(Dataset<T> dataset, int size, long seed, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs) {
187        int[] bootstrapIndices = Util.generateBootstrapIndices(size, new SplittableRandom(seed));
188        return new DatasetView<>(dataset, bootstrapIndices, seed, featureIDs, outputIDs, false);
189    }
190
191    /**
192     * Generates a DatasetView bootstrapped from the supplied Dataset using the supplied
193     * example weights.
194     *
195     * @param dataset The dataset to sample from.
196     * @param size The size of the sample.
197     * @param seed A seed for the RNG.
198     * @param exampleWeights The sampling weights for each example, must be in the range 0,1.
199     * @param <T> The type of the Output in the dataset.
200     * @return A dataset view containing a weighted bootstrap sample of the supplied dataset.
201     */
202    public static <T extends Output<T>> DatasetView<T> createWeightedBootstrapView(Dataset<T> dataset, int size, long seed, float[] exampleWeights) {
203        return createWeightedBootstrapView(dataset,size,seed,exampleWeights,dataset.getFeatureIDMap(),dataset.getOutputIDInfo());
204    }
205
206    /**
207     * Generates a DatasetView bootstrapped from the supplied Dataset using the supplied
208     * example weights.
209     * <p>
210     * This takes the ImmutableFeatureMap and ImmutableOutputInfo parameters to save them being
211     * regenerated.
212     *
213     * @param dataset The dataset to sample from.
214     * @param size The size of the sample.
215     * @param seed A seed for the RNG.
216     * @param exampleWeights The sampling weights for each example, must be in the range 0,1.
217     * @param featureIDs The featureIDs to use for this dataset.
218     * @param outputIDs The output info to use for this dataset.
219     * @param <T> The type of the Output in the dataset.
220     * @return A dataset view containing a weighted bootstrap sample of the supplied dataset.
221     */
222    public static <T extends Output<T>> DatasetView<T> createWeightedBootstrapView(Dataset<T> dataset, int size, long seed, float[] exampleWeights, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs) {
223        if (dataset.size() != exampleWeights.length) {
224            throw new IllegalArgumentException("There must be a weight for each example, dataset.size()="+dataset.size()+", exampleWeights.length="+exampleWeights.length);
225        }
226        int[] bootstrapIndices = Util.generateWeightedIndicesSample(size,exampleWeights,new SplittableRandom(seed));
227        return new DatasetView<>(dataset, bootstrapIndices, seed, featureIDs, outputIDs,true);
228    }
229
230    /**
231     * Are the indices stored in the provenance system.
232     * @return True if the indices will be stored in the provenance of this view.
233     */
234    public boolean storeIndicesInProvenance() {
235        return storeIndices;
236    }
237
238    /**
239     * Set to true to store the indices in the provenance system.
240     * @param storeIndices True if the indices should be stored in the provenance of this view.
241     */
242    public void setStoreIndices(boolean storeIndices) {
243        this.storeIndices = storeIndices;
244    }
245
246    @Override
247    public String toString() {
248        StringBuilder buffer = new StringBuilder();
249
250        buffer.append("DatasetView(innerDataset=");
251        buffer.append(innerDataset.getSourceDescription());
252        buffer.append(",size=");
253        buffer.append(size);
254        buffer.append(",seed=");
255        buffer.append(seed);
256        buffer.append(",tag=");
257        buffer.append(tag);
258        buffer.append(")");
259        
260        return buffer.toString();
261    }
262
263    /**
264     * Gets the set of outputs that occur in the examples in this dataset.
265     *
266     * @return the set of outputs that occur in the examples in this dataset.
267     */
268    @Override
269    public Set<T> getOutputs() {
270        return innerDataset.getOutputs();
271    }
272
273    /**
274     * Gets the size of the data set.
275     *
276     * @return the size of the data set.
277     */
278    @Override
279    public int size() {
280        return size;
281    }
282
283    @Override
284    public ImmutableFeatureMap getFeatureMap() {
285        return featureIDMap;
286    }
287
288    @Override
289    public ImmutableOutputInfo<T> getOutputInfo() {
290        return outputIDInfo;
291    }
292
293    @Override
294    public Iterator<Example<T>> iterator() {
295        return new ViewIterator<>(this);
296    }
297
298    @Override
299    public List<Example<T>> getData() {
300        ArrayList<Example<T>> data = new ArrayList<>();
301        for (int index : exampleIndices) {
302            data.add(innerDataset.getExample(index));
303        }
304        return Collections.unmodifiableList(data);
305    }
306
307    @Override
308    public Example<T> getExample(int index) {
309        if ((index < 0) || (index >= size())) {
310            throw new IllegalArgumentException("Example index " + index + " is out of bounds.");  
311        }
312        return innerDataset.getExample(exampleIndices[index]);
313    }
314
315    @Override
316    public DatasetViewProvenance getProvenance() {
317        return new DatasetViewProvenance(this,storeIndices);
318    }
319
320    /**
321     * Returns a copy of the indicies used in this view.
322     * @return The indices.
323     */
324    public int[] getExampleIndices() {
325        return Arrays.copyOf(exampleIndices,exampleIndices.length);
326    }
327
328    /**
329     * Checks that all the indices are non-negative and less than size.
330     * @param size The maximum size.
331     * @param indices The indices to check.
332     * @return True if the indices are valid for the given size, false otherwise.
333     */
334    private static boolean validateIndices(int size, int[] indices) {
335        boolean valid = true;
336
337        for (int i = 0; i < indices.length; i++) {
338            int idx = indices[i];
339            valid &= idx < size && idx > -1;
340        }
341
342        return valid;
343    }
344
345    private static class ViewIterator<T extends Output<T>> implements Iterator<Example<T>> {
346
347        private int counter = 0;
348        private final DatasetView<T> dataset;
349
350        ViewIterator(DatasetView<T> dataset) {
351            this.dataset = dataset;
352        }
353
354        @Override
355        public boolean hasNext() {
356            return counter < dataset.size();
357        }
358
359        @Override
360        public Example<T> next() {
361            Example<T> example = dataset.getExample(counter);
362            counter++;
363            return example;
364        }
365
366    }
367
368    /**
369     * Provenance for the {@link DatasetView}.
370     */
371    public static final class DatasetViewProvenance extends DatasetProvenance {
372        private static final long serialVersionUID = 1L;
373
374        private static final String SIZE = "size";
375        private static final String SEED = "seed";
376        private static final String TAG = "tag";
377        private static final String SAMPLED = "sampled";
378        private static final String WEIGHTED = "weighted";
379        private static final String INDICES = "indices";
380
381        private final IntProvenance size;
382        private final LongProvenance seed;
383        private final StringProvenance tag;
384        private final BooleanProvenance weighted;
385        private final BooleanProvenance sampled;
386        private final int[] indices;
387
388        <T extends Output<T>> DatasetViewProvenance(DatasetView<T> dataset, boolean storeIndices) {
389            super(dataset.sourceProvenance, new ListProvenance<>(), dataset);
390            this.size = new IntProvenance(SIZE,dataset.size);
391            this.seed = new LongProvenance(SEED,dataset.seed);
392            this.weighted = new BooleanProvenance(WEIGHTED,dataset.weighted);
393            this.sampled = new BooleanProvenance(SAMPLED,dataset.sampled);
394            this.tag = new StringProvenance(TAG,dataset.tag);
395            this.indices = storeIndices ? dataset.indices : new int[0];
396        }
397
398        public DatasetViewProvenance(Map<String,Provenance> map) {
399            super(map);
400            this.size = ObjectProvenance.checkAndExtractProvenance(map,SIZE,IntProvenance.class, DatasetViewProvenance.class.getSimpleName());
401            this.seed = ObjectProvenance.checkAndExtractProvenance(map,SEED,LongProvenance.class, DatasetViewProvenance.class.getSimpleName());
402            this.tag = ObjectProvenance.checkAndExtractProvenance(map,TAG,StringProvenance.class, DatasetViewProvenance.class.getSimpleName());
403            this.weighted = ObjectProvenance.checkAndExtractProvenance(map,WEIGHTED,BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName());
404            this.sampled = ObjectProvenance.checkAndExtractProvenance(map,SAMPLED,BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName());
405            @SuppressWarnings("unchecked") // List provenance cast
406            ListProvenance<IntProvenance> listIndices = ObjectProvenance.checkAndExtractProvenance(map,INDICES,ListProvenance.class, DatasetViewProvenance.class.getSimpleName());
407            if (listIndices.getList().size() > 0) {
408                try {
409                    IntProvenance i = listIndices.getList().get(0);
410                } catch (ClassCastException e) {
411                    throw new ProvenanceException("Loaded another class when expecting an ListProvenance<IntProvenance>",e);
412                }
413            }
414            this.indices = Util.toPrimitiveInt(ProvenanceUtil.unwrap(listIndices));
415        }
416
417        /**
418         * Generates the indices from this DatasetViewProvenance
419         * by rerunning the bootstrap sample.
420         *
421         * Note these indices are invalid if the view is a weighted sample, or
422         * not sampled.
423         * @return The bootstrap indices.
424         */
425        public int[] generateBootstrap() {
426            return Util.generateBootstrapIndices(size.getValue(), new SplittableRandom(seed.getValue()));
427        }
428
429        /**
430         * Is this view from a bootstrap sample.
431         * @return True if it's a bootstrap sample.
432         */
433        public boolean isSampled() {
434            return sampled.getValue();
435        }
436
437        /**
438         * Is this view a weighted bootstrap sample.
439         * @return True if it's a weighted bootstrap sample.
440         */
441        public boolean isWeighted() {
442            return weighted.getValue();
443        }
444
445        @Override
446        public boolean equals(Object o) {
447            if (this == o) return true;
448            if (!(o instanceof DatasetView.DatasetViewProvenance)) return false;
449            if (!super.equals(o)) return false;
450            DatasetViewProvenance pairs = (DatasetViewProvenance) o;
451            return size.equals(pairs.size) && seed.equals(pairs.seed) &&
452                    tag.equals(pairs.tag);
453        }
454
455        @Override
456        public int hashCode() {
457            return Objects.hash(super.hashCode(), size, seed, tag);
458        }
459
460        @Override
461        protected List<Pair<String, Provenance>> allProvenances() {
462            List<Pair<String,Provenance>> provenances = super.allProvenances();
463            provenances.add(new Pair<>(SIZE,size));
464            provenances.add(new Pair<>(SEED,seed));
465            provenances.add(new Pair<>(TAG,tag));
466            provenances.add(new Pair<>(WEIGHTED,weighted));
467            provenances.add(new Pair<>(SAMPLED,sampled));
468            provenances.add(new Pair<>(INDICES,boxArray()));
469            return provenances;
470        }
471
472        private ListProvenance<IntProvenance> boxArray() {
473            List<IntProvenance> list = new ArrayList<>();
474
475            for (int i = 0; i < indices.length; i++) {
476                list.add(new IntProvenance("indices",indices[i]));
477            }
478
479            return new ListProvenance<>(list);
480        }
481
482        /**
483         * This toString doesn't put the indices in the string, as it's likely
484         * to be huge.
485         * @return A string describing this provenance.
486         */
487        @Override
488        public String toString() {
489            List<Pair<String,Provenance>> provenances = super.allProvenances();
490            provenances.add(new Pair<>(SIZE,size));
491            provenances.add(new Pair<>(SEED,seed));
492            provenances.add(new Pair<>(TAG,tag));
493            provenances.add(new Pair<>(WEIGHTED,weighted));
494            provenances.add(new Pair<>(SAMPLED,sampled));
495            provenances.add(new Pair<>(INDICES,new ListProvenance<>()));
496
497            StringBuilder sb = new StringBuilder();
498
499            sb.append("DatasetView(");
500            for (Pair<String,Provenance> p : provenances) {
501                sb.append(p.getA());
502                sb.append('=');
503                sb.append(p.getB().toString());
504                sb.append(',');
505            }
506            sb.replace(sb.length()-1,sb.length(),")");
507
508            return sb.toString();
509        }
510    }
511}