001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo; 018 019import com.oracle.labs.mlrg.olcut.provenance.Provenancable; 020import com.oracle.labs.mlrg.olcut.util.MutableLong; 021import org.tribuo.provenance.DataProvenance; 022import org.tribuo.provenance.DatasetProvenance; 023import org.tribuo.transform.TransformStatistics; 024import org.tribuo.transform.Transformation; 025import org.tribuo.transform.TransformationMap; 026import org.tribuo.transform.Transformer; 027import org.tribuo.transform.TransformerMap; 028import org.tribuo.util.Util; 029 030import java.io.Serializable; 031import java.util.ArrayList; 032import java.util.Collections; 033import java.util.HashMap; 034import java.util.Iterator; 035import java.util.LinkedHashMap; 036import java.util.LinkedHashSet; 037import java.util.LinkedList; 038import java.util.List; 039import java.util.Map; 040import java.util.Queue; 041import java.util.Set; 042import java.util.SplittableRandom; 043import java.util.logging.Logger; 044import java.util.regex.Pattern; 045 046/** 047 * A class for sets of data, which are used to train and evaluate classifiers. 048 * <p> 049 * Subclass {@link MutableDataset} rather than this class. 050 * <p> 051 * @param <T> the type of the features in the data set. 052 */ 053public abstract class Dataset<T extends Output<T>> implements Iterable<Example<T>>, Provenancable<DatasetProvenance>, Serializable { 054 private static final long serialVersionUID = 2L; 055 056 private static final Logger logger = Logger.getLogger(Dataset.class.getName()); 057 058 /** 059 * Users of this RNG should synchronize on the Dataset to prevent replicability issues. 060 */ 061 private static final SplittableRandom rng = new SplittableRandom(Trainer.DEFAULT_SEED); 062 063 /** 064 * The data in this data set. 065 */ 066 protected final List<Example<T>> data = new ArrayList<>(); 067 068 /** 069 * The provenance of the data source, extracted on construction. 070 */ 071 protected final DataProvenance sourceProvenance; 072 073 /** 074 * A factory for making {@link OutputInfo} and {@link Output} of the appropriate type. 075 */ 076 protected final OutputFactory<T> outputFactory; 077 078 /** 079 * The indices of the shuffled order. 080 */ 081 protected int[] indices = null; 082 083 /** 084 * Creates a dataset. 085 * @param provenance A description of the data, including preprocessing steps. 086 * @param outputFactory The output factory. 087 */ 088 protected Dataset(DataProvenance provenance, OutputFactory<T> outputFactory) { 089 this.sourceProvenance = provenance; 090 this.outputFactory = outputFactory; 091 } 092 093 /** 094 * Creates a dataset. 095 * @param dataSource the DataSource to use. 096 */ 097 protected Dataset(DataSource<T> dataSource) { 098 this(dataSource.getProvenance(),dataSource.getOutputFactory()); 099 } 100 101 /** 102 * A String description of this dataset. 103 * @return The description 104 */ 105 public String getSourceDescription() { 106 return "Dataset(source="+ sourceProvenance.toString() +")"; 107 } 108 109 /** 110 * The provenance of the data this Dataset contains. 111 * @return The data provenance. 112 */ 113 public DataProvenance getSourceProvenance() { 114 return sourceProvenance; 115 } 116 117 /** 118 * Gets the examples as an unmodifiable list. This list will throw an UnsupportedOperationException if any elements 119 * are added to it. 120 * <p> 121 * In other words, using the following to add additional examples to this dataset with throw an exception: 122 * 123 * {@code dataset.getData().add(example)} 124 * 125 * Instead, use {@link MutableDataset#add(Example)}. 126 * 127 * @return The unmodifiable example list. 128 */ 129 public List<Example<T>> getData() { 130 return Collections.unmodifiableList(data); 131 } 132 133 /** 134 * Gets the output factory this dataset contains. 135 * @return The output factory. 136 */ 137 public OutputFactory<T> getOutputFactory() { 138 return outputFactory; 139 } 140 141 /** 142 * Gets the set of outputs that occur in the examples in this dataset. 143 * 144 * @return the set of outputs that occur in the examples in this dataset. 145 */ 146 public abstract Set<T> getOutputs(); 147 148 /** 149 * Gets the example at the supplied index. 150 * <p> 151 * Throws IllegalArgumentException if the index is invalid or outside the bounds. 152 * @param index The index of the example. 153 * @return The example. 154 */ 155 public Example<T> getExample(int index) { 156 if ((index < 0) || (index >= size())) { 157 throw new IllegalArgumentException("Example index " + index + " is out of bounds."); 158 } 159 return data.get(index); 160 } 161 162 /** 163 * Gets the size of the data set. 164 * 165 * @return the size of the data set. 166 */ 167 public int size() { 168 return data.size(); 169 } 170 171 /** 172 * Shuffles the indices, or stops shuffling them. 173 * <p> 174 * The shuffle only affects the iterator, it does not affect 175 * {@link Dataset#getExample}. 176 * <p> 177 * Multiple calls with the argument true will shuffle the dataset multiple times. 178 * The RNG is shared across all Dataset instances, so methods which access it are synchronized. 179 * <p> 180 * Using this method will prevent the provenance system from tracking the exact state of the dataset, 181 * which may be important for trainers which depend on the example order, like those 182 * using stochastic gradient descent. 183 * @param shuffle If true shuffle the data. 184 */ 185 public synchronized void shuffle(boolean shuffle) { 186 if (shuffle) { 187 indices = Util.randperm(data.size(), rng); 188 } else { 189 indices = null; 190 } 191 } 192 193 /** 194 * Returns or generates an {@link ImmutableOutputInfo}. 195 * @return An immutable output info. 196 */ 197 public abstract ImmutableOutputInfo<T> getOutputIDInfo(); 198 199 /** 200 * Returns this dataset's {@link OutputInfo}. 201 * @return The output info. 202 */ 203 public abstract OutputInfo<T> getOutputInfo(); 204 205 /** 206 * Returns or generates an {@link ImmutableFeatureMap}. 207 * @return An immutable feature map with id numbers. 208 */ 209 public abstract ImmutableFeatureMap getFeatureIDMap(); 210 211 /** 212 * Returns this dataset's {@link FeatureMap}. 213 * @return The feature map from this dataset. 214 */ 215 public abstract FeatureMap getFeatureMap(); 216 217 @Override 218 public synchronized Iterator<Example<T>> iterator() { 219 if (indices == null) { 220 return data.iterator(); 221 } else { 222 return new ShuffleIterator<>(this,indices); 223 } 224 } 225 226 @Override 227 public String toString(){ 228 return "Dataset(source="+ sourceProvenance +")"; 229 } 230 231 /** 232 * Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by 233 * observing all the values in this dataset. 234 * <p> 235 * Does not mutate the dataset, if you wish to apply the TransformerMap, use 236 * {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}. 237 * <p> 238 * Currently TransformationMaps and TransformerMaps only operate on feature values 239 * which are present, sparse values are ignored and not transformed. If the zeros 240 * should be transformed, call {@link MutableDataset#densify} on the datasets. 241 * <p> 242 * Throws {@link IllegalArgumentException} if the TransformationMap object has 243 * regexes which apply to multiple features. 244 * @param transformations The transformations to fit. 245 * @return A TransformerMap which can apply the transformations to a dataset. 246 */ 247 public TransformerMap createTransformers(TransformationMap transformations) { 248 ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet()); 249 250 // Validate map by checking no regex applies to multiple features. 251 Map<String,List<Transformation>> featureTransformations = new HashMap<>(); 252 for (Map.Entry<String,List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) { 253 // Compile the regex. 254 Pattern pattern = Pattern.compile(entry.getKey()); 255 // Check all the feature names 256 for (String name : featureNames) { 257 // If the regex matches 258 if (pattern.matcher(name).matches()) { 259 List<Transformation> oldTransformations = featureTransformations.put(name,entry.getValue()); 260 // See if there is already a transformation list for that name. 261 if (oldTransformations != null) { 262 throw new IllegalArgumentException("Feature name '" 263 + name + "' matches multiple regexes, at least one of which was '" 264 + entry.getKey() + "'."); 265 } 266 } 267 } 268 } 269 270 // Populate the feature transforms map. 271 Map<String,Queue<TransformStatistics>> featureStats = new HashMap<>(); 272 // sparseCount tracks how many times a feature was not observed 273 Map<String,MutableLong> sparseCount = new HashMap<>(); 274 for (Map.Entry<String,List<Transformation>> entry : featureTransformations.entrySet()) { 275 // Create the queue of feature transformations for this feature 276 Queue<TransformStatistics> l = new LinkedList<>(); 277 for (Transformation t : entry.getValue()) { 278 l.add(t.createStats()); 279 } 280 // Add the queue to the map for that feature 281 featureStats.put(entry.getKey(),l); 282 sparseCount.put(entry.getKey(), new MutableLong(data.size())); 283 } 284 if (!transformations.getGlobalTransformations().isEmpty()) { 285 // Append all the global transformations 286 for (String v : featureNames) { 287 // Create the queue of feature transformations for this feature 288 Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>()); 289 for (Transformation t : transformations.getGlobalTransformations()) { 290 l.add(t.createStats()); 291 } 292 // Add the queue to the map for that feature 293 featureStats.put(v, l); 294 // Generate the sparse count initialised to the number of features. 295 sparseCount.putIfAbsent(v, new MutableLong(data.size())); 296 } 297 } 298 299 Map<String,List<Transformer>> output = new LinkedHashMap<>(); 300 Set<String> removeSet = new LinkedHashSet<>(); 301 boolean initialisedSparseCounts = false; 302 // Iterate through the dataset max(transformations.length) times. 303 while (!featureStats.isEmpty()) { 304 for (Example<T> example : data) { 305 for (Feature f : example) { 306 if (featureStats.containsKey(f.getName())) { 307 if (!initialisedSparseCounts) { 308 sparseCount.get(f.getName()).decrement(); 309 } 310 List<Transformer> curTransformers = output.get(f.getName()); 311 // Apply all current transformations 312 double fValue = TransformerMap.applyTransformerList(f.getValue(), curTransformers); 313 // Observe the transformed value 314 featureStats.get(f.getName()).peek().observeValue(fValue); 315 } 316 } 317 } 318 // Sparse counts are updated (this could be protected by an if statement) 319 initialisedSparseCounts = true; 320 321 removeSet.clear(); 322 // Emit the new transformers onto the end of the list in the output map. 323 for (Map.Entry<String,Queue<TransformStatistics>> entry : featureStats.entrySet()) { 324 // Observe all the sparse feature values 325 int unobservedFeatures = sparseCount.get(entry.getKey()).intValue(); 326 TransformStatistics currentStats = entry.getValue().poll(); 327 currentStats.observeSparse(unobservedFeatures); 328 // Get the transformer list for that feature (if absent) 329 List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>()); 330 // Generate the transformer and add it to the appropriate list. 331 l.add(currentStats.generateTransformer()); 332 // If the queue is empty, remove that feature, ensuring that featureStats is eventually empty. 333 if (entry.getValue().isEmpty()) { 334 removeSet.add(entry.getKey()); 335 } 336 } 337 // Remove the features with empty queues. 338 for (String s : removeSet) { 339 featureStats.remove(s); 340 } 341 } 342 343 return new TransformerMap(output,getProvenance(),transformations.getProvenance()); 344 } 345 346 private static class ShuffleIterator<T extends Output<T>> implements Iterator<Example<T>> { 347 private final Dataset<T> data; 348 private final int[] indices; 349 private int index; 350 351 public ShuffleIterator(Dataset<T> data, int[] indices) { 352 this.data = data; 353 this.indices = indices; 354 this.index = 0; 355 } 356 357 @Override 358 public boolean hasNext() { 359 return index < indices.length; 360 } 361 362 @Override 363 public Example<T> next() { 364 Example<T> e = data.getExample(indices[index]); 365 index++; 366 return e; 367 } 368 } 369} 370