001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.dataset; 018 019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException; 023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance; 027import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 028import com.oracle.labs.mlrg.olcut.util.Pair; 029import org.tribuo.Dataset; 030import org.tribuo.Example; 031import org.tribuo.ImmutableDataset; 032import org.tribuo.ImmutableFeatureMap; 033import org.tribuo.ImmutableOutputInfo; 034import org.tribuo.Output; 035import org.tribuo.provenance.DatasetProvenance; 036import org.tribuo.util.Util; 037 038import java.util.ArrayList; 039import java.util.Arrays; 040import java.util.Collections; 041import java.util.Iterator; 042import java.util.List; 043import java.util.Map; 044import java.util.Objects; 045import java.util.Set; 046import java.util.SplittableRandom; 047import java.util.function.Predicate; 048 049/** 050 * DatasetView provides an immutable view on another {@link Dataset} that only exposes selected examples. 051 * Does not copy the examples. 052 * 053 * @param <T> The output type of this dataset. 054 */ 055public final class DatasetView<T extends Output<T>> extends ImmutableDataset<T> { 056 private static final long serialVersionUID = 1L; 057 058 private final Dataset<T> innerDataset; 059 060 private final int size; 061 062 private final int[] exampleIndices; 063 064 private final long seed; 065 066 private final String tag; 067 068 private final boolean sampled; 069 070 private final boolean weighted; 071 072 private boolean storeIndices = false; 073 074 /** 075 * Creates a DatasetView which includes the supplied indices from the dataset. 076 * <p> 077 * It uses the feature and output infos from the wrapped dataset. 078 * 079 * @param dataset The dataset to wrap. 080 * @param exampleIndices The indices to present. 081 * @param tag A tag for the view. 082 */ 083 public DatasetView(Dataset<T> dataset, int[] exampleIndices, String tag) { 084 this(dataset,exampleIndices,dataset.getFeatureIDMap(),dataset.getOutputIDInfo(), tag); 085 } 086 087 /** 088 * Creates a DatasetView which includes the supplied indices from the dataset. 089 * <p> 090 * This takes the ImmutableFeatureMap and ImmutableOutputInfo parameters to save them being 091 * regenerated (e.g., in BaggingTrainer). 092 * 093 * @param dataset The dataset to sample from. 094 * @param exampleIndices The indices of this view in the wrapped dataset. 095 * @param featureIDs The featureIDs to use for this dataset. 096 * @param labelIDs The labelIDs to use for this dataset. 097 * @param tag A tag for the view. 098 */ 099 public DatasetView(Dataset<T> dataset, int[] exampleIndices, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, String tag) { 100 super(dataset.getProvenance(),dataset.getOutputFactory(),featureIDs,labelIDs); 101 if (!validateIndices(dataset.size(),exampleIndices)) { 102 throw new IllegalArgumentException("Invalid indices supplied, dataset.size() = " + dataset.size() + ", but found a negative index or a value greater than or equal to size."); 103 } 104 this.innerDataset = dataset; 105 this.size = exampleIndices.length; 106 this.exampleIndices = exampleIndices; 107 this.seed = -1; 108 this.tag = tag; 109 this.storeIndices = true; 110 this.sampled = false; 111 this.weighted = false; 112 } 113 114 /** 115 * Constructor used by the sampling factory methods. 116 * @param dataset The dataset to create the view over. 117 * @param exampleIndices The indices to use. 118 * @param seed The seed for the RNG. 119 * @param featureIDs The feature IDs to use. 120 * @param outputIDs The output IDs to use. 121 * @param weighted Is it a weighted sample? (Weighted samples store the indices in the provenance by default). 122 */ 123 private DatasetView(Dataset<T> dataset, int[] exampleIndices, long seed, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs, boolean weighted) { 124 super(dataset.getProvenance(),dataset.getOutputFactory(),featureIDs,outputIDs); 125 this.innerDataset = dataset; 126 this.size = exampleIndices.length; 127 this.exampleIndices = exampleIndices; 128 this.tag = ""; 129 this.seed = seed; 130 this.sampled = true; 131 this.weighted = weighted; 132 this.storeIndices = weighted; 133 } 134 135 /** 136 * Creates a view from the supplied dataset, using the specified predicate to 137 * test if each example should be in this view. 138 * @param dataset The dataset to create a view over. 139 * @param predicate The predicate which determines if an example is in this view. 140 * @param tag A tag denoting what the predicate does. 141 * @param <T> The type of the Output in the dataset. 142 * @return A dataset view containing each example where the predicate is true. 143 */ 144 public static <T extends Output<T>> DatasetView<T> createView(Dataset<T> dataset, Predicate<Example<T>> predicate, String tag) { 145 List<Integer> selectedIndices = new ArrayList<>(); 146 147 int i = 0; 148 for (Example<T> e : dataset) { 149 if (predicate.test(e)) { 150 selectedIndices.add(i); 151 } 152 i++; 153 } 154 155 int[] exampleIndices = Util.toPrimitiveInt(selectedIndices); 156 return new DatasetView<>(dataset,exampleIndices,tag); 157 } 158 159 /** 160 * Generates a DatasetView bootstrapped from the supplied Dataset. 161 * 162 * @param dataset The dataset to sample from. 163 * @param size The size of the sample. 164 * @param seed A seed for the RNG. 165 * @param <T> The type of the Output in the dataset. 166 * @return A dataset view containing a bootstrap sample of the supplied dataset. 167 */ 168 public static <T extends Output<T>> DatasetView<T> createBootstrapView(Dataset<T> dataset, int size, long seed) { 169 return createBootstrapView(dataset,size,seed,dataset.getFeatureIDMap(),dataset.getOutputIDInfo()); 170 } 171 172 /** 173 * Generates a DatasetView bootstrapped from the supplied Dataset. 174 * <p> 175 * This takes the ImmutableFeatureMap and ImmutableOutputInfo parameters to save them being 176 * regenerated. 177 * 178 * @param dataset The dataset to sample from. 179 * @param size The size of the sample. 180 * @param seed A seed for the RNG. 181 * @param featureIDs The featureIDs to use for this dataset. 182 * @param outputIDs The output info to use for this dataset. 183 * @param <T> The type of the Output in the dataset. 184 * @return A dataset view containing a bootstrap sample of the supplied dataset. 185 */ 186 public static <T extends Output<T>> DatasetView<T> createBootstrapView(Dataset<T> dataset, int size, long seed, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs) { 187 int[] bootstrapIndices = Util.generateBootstrapIndices(size, new SplittableRandom(seed)); 188 return new DatasetView<>(dataset, bootstrapIndices, seed, featureIDs, outputIDs, false); 189 } 190 191 /** 192 * Generates a DatasetView bootstrapped from the supplied Dataset using the supplied 193 * example weights. 194 * 195 * @param dataset The dataset to sample from. 196 * @param size The size of the sample. 197 * @param seed A seed for the RNG. 198 * @param exampleWeights The sampling weights for each example, must be in the range 0,1. 199 * @param <T> The type of the Output in the dataset. 200 * @return A dataset view containing a weighted bootstrap sample of the supplied dataset. 201 */ 202 public static <T extends Output<T>> DatasetView<T> createWeightedBootstrapView(Dataset<T> dataset, int size, long seed, float[] exampleWeights) { 203 return createWeightedBootstrapView(dataset,size,seed,exampleWeights,dataset.getFeatureIDMap(),dataset.getOutputIDInfo()); 204 } 205 206 /** 207 * Generates a DatasetView bootstrapped from the supplied Dataset using the supplied 208 * example weights. 209 * <p> 210 * This takes the ImmutableFeatureMap and ImmutableOutputInfo parameters to save them being 211 * regenerated. 212 * 213 * @param dataset The dataset to sample from. 214 * @param size The size of the sample. 215 * @param seed A seed for the RNG. 216 * @param exampleWeights The sampling weights for each example, must be in the range 0,1. 217 * @param featureIDs The featureIDs to use for this dataset. 218 * @param outputIDs The output info to use for this dataset. 219 * @param <T> The type of the Output in the dataset. 220 * @return A dataset view containing a weighted bootstrap sample of the supplied dataset. 221 */ 222 public static <T extends Output<T>> DatasetView<T> createWeightedBootstrapView(Dataset<T> dataset, int size, long seed, float[] exampleWeights, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> outputIDs) { 223 if (dataset.size() != exampleWeights.length) { 224 throw new IllegalArgumentException("There must be a weight for each example, dataset.size()="+dataset.size()+", exampleWeights.length="+exampleWeights.length); 225 } 226 int[] bootstrapIndices = Util.generateWeightedIndicesSample(size,exampleWeights,new SplittableRandom(seed)); 227 return new DatasetView<>(dataset, bootstrapIndices, seed, featureIDs, outputIDs,true); 228 } 229 230 /** 231 * Are the indices stored in the provenance system. 232 * @return True if the indices will be stored in the provenance of this view. 233 */ 234 public boolean storeIndicesInProvenance() { 235 return storeIndices; 236 } 237 238 /** 239 * Set to true to store the indices in the provenance system. 240 * @param storeIndices True if the indices should be stored in the provenance of this view. 241 */ 242 public void setStoreIndices(boolean storeIndices) { 243 this.storeIndices = storeIndices; 244 } 245 246 @Override 247 public String toString() { 248 StringBuilder buffer = new StringBuilder(); 249 250 buffer.append("DatasetView(innerDataset="); 251 buffer.append(innerDataset.getSourceDescription()); 252 buffer.append(",size="); 253 buffer.append(size); 254 buffer.append(",seed="); 255 buffer.append(seed); 256 buffer.append(",tag="); 257 buffer.append(tag); 258 buffer.append(")"); 259 260 return buffer.toString(); 261 } 262 263 /** 264 * Gets the set of outputs that occur in the examples in this dataset. 265 * 266 * @return the set of outputs that occur in the examples in this dataset. 267 */ 268 @Override 269 public Set<T> getOutputs() { 270 return innerDataset.getOutputs(); 271 } 272 273 /** 274 * Gets the size of the data set. 275 * 276 * @return the size of the data set. 277 */ 278 @Override 279 public int size() { 280 return size; 281 } 282 283 @Override 284 public ImmutableFeatureMap getFeatureMap() { 285 return featureIDMap; 286 } 287 288 @Override 289 public ImmutableOutputInfo<T> getOutputInfo() { 290 return outputIDInfo; 291 } 292 293 @Override 294 public Iterator<Example<T>> iterator() { 295 return new ViewIterator<>(this); 296 } 297 298 @Override 299 public List<Example<T>> getData() { 300 ArrayList<Example<T>> data = new ArrayList<>(); 301 for (int index : exampleIndices) { 302 data.add(innerDataset.getExample(index)); 303 } 304 return Collections.unmodifiableList(data); 305 } 306 307 @Override 308 public Example<T> getExample(int index) { 309 if ((index < 0) || (index >= size())) { 310 throw new IllegalArgumentException("Example index " + index + " is out of bounds."); 311 } 312 return innerDataset.getExample(exampleIndices[index]); 313 } 314 315 @Override 316 public DatasetViewProvenance getProvenance() { 317 return new DatasetViewProvenance(this,storeIndices); 318 } 319 320 /** 321 * Returns a copy of the indicies used in this view. 322 * @return The indices. 323 */ 324 public int[] getExampleIndices() { 325 return Arrays.copyOf(exampleIndices,exampleIndices.length); 326 } 327 328 /** 329 * Checks that all the indices are non-negative and less than size. 330 * @param size The maximum size. 331 * @param indices The indices to check. 332 * @return True if the indices are valid for the given size, false otherwise. 333 */ 334 private static boolean validateIndices(int size, int[] indices) { 335 boolean valid = true; 336 337 for (int i = 0; i < indices.length; i++) { 338 int idx = indices[i]; 339 valid &= idx < size && idx > -1; 340 } 341 342 return valid; 343 } 344 345 private static class ViewIterator<T extends Output<T>> implements Iterator<Example<T>> { 346 347 private int counter = 0; 348 private final DatasetView<T> dataset; 349 350 ViewIterator(DatasetView<T> dataset) { 351 this.dataset = dataset; 352 } 353 354 @Override 355 public boolean hasNext() { 356 return counter < dataset.size(); 357 } 358 359 @Override 360 public Example<T> next() { 361 Example<T> example = dataset.getExample(counter); 362 counter++; 363 return example; 364 } 365 366 } 367 368 /** 369 * Provenance for the {@link DatasetView}. 370 */ 371 public static final class DatasetViewProvenance extends DatasetProvenance { 372 private static final long serialVersionUID = 1L; 373 374 private static final String SIZE = "size"; 375 private static final String SEED = "seed"; 376 private static final String TAG = "tag"; 377 private static final String SAMPLED = "sampled"; 378 private static final String WEIGHTED = "weighted"; 379 private static final String INDICES = "indices"; 380 381 private final IntProvenance size; 382 private final LongProvenance seed; 383 private final StringProvenance tag; 384 private final BooleanProvenance weighted; 385 private final BooleanProvenance sampled; 386 private final int[] indices; 387 388 <T extends Output<T>> DatasetViewProvenance(DatasetView<T> dataset, boolean storeIndices) { 389 super(dataset.sourceProvenance, new ListProvenance<>(), dataset); 390 this.size = new IntProvenance(SIZE,dataset.size); 391 this.seed = new LongProvenance(SEED,dataset.seed); 392 this.weighted = new BooleanProvenance(WEIGHTED,dataset.weighted); 393 this.sampled = new BooleanProvenance(SAMPLED,dataset.sampled); 394 this.tag = new StringProvenance(TAG,dataset.tag); 395 this.indices = storeIndices ? dataset.indices : new int[0]; 396 } 397 398 public DatasetViewProvenance(Map<String,Provenance> map) { 399 super(map); 400 this.size = ObjectProvenance.checkAndExtractProvenance(map,SIZE,IntProvenance.class, DatasetViewProvenance.class.getSimpleName()); 401 this.seed = ObjectProvenance.checkAndExtractProvenance(map,SEED,LongProvenance.class, DatasetViewProvenance.class.getSimpleName()); 402 this.tag = ObjectProvenance.checkAndExtractProvenance(map,TAG,StringProvenance.class, DatasetViewProvenance.class.getSimpleName()); 403 this.weighted = ObjectProvenance.checkAndExtractProvenance(map,WEIGHTED,BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName()); 404 this.sampled = ObjectProvenance.checkAndExtractProvenance(map,SAMPLED,BooleanProvenance.class, DatasetViewProvenance.class.getSimpleName()); 405 @SuppressWarnings("unchecked") // List provenance cast 406 ListProvenance<IntProvenance> listIndices = ObjectProvenance.checkAndExtractProvenance(map,INDICES,ListProvenance.class, DatasetViewProvenance.class.getSimpleName()); 407 if (listIndices.getList().size() > 0) { 408 try { 409 IntProvenance i = listIndices.getList().get(0); 410 } catch (ClassCastException e) { 411 throw new ProvenanceException("Loaded another class when expecting an ListProvenance<IntProvenance>",e); 412 } 413 } 414 this.indices = Util.toPrimitiveInt(ProvenanceUtil.unwrap(listIndices)); 415 } 416 417 /** 418 * Generates the indices from this DatasetViewProvenance 419 * by rerunning the bootstrap sample. 420 * 421 * Note these indices are invalid if the view is a weighted sample, or 422 * not sampled. 423 * @return The bootstrap indices. 424 */ 425 public int[] generateBootstrap() { 426 return Util.generateBootstrapIndices(size.getValue(), new SplittableRandom(seed.getValue())); 427 } 428 429 /** 430 * Is this view from a bootstrap sample. 431 * @return True if it's a bootstrap sample. 432 */ 433 public boolean isSampled() { 434 return sampled.getValue(); 435 } 436 437 /** 438 * Is this view a weighted bootstrap sample. 439 * @return True if it's a weighted bootstrap sample. 440 */ 441 public boolean isWeighted() { 442 return weighted.getValue(); 443 } 444 445 @Override 446 public boolean equals(Object o) { 447 if (this == o) return true; 448 if (!(o instanceof DatasetView.DatasetViewProvenance)) return false; 449 if (!super.equals(o)) return false; 450 DatasetViewProvenance pairs = (DatasetViewProvenance) o; 451 return size.equals(pairs.size) && seed.equals(pairs.seed) && 452 tag.equals(pairs.tag); 453 } 454 455 @Override 456 public int hashCode() { 457 return Objects.hash(super.hashCode(), size, seed, tag); 458 } 459 460 @Override 461 protected List<Pair<String, Provenance>> allProvenances() { 462 List<Pair<String,Provenance>> provenances = super.allProvenances(); 463 provenances.add(new Pair<>(SIZE,size)); 464 provenances.add(new Pair<>(SEED,seed)); 465 provenances.add(new Pair<>(TAG,tag)); 466 provenances.add(new Pair<>(WEIGHTED,weighted)); 467 provenances.add(new Pair<>(SAMPLED,sampled)); 468 provenances.add(new Pair<>(INDICES,boxArray())); 469 return provenances; 470 } 471 472 private ListProvenance<IntProvenance> boxArray() { 473 List<IntProvenance> list = new ArrayList<>(); 474 475 for (int i = 0; i < indices.length; i++) { 476 list.add(new IntProvenance("indices",indices[i])); 477 } 478 479 return new ListProvenance<>(list); 480 } 481 482 /** 483 * This toString doesn't put the indices in the string, as it's likely 484 * to be huge. 485 * @return A string describing this provenance. 486 */ 487 @Override 488 public String toString() { 489 List<Pair<String,Provenance>> provenances = super.allProvenances(); 490 provenances.add(new Pair<>(SIZE,size)); 491 provenances.add(new Pair<>(SEED,seed)); 492 provenances.add(new Pair<>(TAG,tag)); 493 provenances.add(new Pair<>(WEIGHTED,weighted)); 494 provenances.add(new Pair<>(SAMPLED,sampled)); 495 provenances.add(new Pair<>(INDICES,new ListProvenance<>())); 496 497 StringBuilder sb = new StringBuilder(); 498 499 sb.append("DatasetView("); 500 for (Pair<String,Provenance> p : provenances) { 501 sb.append(p.getA()); 502 sb.append('='); 503 sb.append(p.getB().toString()); 504 sb.append(','); 505 } 506 sb.replace(sb.length()-1,sb.length(),")"); 507 508 return sb.toString(); 509 } 510 } 511}