001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.evaluation;
018
019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
026import com.oracle.labs.mlrg.olcut.util.Pair;
027import org.tribuo.DataSource;
028import org.tribuo.Example;
029import org.tribuo.Output;
030import org.tribuo.datasource.ListDataSource;
031import org.tribuo.provenance.DataSourceProvenance;
032
033import java.util.ArrayList;
034import java.util.Collections;
035import java.util.Iterator;
036import java.util.List;
037import java.util.Map;
038import java.util.Objects;
039import java.util.Random;
040
041/**
042 * Splits data into training and testing sets. Note that this doesn't
043 * operate on {@link org.tribuo.Dataset}, but rather on {@link DataSource}.
044 * @param <T> The output type of the examples in the datasource.
045 */
046public class TrainTestSplitter<T extends Output<T>> {
047
048    private final DataSource<T> train;
049    
050    private final DataSource<T> test;
051
052    private final DataSourceProvenance originalProvenance;
053
054    private final long seed;
055
056    private final double trainProportion;
057
058    private final int size;
059
060    /**
061     * Creates a splitter that splits a dataset 70/30 train and test using a default seed.
062     * @param data The data to split.
063     */
064    public TrainTestSplitter(DataSource<T> data) {
065        this(data,1);
066    }
067
068    /**
069     * Creates a splitter that splits a dataset 70/30 train and test.
070     * @param data The data to split.
071     * @param seed The seed for the RNG.
072     */
073    public TrainTestSplitter(DataSource<T> data, long seed) {
074        this(data, 0.7, seed);
075    }
076    
077    /**
078     * Creates a splitter that will split the given data set into 
079     * a training and testing set. The give proportion of the data will be
080     * randomly selected for the training set. The remainder will be in the 
081     * test set.
082     * 
083     * @param data the data that we want to split.
084     * @param trainProportion the proportion of the data to select for training. 
085     * This should be a number between 0 and 1. For example, a value of 0.7 means
086     * that 70% of the data should be selected for the training set.
087     * @param seed The seed for the RNG.
088     */
089    public TrainTestSplitter(DataSource<T> data, double trainProportion, long seed) {
090        this.seed = seed;
091        this.trainProportion = trainProportion;
092        this.originalProvenance = data.getProvenance();
093        List<Example<T>> l = new ArrayList<>();
094        for(Example<T> ex : data) {
095            l.add(ex);
096        }
097        this.size = l.size();
098        Random rng = new Random(seed);
099        Collections.shuffle(l,rng);
100        int n = (int) (trainProportion * l.size());
101        train = new ListDataSource<>(l.subList(0, n),data.getOutputFactory(),new SplitDataSourceProvenance(this,true));
102        test = new ListDataSource<>(l.subList(n, l.size()),data.getOutputFactory(),new SplitDataSourceProvenance(this,false));
103    }
104
105    /**
106     * The total amount of data in train and test combined.
107     * @return The number of examples.
108     */
109    public int totalSize() {
110        return size;
111    }
112
113    /**
114     * Gets the training data source.
115     * @return The training data.
116     */
117    public DataSource<T> getTrain() {
118        return train;
119    }
120
121    /**
122     * Gets the testing datasource.
123     * @return The testing data.
124     */
125    public DataSource<T> getTest() {
126        return test;
127    }
128
129    /**
130     * Provenance for a split data source.
131     */
132    public static class SplitDataSourceProvenance implements DataSourceProvenance {
133        private static final long serialVersionUID = 1L;
134
135        private static final String SOURCE = "source";
136        private static final String TRAIN_PROPORTION = "train-proportion";
137        private static final String SEED = "seed";
138        private static final String SIZE = "size";
139        private static final String IS_TRAIN = "is-train";
140
141        private final StringProvenance className;
142        private final DataSourceProvenance innerSourceProvenance;
143        private final DoubleProvenance trainProportion;
144        private final LongProvenance seed;
145        private final IntProvenance size;
146        private final BooleanProvenance isTrain;
147
148        <T extends Output<T>> SplitDataSourceProvenance(TrainTestSplitter<T> host, boolean isTrain) {
149            this.className = new StringProvenance(CLASS_NAME,host.getClass().getName());
150            this.innerSourceProvenance = host.originalProvenance;
151            this.trainProportion = new DoubleProvenance(TRAIN_PROPORTION,host.trainProportion);
152            this.seed = new LongProvenance(SEED,host.seed);
153            this.size = new IntProvenance(SIZE,host.size);
154            this.isTrain = new BooleanProvenance(IS_TRAIN,isTrain);
155        }
156
157        public SplitDataSourceProvenance(Map<String, Provenance> map) {
158            this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class,SplitDataSourceProvenance.class.getSimpleName());
159            this.innerSourceProvenance = ObjectProvenance.checkAndExtractProvenance(map,SOURCE,DataSourceProvenance.class,SplitDataSourceProvenance.class.getSimpleName());
160            this.trainProportion = ObjectProvenance.checkAndExtractProvenance(map,TRAIN_PROPORTION,DoubleProvenance.class,SplitDataSourceProvenance.class.getSimpleName());
161            this.seed = ObjectProvenance.checkAndExtractProvenance(map,SEED,LongProvenance.class,SplitDataSourceProvenance.class.getSimpleName());
162            this.size = ObjectProvenance.checkAndExtractProvenance(map,SIZE,IntProvenance.class,SplitDataSourceProvenance.class.getSimpleName());
163            this.isTrain = ObjectProvenance.checkAndExtractProvenance(map,IS_TRAIN,BooleanProvenance.class,SplitDataSourceProvenance.class.getSimpleName());
164        }
165
166        @Override
167        public String getClassName() {
168            return className.getValue();
169        }
170
171        @Override
172        public Iterator<Pair<String, Provenance>> iterator() {
173            ArrayList<Pair<String,Provenance>> list = new ArrayList<>();
174
175            list.add(new Pair<>(CLASS_NAME,className));
176            list.add(new Pair<>(SOURCE,innerSourceProvenance));
177            list.add(new Pair<>(TRAIN_PROPORTION,trainProportion));
178            list.add(new Pair<>(SEED,seed));
179            list.add(new Pair<>(SIZE,size));
180            list.add(new Pair<>(IS_TRAIN,isTrain));
181
182            return list.iterator();
183        }
184
185        @Override
186        public boolean equals(Object o) {
187            if (this == o) return true;
188            if (!(o instanceof SplitDataSourceProvenance)) return false;
189            SplitDataSourceProvenance pairs = (SplitDataSourceProvenance) o;
190            return className.equals(pairs.className) &&
191                    innerSourceProvenance.equals(pairs.innerSourceProvenance) &&
192                    trainProportion.equals(pairs.trainProportion) &&
193                    seed.equals(pairs.seed) &&
194                    size.equals(pairs.size) &&
195                    isTrain.equals(pairs.isTrain);
196        }
197
198        @Override
199        public int hashCode() {
200            return Objects.hash(className, innerSourceProvenance, trainProportion, seed, size, isTrain);
201        }
202
203        @Override
204        public String toString() {
205            return "SplitDataSourceProvenance(" +
206                    "className=" + className +
207                    ",innerSourceProvenance=" + innerSourceProvenance +
208                    ",trainProportion=" + trainProportion +
209                    ",seed=" + seed +
210                    ",size=" + size +
211                    ",isTrain=" + isTrain +
212                    ')';
213        }
214    }
215}