001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.datasource;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.config.PropertyException;
021import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.Provenance;
024import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
025import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
027import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
028import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
029import org.tribuo.ConfigurableDataSource;
030import org.tribuo.Dataset;
031import org.tribuo.Example;
032import org.tribuo.Feature;
033import org.tribuo.ImmutableFeatureMap;
034import org.tribuo.Output;
035import org.tribuo.OutputFactory;
036import org.tribuo.impl.ArrayExample;
037import org.tribuo.provenance.DataSourceProvenance;
038
039import java.io.BufferedReader;
040import java.io.IOException;
041import java.io.InputStreamReader;
042import java.io.PrintStream;
043import java.net.MalformedURLException;
044import java.net.URL;
045import java.nio.charset.StandardCharsets;
046import java.nio.file.Path;
047import java.time.OffsetDateTime;
048import java.util.ArrayList;
049import java.util.HashMap;
050import java.util.Iterator;
051import java.util.Map;
052import java.util.Objects;
053import java.util.Optional;
054import java.util.function.Function;
055import java.util.logging.Logger;
056import java.util.regex.Pattern;
057
058/**
059 * A DataSource which can read LibSVM formatted data.
060 * <p>
061 * It also provides a static save method which writes LibSVM format data.
062 * <p>
063 * This class can read libsvm files which are zero-indexed or one-indexed, and the
064 * parsed result is available after construction. When loading testing data it's
065 * best to use the maxFeatureID from the training data (or the number of features
066 * in the model) to ensure that the feature names are formatted with the appropriate
067 * number of leading zeros.
068 */
069public final class LibSVMDataSource<T extends Output<T>> implements ConfigurableDataSource<T> {
070    private static final Logger logger = Logger.getLogger(LibSVMDataSource.class.getName());
071
072    private static final Pattern splitPattern = Pattern.compile("\\s+");
073
074    // url is the store of record.
075    @Config(description="URL to load the data from. Either this or path must be set.")
076    private URL url;
077
078    @Config(description="Path to load the data from. Either this or url must be set.")
079    private Path path;
080
081    @Config(mandatory = true, description="The output factory to use.")
082    private OutputFactory<T> outputFactory;
083
084    @Config(description="Set to true if the features are zero indexed.")
085    private boolean zeroIndexed;
086
087    @Config(description="Sets the maximum feature id to load from the file.")
088    private int maxFeatureID = Integer.MIN_VALUE;
089
090    private boolean rangeSet;
091    private int minFeatureID = Integer.MAX_VALUE;
092
093    private final ArrayList<Example<T>> data = new ArrayList<>();
094
095    private LibSVMDataSourceProvenance provenance;
096
097    /**
098     * For olcut.
099     */
100    private LibSVMDataSource() {}
101
102    /**
103     * Constructs a LibSVMDataSource from the supplied path and output factory.
104     * @param path The path to load.
105     * @param outputFactory The output factory to use.
106     * @throws IOException If the file could not be read or is an invalid format.
107     */
108    public LibSVMDataSource(Path path, OutputFactory<T> outputFactory) throws IOException {
109        this(path,path.toUri().toURL(),outputFactory,false,false,0);
110    }
111
112    /**
113     * Constructs a LibSVMDataSource from the supplied path and output factory.
114     * <p>
115     * Also allows control over the maximum feature id and if the file is zero indexed.
116     * The maximum feature id is used as part of the padding calculation converting the
117     * integer feature numbers into Tribuo's String feature names and is important
118     * to set when loading test data to ensure that the names line up with the training
119     * names. For example if there are 110 features, but the test dataset only has features
120     * 0-90, then without setting {@code maxFeatureID = 110} all the features will be named
121     * "00" through "90", rather than the expected "000" - "090", leading to a mismatch.
122     * @param path The path to load.
123     * @param outputFactory The output factory to use.
124     * @param zeroIndexed Are the features in this file indexed from zero?
125     * @param maxFeatureID The maximum feature ID allowed.
126     * @throws IOException If the file could not be read or is an invalid format.
127     */
128    public LibSVMDataSource(Path path, OutputFactory<T> outputFactory, boolean zeroIndexed, int maxFeatureID) throws IOException {
129        this(path,path.toUri().toURL(),outputFactory,true,zeroIndexed,maxFeatureID);
130    }
131
132    /**
133     * Constructs a LibSVMDataSource from the supplied URL and output factory.
134     * @param url The url to load.
135     * @param outputFactory The output factory to use.
136     * @throws IOException If the url could not load or is in an invalid format.
137     */
138    public LibSVMDataSource(URL url, OutputFactory<T> outputFactory) throws IOException {
139        this(null,url,outputFactory,false,false,0);
140    }
141
142    /**
143     * Constructs a LibSVMDataSource from the supplied URL and output factory.
144     * <p>
145     * Also allows control over the maximum feature id and if the file is zero indexed.
146     * The maximum feature id is used as part of the padding calculation converting the
147     * integer feature numbers into Tribuo's String feature names and is important
148     * to set when loading test data to ensure that the names line up with the training
149     * names. For example if there are 110 features, but the test dataset only has features
150     * 0-90, then without setting {@code maxFeatureID = 110} all the features will be named
151     * "00" through "90", rather than the expected "000" - "090", leading to a mismatch.
152     * @param url The url to load.
153     * @param outputFactory The output factory to use.
154     * @param zeroIndexed Are the features in this file indexed from zero?
155     * @param maxFeatureID The maximum feature ID allowed.
156     * @throws IOException If the url could not load or is in an invalid format.
157     */
158    public LibSVMDataSource(URL url, OutputFactory<T> outputFactory, boolean zeroIndexed, int maxFeatureID) throws IOException {
159        this(null,url,outputFactory,true,zeroIndexed,maxFeatureID);
160    }
161
162    /**
163     * Constructs a LibSVMDataSource from the supplied url or path and output factory.
164     * <p>
165     * One of the url or path must be null.
166     * <p>
167     * Also allows control over the maximum feature id and if the file is zero indexed.
168     * The maximum feature id is used as part of the padding calculation converting the
169     * integer feature numbers into Tribuo's String feature names and is important
170     * to set when loading test data to ensure that the names line up with the training
171     * names. For example if there are 110 features, but the test dataset only has features
172     * 0-90, then without setting {@code maxFeatureID = 110} all the features will be named
173     * "00" through "90", rather than the expected "000" - "090", leading to a mismatch.
174     * @param url The url to load.
175     * @param outputFactory The output factory to use.
176     * @param zeroIndexed Are the features in this file indexed from zero?
177     * @param maxFeatureID The maximum feature ID allowed.
178     * @throws IOException If the url could not load or is in an invalid format.
179     */
180    private LibSVMDataSource(Path path, URL url, OutputFactory<T> outputFactory, boolean rangeSet, boolean zeroIndexed, int maxFeatureID) throws IOException {
181        if (url == null && path == null) {
182            throw new IllegalArgumentException("Must supply a non-null path or url.");
183        }
184        this.path = path;
185        this.url = url;
186        if (outputFactory == null) {
187            throw new IllegalArgumentException("outputFactory must not be null");
188        }
189        this.outputFactory = outputFactory;
190        this.rangeSet = rangeSet;
191        if (rangeSet) {
192            this.zeroIndexed = zeroIndexed;
193            this.minFeatureID = zeroIndexed ? 0 : 1;
194            if (maxFeatureID < minFeatureID + 1) {
195                throw new IllegalArgumentException("maxFeatureID must be positive, found " + maxFeatureID);
196            }
197            this.maxFeatureID = maxFeatureID;
198        }
199        read();
200    }
201
202    /**
203     * Used by the OLCUT configuration system, and should not be called by external code.
204     */
205    @Override
206    public void postConfig() throws IOException {
207        if (maxFeatureID != Integer.MIN_VALUE) {
208            rangeSet = true;
209            minFeatureID = zeroIndexed ? 0 : 1;
210            if (maxFeatureID < minFeatureID + 1) {
211                throw new IllegalArgumentException("maxFeatureID must be positive, found " + maxFeatureID);
212            }
213        }
214        if ((url == null) && (path == null)) {
215            throw new PropertyException("","path","At most one of url and path must be set.");
216        } else if ((url != null) && (path != null) && !path.toUri().toURL().equals(url)) {
217            throw new PropertyException("","path","At most one of url and path must be set");
218        } else if (path != null) {
219            // url is the store of record.
220            try {
221                url = path.toUri().toURL();
222            } catch (MalformedURLException e) {
223                throw new PropertyException(e,"","path","Path was not a valid URL");
224            }
225        }
226        read();
227    }
228
229    /**
230     * Returns true if this dataset is zero indexed, false otherwise (i.e., it starts from 1).
231     * @return True if zero indexed.
232     */
233    public boolean isZeroIndexed() {
234        return minFeatureID == 0;
235    }
236
237    /**
238     * Gets the maximum feature ID found.
239     * @return The maximum feature id.
240     */
241    public int getMaxFeatureID() {
242        return maxFeatureID;
243    }
244
245    @Override
246    public String toString() {
247        if (path != null) {
248            return "LibSVMDataSource(path=" + path.toString() + ",zeroIndexed="+zeroIndexed+",minFeatureID=" + minFeatureID + ",maxFeatureID=" + maxFeatureID + ")";
249        } else {
250            return "LibSVMDataSource(url=" + url.toString() + ",zeroIndexed="+zeroIndexed+",minFeatureID=" + minFeatureID + ",maxFeatureID=" + maxFeatureID + ")";
251        }
252    }
253
254    @Override
255    public OutputFactory<T> getOutputFactory() {
256        return outputFactory;
257    }
258
259    @Override
260    public synchronized DataSourceProvenance getProvenance() {
261        if (provenance == null) {
262            provenance = cacheProvenance();
263        }
264        return provenance;
265    }
266
267    private LibSVMDataSourceProvenance cacheProvenance() {
268        return new LibSVMDataSourceProvenance(this);
269    }
270
271    private void read() throws IOException {
272        int pos = 0;
273        ArrayList<HashMap<Integer,Double>> processedData = new ArrayList<>();
274        ArrayList<String> labels = new ArrayList<>();
275
276        // Idiom copied from Files.readAllLines,
277        // but this doesn't require keeping the whole file in RAM.
278        String line;
279        // Parse the libsvm file, ignoring malformed lines.
280        try (BufferedReader r = new BufferedReader(new InputStreamReader(url.openStream(),StandardCharsets.UTF_8))) {
281            for (;;) {
282                line = r.readLine();
283                if (line == null) {
284                    break;
285                }
286                pos++;
287                String[] fields = splitPattern.split(line);
288                try {
289                    boolean valid = true;
290                    HashMap<Integer, Double> features = new HashMap<>();
291                    for (int i = 1; i < fields.length && valid; i++) {
292                        int ind = fields[i].indexOf(':');
293                        if (ind < 0) {
294                            logger.warning(String.format("Weird line at %d", pos));
295                            valid = false;
296                        }
297                        String ids = fields[i].substring(0, ind);
298                        int id = Integer.parseInt(ids);
299                        if ((!rangeSet) && (maxFeatureID < id)) {
300                            maxFeatureID = id;
301                        }
302                        if ((!rangeSet) && (minFeatureID > id)) {
303                            minFeatureID = id;
304                        }
305                        double val = Double.parseDouble(fields[i].substring(ind + 1));
306                        Double value = features.put(id, val);
307                        if (value != null) {
308                            logger.warning(String.format("Repeated features at line %d", pos));
309                            valid = false;
310                        }
311                    }
312                    if (valid) {
313                        // Store the label
314                        labels.add(fields[0]);
315                        // Store the features
316                        processedData.add(features);
317                    } else {
318                        throw new IOException("Invalid LibSVM format file");
319                    }
320                } catch (NumberFormatException ex) {
321                    logger.warning(String.format("Weird line at %d", pos));
322                    throw new IOException("Invalid LibSVM format file", ex);
323                }
324            }
325        }
326
327        // Calculate the string width
328        int width = (""+maxFeatureID).length();
329        String formatString = "%0"+width+"d";
330
331        // Check to see if it's zero indexed or one indexed, if we didn't observe the zero feature
332        // we assume it's one indexed.
333        int maxID = maxFeatureID;
334        if (minFeatureID != 0) {
335            minFeatureID = 1;
336            zeroIndexed = false;
337        } else {
338            maxID++;
339            zeroIndexed = true;
340        }
341
342        String[] featureNames = new String[maxID];
343        for (int i = 0; i < maxID; i++) {
344            featureNames[i] = String.format(formatString,i);
345        }
346
347        // Generate examples from the processed data
348        ArrayList<Feature> buffer = new ArrayList<>();
349        for (int i = 0; i < processedData.size(); i++) {
350            String labelStr = labels.get(i);
351            HashMap<Integer,Double> features = processedData.get(i);
352            try {
353                T curLabel = outputFactory.generateOutput(labelStr);
354                ArrayExample<T> example = new ArrayExample<>(curLabel);
355                buffer.clear();
356                for (Map.Entry<Integer, Double> e : features.entrySet()) {
357                    // Null check to remove out of range feature indices from test data, if rangeSet was true
358                    int id = e.getKey() - minFeatureID;
359                    if (id < maxID)  {
360                        double value = e.getValue();
361                        Feature f = new Feature(featureNames[id], value);
362                        buffer.add(f);
363                    }
364                }
365                example.addAll(buffer);
366                data.add(example);
367            } catch (NumberFormatException e) {
368                // If the output isn't a valid number for regression tasks.
369                // Features are checked in the input loop above.
370                logger.warning(String.format("Failed to parse example %d",i));
371                throw new IOException("Invalid LibSVM format file");
372            }
373        }
374    }
375
376    /**
377     * The number of examples.
378     * @return The number of examples.
379     */
380    public int size() {
381        return data.size();
382    }
383
384    @Override
385    public Iterator<Example<T>> iterator() {
386        return data.iterator();
387    }
388
389    /**
390     * Writes out a dataset in LibSVM format.
391     * <p>
392     * Can write either zero indexed or one indexed.
393     *
394     * @param dataset The dataset to write out.
395     * @param out A stream to write it to.
396     * @param zeroIndexed If true start the feature numbers from zero, otherwise start from one.
397     * @param transformationFunc A function which transforms an {@link Output} into a number.
398     * @param <T> The type of the Output.
399     */
400    public static <T extends Output<T>> void writeLibSVMFormat(Dataset<T> dataset, PrintStream out, boolean zeroIndexed, Function<T,Number> transformationFunc) {
401        int modifier = zeroIndexed ? 0 : 1;
402        ImmutableFeatureMap featureMap = dataset.getFeatureIDMap();
403        for (Example<T> example : dataset) {
404            out.print(transformationFunc.apply(example.getOutput()));
405            out.print(' ');
406            for (Feature feature : example) {
407                out.print(featureMap.get(feature.getName()).getID() + modifier);
408                out.print(':');
409                out.print(feature.getValue());
410                out.print(' ');
411            }
412            out.print('\n');
413        }
414    }
415
416    /**
417     * The provenance for a {@link LibSVMDataSource}.
418     */
419    public static final class LibSVMDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance {
420        private static final long serialVersionUID = 1L;
421
422        private final DateTimeProvenance fileModifiedTime;
423        private final DateTimeProvenance dataSourceCreationTime;
424        private final HashProvenance sha256Hash;
425
426        /**
427         * Constructs a provenance from the host object's information.
428         * @param host The host LibSVMDataSource.
429         * @param <T> The output type.
430         */
431        <T extends Output<T>> LibSVMDataSourceProvenance(LibSVMDataSource<T> host) {
432            super(host,"DataSource");
433            Optional<OffsetDateTime> time = ProvenanceUtil.getModifiedTime(host.url);
434            this.fileModifiedTime = time.map(offsetDateTime -> new DateTimeProvenance(FILE_MODIFIED_TIME, offsetDateTime)).orElseGet(() -> new DateTimeProvenance(FILE_MODIFIED_TIME, OffsetDateTime.MIN));
435            this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now());
436            this.sha256Hash = new HashProvenance(DEFAULT_HASH_TYPE,RESOURCE_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.url));
437        }
438
439        /**
440         * Constructs a provenance during unmarshalling.
441         * @param map The map of unmarshalled provenances.
442         */
443        public LibSVMDataSourceProvenance(Map<String,Provenance> map) {
444            this(extractProvenanceInfo(map));
445        }
446
447        private LibSVMDataSourceProvenance(ExtractedInfo info) {
448            super(info);
449            this.fileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FILE_MODIFIED_TIME);
450            this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME);
451            this.sha256Hash = (HashProvenance) info.instanceValues.get(RESOURCE_HASH);
452        }
453
454        protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) {
455            Map<String,Provenance> configuredParameters = new HashMap<>(map);
456            String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName()).getValue();
457            String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName()).getValue();
458
459            Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>();
460            instanceParameters.put(FILE_MODIFIED_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,FILE_MODIFIED_TIME,DateTimeProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName()));
461            instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName()));
462            instanceParameters.put(RESOURCE_HASH,ObjectProvenance.checkAndExtractProvenance(configuredParameters,RESOURCE_HASH,HashProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName()));
463
464            return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters);
465        }
466
467        @Override
468        public boolean equals(Object o) {
469            if (this == o) return true;
470            if (!(o instanceof LibSVMDataSourceProvenance)) return false;
471            if (!super.equals(o)) return false;
472            LibSVMDataSourceProvenance pairs = (LibSVMDataSourceProvenance) o;
473            return fileModifiedTime.equals(pairs.fileModifiedTime) &&
474                    dataSourceCreationTime.equals(pairs.dataSourceCreationTime) &&
475                    sha256Hash.equals(pairs.sha256Hash);
476        }
477
478        @Override
479        public int hashCode() {
480            return Objects.hash(super.hashCode(), fileModifiedTime, dataSourceCreationTime, sha256Hash);
481        }
482
483        @Override
484        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
485            Map<String,PrimitiveProvenance<?>> map = super.getInstanceValues();
486
487            map.put(FILE_MODIFIED_TIME,fileModifiedTime);
488            map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime);
489            map.put(RESOURCE_HASH,sha256Hash);
490
491            return map;
492        }
493    }
494
495}