001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.datasource;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
024import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
027import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
028import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
029import com.oracle.labs.mlrg.olcut.util.IOUtil;
030import org.tribuo.ConfigurableDataSource;
031import org.tribuo.Example;
032import org.tribuo.Feature;
033import org.tribuo.Output;
034import org.tribuo.OutputFactory;
035import org.tribuo.impl.ArrayExample;
036import org.tribuo.provenance.DataSourceProvenance;
037
038import java.io.BufferedOutputStream;
039import java.io.DataInputStream;
040import java.io.DataOutputStream;
041import java.io.EOFException;
042import java.io.FileNotFoundException;
043import java.io.FileOutputStream;
044import java.io.IOException;
045import java.io.InputStream;
046import java.io.OutputStream;
047import java.nio.file.Path;
048import java.time.Instant;
049import java.time.OffsetDateTime;
050import java.time.ZoneId;
051import java.util.ArrayList;
052import java.util.Arrays;
053import java.util.HashMap;
054import java.util.Iterator;
055import java.util.Map;
056import java.util.logging.Logger;
057import java.util.zip.GZIPOutputStream;
058
059/**
060 * A DataSource which can read IDX formatted data (i.e., MNIST).
061 * <p>
062 * Transparently reads GZipped files.
063 * <p>
064 * The file format is defined <a href="http://yann.lecun.com/exdb/mnist/">here</a>.
065 */
066public final class IDXDataSource<T extends Output<T>> implements ConfigurableDataSource<T> {
067    private static final Logger logger = Logger.getLogger(IDXDataSource.class.getName());
068
069    /**
070     * The possible IDX input formats.
071     */
072    public enum IDXType {
073        UBYTE((byte) 0x08),
074        BYTE((byte) 0x09),
075        SHORT((byte) 0x0B),
076        INT((byte) 0x0C),
077        FLOAT((byte) 0x0D),
078        DOUBLE((byte) 0x0E);
079
080        /**
081         * The encoded byte value.
082         */
083        public final byte value;
084
085        IDXType(byte value) {
086            this.value = value;
087        }
088
089        /**
090         * Converts the byte into the enum. Throws IllegalArgumentException if it's
091         * not a valid byte.
092         *
093         * @param input The byte to convert.
094         * @return The corresponding enum instance.
095         */
096        public static IDXType convert(byte input) {
097            for (IDXType f : values()) {
098                if (f.value == input) {
099                    return f;
100                }
101            }
102            throw new IllegalArgumentException("Invalid byte found - " + input);
103        }
104    }
105
106    @Config(mandatory = true, description = "Path to load the features from.")
107    private Path featuresPath;
108
109    @Config(mandatory = true, description = "Path to load the features from.")
110    private Path outputPath;
111
112    @Config(mandatory = true, description = "The output factory to use.")
113    private OutputFactory<T> outputFactory;
114
115    private final ArrayList<Example<T>> data = new ArrayList<>();
116
117    private IDXType dataType;
118
119    private IDXDataSourceProvenance provenance;
120
121    /**
122     * For olcut.
123     */
124    private IDXDataSource() {}
125
126    /**
127     * Constructs an IDXDataSource from the supplied paths.
128     *
129     * @param featuresPath  The path to the features file.
130     * @param outputPath    The path to the output file.
131     * @param outputFactory The output factory.
132     * @throws IOException If either file cannot be read.
133     */
134    public IDXDataSource(Path featuresPath, Path outputPath, OutputFactory<T> outputFactory) throws IOException {
135        this.outputFactory = outputFactory;
136        this.featuresPath = featuresPath;
137        this.outputPath = outputPath;
138        read();
139    }
140
141    /**
142     * Used by the OLCUT configuration system, and should not be called by external code.
143     */
144    @Override
145    public void postConfig() throws IOException {
146        read();
147    }
148
149    @Override
150    public String toString() {
151        return "IDXDataSource(featuresPath=" + featuresPath.toString() + ",outputPath=" + outputPath.toString() + ",featureType=" + dataType + ")";
152    }
153
154    @Override
155    public OutputFactory<T> getOutputFactory() {
156        return outputFactory;
157    }
158
159    @Override
160    public synchronized DataSourceProvenance getProvenance() {
161        if (provenance == null) {
162            provenance = cacheProvenance();
163        }
164        return provenance;
165    }
166
167    private IDXDataSourceProvenance cacheProvenance() {
168        return new IDXDataSourceProvenance(this);
169    }
170
171    /**
172     * Loads the data.
173     *
174     * @throws IOException If the files could not be read.
175     */
176    private void read() throws IOException {
177        IDXData features = readData(featuresPath);
178        IDXData outputs = readData(outputPath);
179
180        dataType = features.dataType;
181
182        if (features.shape[0] != outputs.shape[0]) {
183            throw new IllegalStateException("Features and outputs have different numbers of examples, feature shape = " + Arrays.toString(features.shape) + ", output shape = " + Arrays.toString(outputs.shape));
184        }
185
186        // Calculate the example size
187        int numFeatures = 1;
188        for (int i = 1; i < features.shape.length; i++) {
189            numFeatures *= features.shape[i];
190        }
191        int numOutputs = 1;
192        for (int i = 1; i < outputs.shape.length; i++) {
193            numOutputs *= outputs.shape[i];
194        }
195
196        String[] featureNames = new String[numFeatures];
197        int width = ("" + numFeatures).length();
198        String formatString = "%0" + width + "d";
199        for (int i = 0; i < numFeatures; i++) {
200            featureNames[i] = String.format(formatString, i);
201        }
202
203        ArrayList<Feature> buffer = new ArrayList<>();
204        int featureCounter = 0;
205        int outputCounter = 0;
206        StringBuilder outputBuilder = new StringBuilder();
207        for (int i = 0; i < features.data.length; i++) {
208            double curValue = features.data[i];
209            if (curValue != 0.0) {
210                // Tribuo is sparse, so only create non-zero features
211                buffer.add(new Feature(featureNames[featureCounter], curValue));
212            }
213            featureCounter++;
214            if (featureCounter == numFeatures) {
215                // fabricate output. Multidimensional outputs expect a comma separated string.
216                outputBuilder.setLength(0);
217                for (int j = 0; j < numOutputs; j++) {
218                    if (j != 0) {
219                        outputBuilder.append(',');
220                    }
221                    // If necessary cast to int to ensure we get a integer out for use as a class label
222                    // No-one wants to have MNIST digits with labels "0.0", "1.0" etc.
223                    switch (outputs.dataType) {
224                        case BYTE:
225                        case UBYTE:
226                        case SHORT:
227                        case INT:
228                            outputBuilder.append((int) outputs.data[j + outputCounter]);
229                            break;
230                        case FLOAT:
231                        case DOUBLE:
232                            outputBuilder.append(outputs.data[j + outputCounter]);
233                            break;
234                    }
235                }
236                outputCounter += numOutputs;
237                T output = outputFactory.generateOutput(outputBuilder.toString());
238
239                // create example
240                Example<T> example = new ArrayExample<T>(output);
241                example.addAll(buffer);
242                data.add(example);
243
244                // Clean up
245                buffer.clear();
246                featureCounter = 0;
247            }
248        }
249
250        if (featureCounter != 0) {
251            throw new IllegalStateException("Failed to process all the features, missing " + (numFeatures - featureCounter) + " values");
252        }
253    }
254
255    /**
256     * Reads a single IDX format file.
257     *
258     * @param path The path to read.
259     * @return The IDXData from the file.
260     * @throws IOException If the file could not be read.
261     */
262    static IDXData readData(Path path) throws IOException {
263        InputStream inputStream = IOUtil.getInputStreamForLocation(path.toString());
264        if (inputStream == null) {
265            throw new FileNotFoundException("Failed to load from path - " + path);
266        }
267        // DataInputStream.close implicitly closes the InputStream
268        try (DataInputStream stream = new DataInputStream(inputStream)) {
269            short magicNumber = stream.readShort();
270            if (magicNumber != 0) {
271                throw new IllegalStateException("Invalid IDX file, magic number was not zero. Found " + magicNumber);
272            }
273            final byte dataTypeByte = stream.readByte();
274            final IDXType dataType = IDXType.convert(dataTypeByte);
275            final byte numDimensions = stream.readByte();
276            if (numDimensions < 1) {
277                throw new IllegalStateException("Invalid number of dimensions, found " + numDimensions);
278            }
279            final int[] shape = new int[numDimensions];
280            int size = 1;
281            for (int i = 0; i < numDimensions; i++) {
282                shape[i] = stream.readInt();
283                if (shape[i] < 1) {
284                    throw new IllegalStateException("Invalid shape, found " + Arrays.toString(shape));
285                }
286                size *= shape[i];
287            }
288            double[] data = new double[size];
289            try {
290                for (int i = 0; i < size; i++) {
291                    switch (dataType) {
292                        case BYTE:
293                            data[i] = stream.readByte();
294                            break;
295                        case UBYTE:
296                            data[i] = stream.readUnsignedByte();
297                            break;
298                        case SHORT:
299                            data[i] = stream.readShort();
300                            break;
301                        case INT:
302                            data[i] = stream.readInt();
303                            break;
304                        case FLOAT:
305                            data[i] = stream.readFloat();
306                            break;
307                        case DOUBLE:
308                            data[i] = stream.readDouble();
309                            break;
310                    }
311                }
312            } catch (EOFException e) {
313                throw new IllegalStateException("Too little data in the file, expected to find " + size + " elements");
314            }
315            try {
316                byte unexpectedByte = stream.readByte();
317                throw new IllegalStateException("Too much data in the file");
318            } catch (EOFException e) {
319                //pass as the stream is exhausted
320            }
321            return new IDXData(dataType, shape, data);
322        }
323    }
324
325    /**
326     * The number of examples loaded.
327     *
328     * @return The number of examples.
329     */
330    public int size() {
331        return data.size();
332    }
333
334    /**
335     * The type of the features that were loaded in.
336     *
337     * @return The feature type.
338     */
339    public IDXType getDataType() {
340        return dataType;
341    }
342
343    @Override
344    public Iterator<Example<T>> iterator() {
345        return data.iterator();
346    }
347
348    /**
349     * Java side representation for an IDX file.
350     */
351    public static class IDXData {
352        final IDXType dataType;
353        final int[] shape;
354        final double[] data;
355
356        /**
357         * Constructor, does not validate or copy inputs.
358         * Use the factory method.
359         * @param dataType The data type.
360         * @param shape    The tensor shape.
361         * @param data     The data to write.
362         */
363        IDXData(IDXType dataType, int[] shape, double[] data) {
364            this.dataType = dataType;
365            this.shape = shape;
366            this.data = data;
367        }
368
369        /**
370         * Constructs an IDXData, validating the input and defensively copying it.
371         *
372         * @param dataType The data type.
373         * @param shape    The tensor shape.
374         * @param data     The data to write.
375         * @return An IDXData.
376         */
377        public static IDXData createIDXData(IDXType dataType, int[] shape, double[] data) {
378            int[] shapeCopy = Arrays.copyOf(shape, shape.length);
379            double[] dataCopy = Arrays.copyOf(data, data.length);
380            if (shape.length > 128) {
381                throw new IllegalArgumentException("Must have fewer than 128 dimensions");
382            }
383            int numElements = 1;
384            for (int i = 0; i < shapeCopy.length; i++) {
385                numElements *= shapeCopy[i];
386                if (shapeCopy[i] < 1) {
387                    throw new IllegalArgumentException("Invalid shape, all elements must be positive, found " + Arrays.toString(shapeCopy));
388                }
389            }
390            if (numElements != dataCopy.length) {
391                throw new IllegalArgumentException("Incorrect number of elements, expected " + numElements + ", found " + dataCopy.length);
392            }
393
394            if (dataType != IDXType.DOUBLE) {
395                for (int i = 0; i < dataCopy.length; i++) {
396                    switch (dataType) {
397                        case UBYTE:
398                            int tmpU = 0xFF & (int) dataCopy[i];
399                            if (dataCopy[i] != tmpU) {
400                                throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to unsigned byte");
401                            }
402                            break;
403                        case BYTE:
404                            byte tmpB = (byte) dataCopy[i];
405                            if (dataCopy[i] != tmpB) {
406                                throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to byte");
407                            }
408                            break;
409                        case SHORT:
410                            short tmpS = (short) dataCopy[i];
411                            if (dataCopy[i] != tmpS) {
412                                throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to short");
413                            }
414                            break;
415                        case INT:
416                            int tmpI = (int) dataCopy[i];
417                            if (dataCopy[i] != tmpI) {
418                                throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to int");
419                            }
420                            break;
421                        case FLOAT:
422                            float tmpF = (float) dataCopy[i];
423                            if (dataCopy[i] != tmpF) {
424                                throw new IllegalArgumentException("Invalid value at idx " + i + ", could not be converted to float");
425                            }
426                            break;
427                    }
428                }
429            }
430
431            return new IDXData(dataType, shape, data);
432        }
433
434        /**
435         * Writes out this IDXData to the specified path.
436         *
437         * @param outputPath The path to write to.
438         * @param gzip       If true, gzip the output.
439         * @throws IOException If the write failed.
440         */
441        public void save(Path outputPath, boolean gzip) throws IOException {
442            try (DataOutputStream ds = makeStream(outputPath, gzip)) {
443                // Magic number
444                ds.writeShort(0);
445                // Data type
446                ds.writeByte(dataType.value);
447                // Num dimensions
448                ds.writeByte(shape.length);
449
450                for (int i = 0; i < shape.length; i++) {
451                    ds.writeInt(shape[i]);
452                }
453
454                for (int i = 0; i < data.length; i++) {
455                    switch (dataType) {
456                        case UBYTE:
457                            ds.writeByte(0xFF & (int) data[i]);
458                            break;
459                        case BYTE:
460                            ds.writeByte((byte) data[i]);
461                            break;
462                        case SHORT:
463                            ds.writeShort((short) data[i]);
464                            break;
465                        case INT:
466                            ds.writeInt((int) data[i]);
467                            break;
468                        case FLOAT:
469                            ds.writeFloat((float) data[i]);
470                            break;
471                        case DOUBLE:
472                            ds.writeDouble(data[i]);
473                            break;
474                    }
475                }
476            }
477        }
478
479        private static DataOutputStream makeStream(Path outputPath, boolean gzip) throws IOException {
480            OutputStream stream;
481            if (gzip) {
482                stream = new GZIPOutputStream(new FileOutputStream(outputPath.toFile()));
483            } else {
484                stream = new FileOutputStream(outputPath.toFile());
485            }
486            return new DataOutputStream(new BufferedOutputStream(stream));
487        }
488    }
489
490    /**
491     * Provenance class for {@link IDXDataSource}.
492     */
493    public static final class IDXDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance {
494        private static final long serialVersionUID = 1L;
495
496        public static final String OUTPUT_FILE_MODIFIED_TIME = "output-file-modified-time";
497        public static final String FEATURES_FILE_MODIFIED_TIME = "features-file-modified-time";
498        public static final String FEATURES_RESOURCE_HASH = "features-resource-hash";
499        public static final String OUTPUT_RESOURCE_HASH = "output-resource-hash";
500        public static final String FEATURE_TYPE = "idx-feature-type";
501
502        private final DateTimeProvenance featuresFileModifiedTime;
503        private final DateTimeProvenance outputFileModifiedTime;
504        private final DateTimeProvenance dataSourceCreationTime;
505        private final HashProvenance featuresSHA256Hash;
506        private final HashProvenance outputSHA256Hash;
507        private final EnumProvenance<IDXType> featureType;
508
509        <T extends Output<T>> IDXDataSourceProvenance(IDXDataSource<T> host) {
510            super(host, "DataSource");
511            this.outputFileModifiedTime = new DateTimeProvenance(OUTPUT_FILE_MODIFIED_TIME, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.outputPath.toFile().lastModified()), ZoneId.systemDefault()));
512            this.featuresFileModifiedTime = new DateTimeProvenance(FEATURES_FILE_MODIFIED_TIME, OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.featuresPath.toFile().lastModified()), ZoneId.systemDefault()));
513            this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME, OffsetDateTime.now());
514            this.featuresSHA256Hash = new HashProvenance(DEFAULT_HASH_TYPE, FEATURES_RESOURCE_HASH, ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, host.featuresPath));
515            this.outputSHA256Hash = new HashProvenance(DEFAULT_HASH_TYPE, OUTPUT_RESOURCE_HASH, ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, host.outputPath));
516            this.featureType = new EnumProvenance<>(FEATURE_TYPE, host.dataType);
517        }
518
519        public IDXDataSourceProvenance(Map<String, Provenance> map) {
520            this(extractProvenanceInfo(map));
521        }
522
523        // Suppressed due to enum provenance cast
524        @SuppressWarnings("unchecked")
525        private IDXDataSourceProvenance(ExtractedInfo info) {
526            super(info);
527            this.featuresFileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FEATURES_FILE_MODIFIED_TIME);
528            this.outputFileModifiedTime = (DateTimeProvenance) info.instanceValues.get(OUTPUT_FILE_MODIFIED_TIME);
529            this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME);
530            this.featuresSHA256Hash = (HashProvenance) info.instanceValues.get(FEATURES_RESOURCE_HASH);
531            this.outputSHA256Hash = (HashProvenance) info.instanceValues.get(OUTPUT_RESOURCE_HASH);
532            this.featureType = (EnumProvenance<IDXType>) info.instanceValues.get(FEATURE_TYPE);
533        }
534
535        protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
536            Map<String, Provenance> configuredParameters = new HashMap<>(map);
537            String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters, CLASS_NAME, StringProvenance.class, IDXDataSourceProvenance.class.getSimpleName()).getValue();
538            String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, IDXDataSourceProvenance.class.getSimpleName()).getValue();
539
540            Map<String, PrimitiveProvenance<?>> instanceParameters = new HashMap<>();
541            instanceParameters.put(FEATURES_FILE_MODIFIED_TIME, ObjectProvenance.checkAndExtractProvenance(configuredParameters, FEATURES_FILE_MODIFIED_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
542            instanceParameters.put(OUTPUT_FILE_MODIFIED_TIME, ObjectProvenance.checkAndExtractProvenance(configuredParameters, OUTPUT_FILE_MODIFIED_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
543            instanceParameters.put(DATASOURCE_CREATION_TIME, ObjectProvenance.checkAndExtractProvenance(configuredParameters, DATASOURCE_CREATION_TIME, DateTimeProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
544            instanceParameters.put(FEATURES_RESOURCE_HASH, ObjectProvenance.checkAndExtractProvenance(configuredParameters, FEATURES_RESOURCE_HASH, HashProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
545            instanceParameters.put(OUTPUT_RESOURCE_HASH, ObjectProvenance.checkAndExtractProvenance(configuredParameters, OUTPUT_RESOURCE_HASH, HashProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
546            instanceParameters.put(FEATURE_TYPE, ObjectProvenance.checkAndExtractProvenance(configuredParameters, FEATURE_TYPE, EnumProvenance.class, IDXDataSourceProvenance.class.getSimpleName()));
547
548            return new ExtractedInfo(className, hostTypeStringName, configuredParameters, instanceParameters);
549        }
550
551        @Override
552        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
553            Map<String, PrimitiveProvenance<?>> map = super.getInstanceValues();
554
555            map.put(featuresFileModifiedTime.getKey(), featuresFileModifiedTime);
556            map.put(outputFileModifiedTime.getKey(), outputFileModifiedTime);
557            map.put(dataSourceCreationTime.getKey(), dataSourceCreationTime);
558            map.put(featuresSHA256Hash.getKey(), featuresSHA256Hash);
559            map.put(outputSHA256Hash.getKey(), outputSHA256Hash);
560            map.put(featureType.getKey(), featureType);
561
562            return map;
563        }
564    }
565}