001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.evaluation; 018 019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance; 023import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 026import com.oracle.labs.mlrg.olcut.util.Pair; 027import org.tribuo.DataSource; 028import org.tribuo.Example; 029import org.tribuo.Output; 030import org.tribuo.datasource.ListDataSource; 031import org.tribuo.provenance.DataSourceProvenance; 032 033import java.util.ArrayList; 034import java.util.Collections; 035import java.util.Iterator; 036import java.util.List; 037import java.util.Map; 038import java.util.Objects; 039import java.util.Random; 040 041/** 042 * Splits data into training and testing sets. Note that this doesn't 043 * operate on {@link org.tribuo.Dataset}, but rather on {@link DataSource}. 044 * @param <T> The output type of the examples in the datasource. 045 */ 046public class TrainTestSplitter<T extends Output<T>> { 047 048 private final DataSource<T> train; 049 050 private final DataSource<T> test; 051 052 private final DataSourceProvenance originalProvenance; 053 054 private final long seed; 055 056 private final double trainProportion; 057 058 private final int size; 059 060 /** 061 * Creates a splitter that splits a dataset 70/30 train and test using a default seed. 062 * @param data The data to split. 063 */ 064 public TrainTestSplitter(DataSource<T> data) { 065 this(data,1); 066 } 067 068 /** 069 * Creates a splitter that splits a dataset 70/30 train and test. 070 * @param data The data to split. 071 * @param seed The seed for the RNG. 072 */ 073 public TrainTestSplitter(DataSource<T> data, long seed) { 074 this(data, 0.7, seed); 075 } 076 077 /** 078 * Creates a splitter that will split the given data set into 079 * a training and testing set. The give proportion of the data will be 080 * randomly selected for the training set. The remainder will be in the 081 * test set. 082 * 083 * @param data the data that we want to split. 084 * @param trainProportion the proportion of the data to select for training. 085 * This should be a number between 0 and 1. For example, a value of 0.7 means 086 * that 70% of the data should be selected for the training set. 087 * @param seed The seed for the RNG. 088 */ 089 public TrainTestSplitter(DataSource<T> data, double trainProportion, long seed) { 090 this.seed = seed; 091 this.trainProportion = trainProportion; 092 this.originalProvenance = data.getProvenance(); 093 List<Example<T>> l = new ArrayList<>(); 094 for(Example<T> ex : data) { 095 l.add(ex); 096 } 097 this.size = l.size(); 098 Random rng = new Random(seed); 099 Collections.shuffle(l,rng); 100 int n = (int) (trainProportion * l.size()); 101 train = new ListDataSource<>(l.subList(0, n),data.getOutputFactory(),new SplitDataSourceProvenance(this,true)); 102 test = new ListDataSource<>(l.subList(n, l.size()),data.getOutputFactory(),new SplitDataSourceProvenance(this,false)); 103 } 104 105 /** 106 * The total amount of data in train and test combined. 107 * @return The number of examples. 108 */ 109 public int totalSize() { 110 return size; 111 } 112 113 /** 114 * Gets the training data source. 115 * @return The training data. 116 */ 117 public DataSource<T> getTrain() { 118 return train; 119 } 120 121 /** 122 * Gets the testing datasource. 123 * @return The testing data. 124 */ 125 public DataSource<T> getTest() { 126 return test; 127 } 128 129 /** 130 * Provenance for a split data source. 131 */ 132 public static class SplitDataSourceProvenance implements DataSourceProvenance { 133 private static final long serialVersionUID = 1L; 134 135 private static final String SOURCE = "source"; 136 private static final String TRAIN_PROPORTION = "train-proportion"; 137 private static final String SEED = "seed"; 138 private static final String SIZE = "size"; 139 private static final String IS_TRAIN = "is-train"; 140 141 private final StringProvenance className; 142 private final DataSourceProvenance innerSourceProvenance; 143 private final DoubleProvenance trainProportion; 144 private final LongProvenance seed; 145 private final IntProvenance size; 146 private final BooleanProvenance isTrain; 147 148 <T extends Output<T>> SplitDataSourceProvenance(TrainTestSplitter<T> host, boolean isTrain) { 149 this.className = new StringProvenance(CLASS_NAME,host.getClass().getName()); 150 this.innerSourceProvenance = host.originalProvenance; 151 this.trainProportion = new DoubleProvenance(TRAIN_PROPORTION,host.trainProportion); 152 this.seed = new LongProvenance(SEED,host.seed); 153 this.size = new IntProvenance(SIZE,host.size); 154 this.isTrain = new BooleanProvenance(IS_TRAIN,isTrain); 155 } 156 157 public SplitDataSourceProvenance(Map<String, Provenance> map) { 158 this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class,SplitDataSourceProvenance.class.getSimpleName()); 159 this.innerSourceProvenance = ObjectProvenance.checkAndExtractProvenance(map,SOURCE,DataSourceProvenance.class,SplitDataSourceProvenance.class.getSimpleName()); 160 this.trainProportion = ObjectProvenance.checkAndExtractProvenance(map,TRAIN_PROPORTION,DoubleProvenance.class,SplitDataSourceProvenance.class.getSimpleName()); 161 this.seed = ObjectProvenance.checkAndExtractProvenance(map,SEED,LongProvenance.class,SplitDataSourceProvenance.class.getSimpleName()); 162 this.size = ObjectProvenance.checkAndExtractProvenance(map,SIZE,IntProvenance.class,SplitDataSourceProvenance.class.getSimpleName()); 163 this.isTrain = ObjectProvenance.checkAndExtractProvenance(map,IS_TRAIN,BooleanProvenance.class,SplitDataSourceProvenance.class.getSimpleName()); 164 } 165 166 @Override 167 public String getClassName() { 168 return className.getValue(); 169 } 170 171 @Override 172 public Iterator<Pair<String, Provenance>> iterator() { 173 ArrayList<Pair<String,Provenance>> list = new ArrayList<>(); 174 175 list.add(new Pair<>(CLASS_NAME,className)); 176 list.add(new Pair<>(SOURCE,innerSourceProvenance)); 177 list.add(new Pair<>(TRAIN_PROPORTION,trainProportion)); 178 list.add(new Pair<>(SEED,seed)); 179 list.add(new Pair<>(SIZE,size)); 180 list.add(new Pair<>(IS_TRAIN,isTrain)); 181 182 return list.iterator(); 183 } 184 185 @Override 186 public boolean equals(Object o) { 187 if (this == o) return true; 188 if (!(o instanceof SplitDataSourceProvenance)) return false; 189 SplitDataSourceProvenance pairs = (SplitDataSourceProvenance) o; 190 return className.equals(pairs.className) && 191 innerSourceProvenance.equals(pairs.innerSourceProvenance) && 192 trainProportion.equals(pairs.trainProportion) && 193 seed.equals(pairs.seed) && 194 size.equals(pairs.size) && 195 isTrain.equals(pairs.isTrain); 196 } 197 198 @Override 199 public int hashCode() { 200 return Objects.hash(className, innerSourceProvenance, trainProportion, seed, size, isTrain); 201 } 202 203 @Override 204 public String toString() { 205 return "SplitDataSourceProvenance(" + 206 "className=" + className + 207 ",innerSourceProvenance=" + innerSourceProvenance + 208 ",trainProportion=" + trainProportion + 209 ",seed=" + seed + 210 ",size=" + size + 211 ",isTrain=" + isTrain + 212 ')'; 213 } 214 } 215}