001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.datasource; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 024import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance; 027import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 028import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 029import com.oracle.labs.mlrg.olcut.util.IOUtil; 030import org.tribuo.ConfigurableDataSource; 031import org.tribuo.Example; 032import org.tribuo.Feature; 033import org.tribuo.Output; 034import org.tribuo.OutputFactory; 035import org.tribuo.impl.ArrayExample; 036import org.tribuo.provenance.DataSourceProvenance; 037 038import java.io.BufferedOutputStream; 039import java.io.DataInputStream; 040import java.io.DataOutputStream; 041import java.io.EOFException; 042import java.io.FileNotFoundException; 043import java.io.FileOutputStream; 044import java.io.IOException; 045import java.io.InputStream; 046import java.io.OutputStream; 047import java.nio.file.Path; 048import java.time.Instant; 049import java.time.OffsetDateTime; 050import java.time.ZoneId; 051import java.util.ArrayList; 052import java.util.Arrays; 053import java.util.HashMap; 054import java.util.Iterator; 055import java.util.Map; 056import java.util.logging.Logger; 057import java.util.zip.GZIPOutputStream; 058 059/** 060 * A DataSource which can read IDX formatted data (i.e., MNIST). 061 * <p> 062 * Transparently reads GZipped files. 063 * <p> 064 * The file format is defined <a href="http://yann.lecun.com/exdb/mnist/">here</a>. 065 */ 066public final class IDXDataSource<T extends Output<T>> implements ConfigurableDataSource<T> { 067 private static final Logger logger = Logger.getLogger(IDXDataSource.class.getName()); 068 069 /** 070 * The possible IDX input formats. 071 */ 072 public enum IDXType { 073 UBYTE((byte) 0x08), 074 BYTE((byte) 0x09), 075 SHORT((byte) 0x0B), 076 INT((byte) 0x0C), 077 FLOAT((byte) 0x0D), 078 DOUBLE((byte) 0x0E); 079 080 /** 081 * The encoded byte value. 082 */ 083 public final byte value; 084 085 IDXType(byte value) { 086 this.value = value; 087 } 088 089 /** 090 * Converts the byte into the enum. Throws IllegalArgumentException if it's 091 * not a valid byte. 092 * 093 * @param input The byte to convert. 094 * @return The corresponding enum instance. 095 */ 096 public static IDXType convert(byte input) { 097 for (IDXType f : values()) { 098 if (f.value == input) { 099 return f; 100 } 101 } 102 throw new IllegalArgumentException("Invalid byte found - " + input); 103 } 104 } 105 106 @Config(mandatory = true, description = "Path to load the features from.") 107 private Path featuresPath; 108 109 @Config(mandatory = true, description = "Path to load the features from.") 110 private Path outputPath; 111 112 @Config(mandatory = true, description = "The output factory to use.") 113 private OutputFactory<T> outputFactory; 114 115 private final ArrayList<Example<T>> data = new ArrayList<>(); 116 117 private IDXType dataType; 118 119 private IDXDataSourceProvenance provenance; 120 121 /** 122 * For olcut. 123 */ 124 private IDXDataSource() {} 125 126 /** 127 * Constructs an IDXDataSource from the supplied paths. 128 * 129 * @param featuresPath The path to the features file. 130 * @param outputPath The path to the output file. 131 * @param outputFactory The output factory. 132 * @throws IOException If either file cannot be read. 133 */ 134 public IDXDataSource(Path featuresPath, Path outputPath, OutputFactory<T> outputFactory) throws IOException { 135 this.outputFactory = outputFactory; 136 this.featuresPath = featuresPath; 137 this.outputPath = outputPath; 138 read(); 139 } 140 141 /** 142 * Used by the OLCUT configuration system, and should not be called by external code. 143 */ 144 @Override 145 public void postConfig() throws IOException { 146 read(); 147 } 148 149 @Override 150 public String toString() { 151 return "IDXDataSource(featuresPath=" + featuresPath.toString() + ",outputPath=" + outputPath.toString() + ",featureType=" + dataType + ")"; 152 } 153 154 @Override 155 public OutputFactory<T> getOutputFactory() { 156 return outputFactory; 157 } 158 159 @Override 160 public synchronized DataSourceProvenance getProvenance() { 161 if (provenance == null) { 162 provenance = cacheProvenance(); 163 } 164 return provenance; 165 } 166 167 private IDXDataSourceProvenance cacheProvenance() { 168 return new IDXDataSourceProvenance(this); 169 } 170 171 /** 172 * Loads the data. 173 * 174 * @throws IOException If the files could not be read. 175 */ 176 private void read() throws IOException { 177 IDXData features = readData(featuresPath); 178 IDXData outputs = readData(outputPath); 179 180 dataType = features.dataType; 181 182 if (features.shape[0] != outputs.shape[0]) { 183 throw new IllegalStateException("Features and outputs have different numbers of examples, feature shape = " + Arrays.toString(features.shape) + ", output shape = " + Arrays.toString(outputs.shape)); 184 } 185 186 // Calculate the example size 187 int numFeatures = 1; 188 for (int i = 1; i < features.shape.length; i++) { 189 numFeatures *= features.shape[i]; 190 } 191 int numOutputs = 1; 192 for (int i = 1; i < outputs.shape.length; i++) { 193 numOutputs *= outputs.shape[i]; 194 } 195 196 String[] featureNames = new String[numFeatures]; 197 int width = ("" + numFeatures).length(); 198 String formatString = "%0" + width + "d"; 199 for (int i = 0; i < numFeatures; i++) { 200 featureNames[i] = String.format(formatString, i); 201 } 202 203 ArrayList<Feature> buffer = new ArrayList<>(); 204 int featureCounter = 0; 205 int outputCounter = 0; 206 StringBuilder outputBuilder = new StringBuilder(); 207 for (int i = 0; i < features.data.length; i++) { 208 double curValue = features.data[i]; 209 if (curValue != 0.0) { 210 // Tribuo is sparse, so only create non-zero features 211 buffer.add(new Feature(featureNames[featureCounter], curValue)); 212 } 213 featureCounter++; 214 if (featureCounter == numFeatures) { 215 // fabricate output. Multidimensional outputs expect a comma separated string. 216 outputBuilder.setLength(0); 217 for (int j = 0; j < numOutputs; j++) { 218 if (j != 0) { 219 outputBuilder.append(','); 220 } 221 // If necessary cast to int to ensure we get a integer out for use as a class label 222 // No-one wants to have MNIST digits with labels "0.0", "1.0" etc. 223 switch (outputs.dataType) { 224 case BYTE: 225 case UBYTE: 226 case SHORT: 227 case INT: 228 outputBuilder.append((int) outputs.data[j + outputCounter]); 229 break; 230 case FLOAT: 231 case DOUBLE: 232 outputBuilder.append(outputs.data[j + outputCounter]); 233 break; 234 } 235 } 236 outputCounter += numOutputs; 237 T output = outputFactory.generateOutput(outputBuilder.toString()); 238 239 // create example 240 Example<T> example = new ArrayExample<T>(output); 241 example.addAll(buffer); 242 data.add(example); 243 244 // Clean up 245 buffer.clear(); 246 featureCounter = 0; 247 } 248 } 249 250 if (featureCounter != 0) { 251 throw new IllegalStateException("Failed to process all the features, missing " + (numFeatures - featureCounter) + " values"); 252 } 253 } 254 255 /** 256 * Reads a single IDX format file. 257 * 258 * @param path The path to read. 259 * @return The IDXData from the file. 260 * @throws IOException If the file could not be read. 261 */ 262 static IDXData readData(Path path) throws IOException { 263 InputStream inputStream = IOUtil.getInputStreamForLocation(path.toString()); 264 if (inputStream == null) { 265 throw new FileNotFoundException("Failed to load from path - " + path); 266 } 267 // DataInputStream.close implicitly closes the InputStream 268 try (DataInputStream stream = new DataInputStream(inputStream)) { 269 short magicNumber = stream.readShort(); 270 if (magicNumber != 0) { 271 throw new IllegalStateException("Invalid IDX file, magic number was not zero. Found " + magicNumber); 272 } 273 final byte dataTypeByte = stream.readByte(); 274 final IDXType dataType = IDXType.convert(dataTypeByte); 275 final byte numDimensions = stream.readByte(); 276 if (numDimensions < 1) { 277 throw new IllegalStateException("Invalid number of dimensions, found " + numDimensions); 278 } 279 final int[] shape = new int[numDimensions]; 280 int size = 1; 281 for (int i = 0; i < numDimensions; i++) { 282 shape[i] = stream.readInt(); 283 if (shape[i] < 1) { 284 throw new IllegalStateException("Invalid shape, found " + Arrays.toString(shape)); 285 } 286 size *= shape[i]; 287 } 288 double[] data = new double[size]; 289 try { 290 for (int i = 0; i < size; i++) { 291 switch (dataType) { 292 case BYTE: 293 data[i] = stream.readByte(); 294 break; 295 case UBYTE: 296 data[i] = stream.readUnsignedByte(); 297 break; 298 case SHORT: 299 data[i] = stream.readShort(); 300 break; 301 case INT: 302 data[i] = stream.readInt(); 303 break; 304 case FLOAT: 305 data[i] = stream.readFloat(); 306 break; 307 case DOUBLE: 308 data[i] = stream.readDouble(); 309 break; 310 } 311 } 312 } catch (EOFException e) { 313 throw new IllegalStateException("Too little data in the file, expected to find " + size + " elements"); 314 } 315 try { 316 byte unexpectedByte = stream.readByte(); 317 throw new IllegalStateException("Too much data in the file"); 318 } catch (EOFException e) { 319 //pass as the stream is exhausted 320 } 321 return new IDXData(dataType, shape, data); 322 } 323 } 324 325 /** 326 * The number of examples loaded. 327 * 328 * @return The number of examples. 329 */ 330 public int size() { 331 return data.size(); 332 } 333 334 /** 335 * The type of the features that were loaded in. 336 * 337 * @return The feature type. 338 */ 339 public IDXType getDataType() { 340 return dataType; 341 } 342 343 @Override 344 public Iterator<Example<T>> iterator() { 345 return data.iterator(); 346 } 347 348 /** 349 * Java side representation for an IDX file. 350 */ 351 public static class IDXData { 352 final IDXType dataType; 353 final int[] shape; 354 final double[] data; 355 356 /** 357 * Constructor, does not validate or copy inputs. 358 * Use the factory method. 359 * @param dataType The data type. 360 * @param shape The tensor shape. 361 * @param data The data to write. 362 */ 363 IDXData(IDXType dataType, int[] shape, double[] data) { 364 this.dataType = dataType; 365 this.shape = shape; 366 this.data = data; 367 } 368 369 /** 370 * Constructs an IDXData, validating the input and defensively copying it. 371 * 372 * @param dataType The data type. 373 * @param shape The tensor shape. 374 * @param data The data to write. 375 * @return An IDXData. 376 */ 377 public static IDXData createIDXData(IDXType dataType, int[] shape, double[] data) { 378 int[] shapeCopy = Arrays.copyOf(shape, shape.length); 379 double[] dataCopy = Arrays.copyOf(data, data.length); 380 if (shape.length > 128) { 381 throw new IllegalArgumentException("Must have fewer than 128 dimensions"); 382 } 383 int numElements = 1; 384 for (int i = 0; i < shapeCopy.length; i++) { 385 numElements *= shapeCopy[i]; 386 if (shapeCopy[i] < 1) { 387 throw new IllegalArgumentException("Invalid shape, all elements must be positive, found " + Arrays.toString(shapeCopy)); 388 } 389 } 390 if (numElements != dataCopy.length) { 391 throw new IllegalArgumentException("Incorrect number of elements, expected " + numElements + ", found " + dataCopy.length); 392 } 393 394 if (dataType != IDXType.DOUBLE) { 395 for (int i = 0; i < dataCopy.length; i++) { 396 switch (dataType) { 397 case UBYTE: 398 int tmpU = 0xFF & (int) dataCopy[i]; 399 if (dataCopy[i] != tmpU) { 400 throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to unsigned byte"); 401 } 402 break; 403 case BYTE: 404 byte tmpB = (byte) dataCopy[i]; 405 if (dataCopy[i] != tmpB) { 406 throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to byte"); 407 } 408 break; 409 case SHORT: 410 short tmpS = (short) dataCopy[i]; 411 if (dataCopy[i] != tmpS) { 412 throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to short"); 413 } 414 break; 415 case INT: 416 int tmpI = (int) dataCopy[i]; 417 if (dataCopy[i] != tmpI) { 418 throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to int"); 419 } 420 break; 421 case FLOAT: 422 float tmpF = (float) dataCopy[i]; 423 if (dataCopy[i] != tmpF) { 424 throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to float"); 425 } 426 break; 427 } 428 } 429 } 430 431 return new IDXData(dataType, shape, data); 432 } 433 434 /** 435 * Writes out this IDXData to the specified path. 436 * 437 * @param outputPath The path to write to. 438 * @param gzip If true, gzip the output. 439 * @throws IOException If the write failed. 440 */ 441 public void save(Path outputPath, boolean gzip) throws IOException { 442 try (DataOutputStream ds = makeStream(outputPath, gzip)) { 443 // Magic number 444 ds.writeShort(0); 445 // Data type 446 ds.writeByte(dataType.value); 447 // Num dimensions 448 ds.writeByte(shape.length); 449 450 for (int i = 0; i < shape.length; i++) { 451 ds.writeInt(shape[i]); 452 } 453 454 for (int i = 0; i < data.length; i++) { 455 switch (dataType) { 456 case UBYTE: 457 ds.writeByte(0xFF & (int) data[i]); 458 break; 459 case BYTE: 460 ds.writeByte((byte) data[i]); 461 break; 462 case SHORT: 463 ds.writeShort((short) data[i]); 464 break; 465 case INT: 466 ds.writeInt((int) data[i]); 467 break; 468 case FLOAT: 469 ds.writeFloat((float) data[i]); 470 break; 471 case DOUBLE: 472 ds.writeDouble(data[i]); 473 break; 474 } 475 } 476 } 477 } 478 479 private static DataOutputStream makeStream(Path outputPath, boolean gzip) throws IOException { 480 OutputStream stream; 481 if (gzip) { 482 stream = new GZIPOutputStream(new FileOutputStream(outputPath.toFile())); 483 } else { 484 stream = new FileOutputStream(outputPath.toFile()); 485 } 486 return new DataOutputStream(new BufferedOutputStream(stream)); 487 } 488 } 489 490 /** 491 * Provenance class for {@link IDXDataSource}. 492 */ 493 public static final class IDXDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance { 494 private static final long serialVersionUID = 1L; 495 496 public static final String OUTPUT_FILE_MODIFIED_TIME = "output-file-modified-time"; 497 public static final String FEATURES_FILE_MODIFIED_TIME = "features-file-modified-time"; 498 public static final String FEATURES_RESOURCE_HASH = "features-resource-hash"; 499 public static final String OUTPUT_RESOURCE_HASH = "output-resource-hash"; 500 public static final String FEATURE_TYPE = "idx-feature-type"; 501 502 private final DateTimeProvenance featuresFileModifiedTime; 503 private final DateTimeProvenance outputFileModifiedTime; 504 private final DateTimeProvenance dataSourceCreationTime; 505 private final HashProvenance featuresSHA256Hash; 506 private final HashProvenance outputSHA256Hash; 507 private final EnumProvenance<IDXType> featureType; 508 509 <T extends Output<T>> IDXDataSourceProvenance(IDXDataSource<T> host) { 510 super(host, "DataSource"); 511 this.outputFileModifiedTime = new DateTimeProvenance(OUTPUT_FILE_MODIFIED_TIME, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.outputPath.toFile().lastModified()), ZoneId.systemDefault())); 512 this.featuresFileModifiedTime = new DateTimeProvenance(FEATURES_FILE_MODIFIED_TIME, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.featuresPath.toFile().lastModified()), ZoneId.systemDefault())); 513 this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME, OffsetDateTime.now()); 514 this.featuresSHA256Hash = new HashProvenance(DEFAULT_HASH_TYPE, FEATURES_RESOURCE_HASH, ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, host.featuresPath)); 515 this.outputSHA256Hash = new HashProvenance(DEFAULT_HASH_TYPE, OUTPUT_RESOURCE_HASH, ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, host.outputPath)); 516 this.featureType = new EnumProvenance<>(FEATURE_TYPE, host.dataType); 517 } 518 519 public IDXDataSourceProvenance(Map<String, Provenance> map) { 520 this(extractProvenanceInfo(map)); 521 } 522 523 // Suppressed due to enum provenance cast 524 @SuppressWarnings("unchecked") 525 private IDXDataSourceProvenance(ExtractedInfo info) { 526 super(info); 527 this.featuresFileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FEATURES_FILE_MODIFIED_TIME); 528 this.outputFileModifiedTime = (DateTimeProvenance) info.instanceValues.get(OUTPUT_FILE_MODIFIED_TIME); 529 this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME); 530 this.featuresSHA256Hash = (HashProvenance) info.instanceValues.get(FEATURES_RESOURCE_HASH); 531 this.outputSHA256Hash = (HashProvenance) info.instanceValues.get(OUTPUT_RESOURCE_HASH); 532 this.featureType = (EnumProvenance<IDXType>) info.instanceValues.get(FEATURE_TYPE); 533 } 534 535 protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) { 536 Map<String, Provenance> configuredParameters = new HashMap<>(map); 537 String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters, CLASS_NAME, StringProvenance.class, IDXDataSourceProvenance.class.getSimpleName()).getValue(); 538 String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, IDXDataSourceProvenance.class.getSimpleName()).getValue(); 539 540 Map<String, PrimitiveProvenance<?>> instanceParameters = new HashMap<>(); 541 instanceParameters.put(FEATURES_FILE_MODIFIED_TIME, ObjectProvenance.checkAndExtractProvenance(configuredParameters, FEATURES_FILE_MODIFIED_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName())); 542 instanceParameters.put(OUTPUT_FILE_MODIFIED_TIME, ObjectProvenance.checkAndExtractProvenance(configuredParameters, OUTPUT_FILE_MODIFIED_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName())); 543 instanceParameters.put(DATASOURCE_CREATION_TIME, ObjectProvenance.checkAndExtractProvenance(configuredParameters, DATASOURCE_CREATION_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName())); 544 instanceParameters.put(FEATURES_RESOURCE_HASH, ObjectProvenance.checkAndExtractProvenance(configuredParameters, FEATURES_RESOURCE_HASH, HashProvenance.class, IDXDataSourceProvenance.class.getSimpleName())); 545 instanceParameters.put(OUTPUT_RESOURCE_HASH, ObjectProvenance.checkAndExtractProvenance(configuredParameters, OUTPUT_RESOURCE_HASH, HashProvenance.class, IDXDataSourceProvenance.class.getSimpleName())); 546 instanceParameters.put(FEATURE_TYPE, ObjectProvenance.checkAndExtractProvenance(configuredParameters, FEATURE_TYPE, EnumProvenance.class, IDXDataSourceProvenance.class.getSimpleName())); 547 548 return new ExtractedInfo(className, hostTypeStringName, configuredParameters, instanceParameters); 549 } 550 551 @Override 552 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 553 Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues(); 554 555 map.put(featuresFileModifiedTime.getKey(), featuresFileModifiedTime); 556 map.put(outputFileModifiedTime.getKey(), outputFileModifiedTime); 557 map.put(dataSourceCreationTime.getKey(), dataSourceCreationTime); 558 map.put(featuresSHA256Hash.getKey(), featuresSHA256Hash); 559 map.put(outputSHA256Hash.getKey(), outputSHA256Hash); 560 map.put(featureType.getKey(), featureType); 561 562 return map; 563 } 564 } 565}