001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.datasource; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.config.PropertyException; 021import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 023import com.oracle.labs.mlrg.olcut.provenance.Provenance; 024import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 025import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 027import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 028import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 029import org.tribuo.ConfigurableDataSource; 030import org.tribuo.Dataset; 031import org.tribuo.Example; 032import org.tribuo.Feature; 033import org.tribuo.ImmutableFeatureMap; 034import org.tribuo.Output; 035import org.tribuo.OutputFactory; 036import org.tribuo.impl.ArrayExample; 037import org.tribuo.provenance.DataSourceProvenance; 038 039import java.io.BufferedReader; 040import java.io.IOException; 041import java.io.InputStreamReader; 042import java.io.PrintStream; 043import java.net.MalformedURLException; 044import java.net.URL; 045import java.nio.charset.StandardCharsets; 046import java.nio.file.Path; 047import java.time.OffsetDateTime; 048import java.util.ArrayList; 049import java.util.HashMap; 050import java.util.Iterator; 051import java.util.Map; 052import java.util.Objects; 053import java.util.Optional; 054import java.util.function.Function; 055import java.util.logging.Logger; 056import java.util.regex.Pattern; 057 058/** 059 * A DataSource which can read LibSVM formatted data. 060 * <p> 061 * It also provides a static save method which writes LibSVM format data. 062 * <p> 063 * This class can read libsvm files which are zero-indexed or one-indexed, and the 064 * parsed result is available after construction. When loading testing data it's 065 * best to use the maxFeatureID from the training data (or the number of features 066 * in the model) to ensure that the feature names are formatted with the appropriate 067 * number of leading zeros. 068 */ 069public final class LibSVMDataSource<T extends Output<T>> implements ConfigurableDataSource<T> { 070 private static final Logger logger = Logger.getLogger(LibSVMDataSource.class.getName()); 071 072 private static final Pattern splitPattern = Pattern.compile("\\s+"); 073 074 // url is the store of record. 075 @Config(description="URL to load the data from. Either this or path must be set.") 076 private URL url; 077 078 @Config(description="Path to load the data from. Either this or url must be set.") 079 private Path path; 080 081 @Config(mandatory = true, description="The output factory to use.") 082 private OutputFactory<T> outputFactory; 083 084 @Config(description="Set to true if the features are zero indexed.") 085 private boolean zeroIndexed; 086 087 @Config(description="Sets the maximum feature id to load from the file.") 088 private int maxFeatureID = Integer.MIN_VALUE; 089 090 private boolean rangeSet; 091 private int minFeatureID = Integer.MAX_VALUE; 092 093 private final ArrayList<Example<T>> data = new ArrayList<>(); 094 095 private LibSVMDataSourceProvenance provenance; 096 097 /** 098 * For olcut. 099 */ 100 private LibSVMDataSource() {} 101 102 /** 103 * Constructs a LibSVMDataSource from the supplied path and output factory. 104 * @param path The path to load. 105 * @param outputFactory The output factory to use. 106 * @throws IOException If the file could not be read or is an invalid format. 107 */ 108 public LibSVMDataSource(Path path, OutputFactory<T> outputFactory) throws IOException { 109 this(path,path.toUri().toURL(),outputFactory,false,false,0); 110 } 111 112 /** 113 * Constructs a LibSVMDataSource from the supplied path and output factory. 114 * <p> 115 * Also allows control over the maximum feature id and if the file is zero indexed. 116 * The maximum feature id is used as part of the padding calculation converting the 117 * integer feature numbers into Tribuo's String feature names and is important 118 * to set when loading test data to ensure that the names line up with the training 119 * names. For example if there are 110 features, but the test dataset only has features 120 * 0-90, then without setting {@code maxFeatureID = 110} all the features will be named 121 * "00" through "90", rather than the expected "000" - "090", leading to a mismatch. 122 * @param path The path to load. 123 * @param outputFactory The output factory to use. 124 * @param zeroIndexed Are the features in this file indexed from zero? 125 * @param maxFeatureID The maximum feature ID allowed. 126 * @throws IOException If the file could not be read or is an invalid format. 127 */ 128 public LibSVMDataSource(Path path, OutputFactory<T> outputFactory, boolean zeroIndexed, int maxFeatureID) throws IOException { 129 this(path,path.toUri().toURL(),outputFactory,true,zeroIndexed,maxFeatureID); 130 } 131 132 /** 133 * Constructs a LibSVMDataSource from the supplied URL and output factory. 134 * @param url The url to load. 135 * @param outputFactory The output factory to use. 136 * @throws IOException If the url could not load or is in an invalid format. 137 */ 138 public LibSVMDataSource(URL url, OutputFactory<T> outputFactory) throws IOException { 139 this(null,url,outputFactory,false,false,0); 140 } 141 142 /** 143 * Constructs a LibSVMDataSource from the supplied URL and output factory. 144 * <p> 145 * Also allows control over the maximum feature id and if the file is zero indexed. 146 * The maximum feature id is used as part of the padding calculation converting the 147 * integer feature numbers into Tribuo's String feature names and is important 148 * to set when loading test data to ensure that the names line up with the training 149 * names. For example if there are 110 features, but the test dataset only has features 150 * 0-90, then without setting {@code maxFeatureID = 110} all the features will be named 151 * "00" through "90", rather than the expected "000" - "090", leading to a mismatch. 152 * @param url The url to load. 153 * @param outputFactory The output factory to use. 154 * @param zeroIndexed Are the features in this file indexed from zero? 155 * @param maxFeatureID The maximum feature ID allowed. 156 * @throws IOException If the url could not load or is in an invalid format. 157 */ 158 public LibSVMDataSource(URL url, OutputFactory<T> outputFactory, boolean zeroIndexed, int maxFeatureID) throws IOException { 159 this(null,url,outputFactory,true,zeroIndexed,maxFeatureID); 160 } 161 162 /** 163 * Constructs a LibSVMDataSource from the supplied url or path and output factory. 164 * <p> 165 * One of the url or path must be null. 166 * <p> 167 * Also allows control over the maximum feature id and if the file is zero indexed. 168 * The maximum feature id is used as part of the padding calculation converting the 169 * integer feature numbers into Tribuo's String feature names and is important 170 * to set when loading test data to ensure that the names line up with the training 171 * names. For example if there are 110 features, but the test dataset only has features 172 * 0-90, then without setting {@code maxFeatureID = 110} all the features will be named 173 * "00" through "90", rather than the expected "000" - "090", leading to a mismatch. 174 * @param url The url to load. 175 * @param outputFactory The output factory to use. 176 * @param zeroIndexed Are the features in this file indexed from zero? 177 * @param maxFeatureID The maximum feature ID allowed. 178 * @throws IOException If the url could not load or is in an invalid format. 179 */ 180 private LibSVMDataSource(Path path, URL url, OutputFactory<T> outputFactory, boolean rangeSet, boolean zeroIndexed, int maxFeatureID) throws IOException { 181 if (url == null && path == null) { 182 throw new IllegalArgumentException("Must supply a non-null path or url."); 183 } 184 this.path = path; 185 this.url = url; 186 if (outputFactory == null) { 187 throw new IllegalArgumentException("outputFactory must not be null"); 188 } 189 this.outputFactory = outputFactory; 190 this.rangeSet = rangeSet; 191 if (rangeSet) { 192 this.zeroIndexed = zeroIndexed; 193 this.minFeatureID = zeroIndexed ? 0 : 1; 194 if (maxFeatureID < minFeatureID + 1) { 195 throw new IllegalArgumentException("maxFeatureID must be positive, found " + maxFeatureID); 196 } 197 this.maxFeatureID = maxFeatureID; 198 } 199 read(); 200 } 201 202 /** 203 * Used by the OLCUT configuration system, and should not be called by external code. 204 */ 205 @Override 206 public void postConfig() throws IOException { 207 if (maxFeatureID != Integer.MIN_VALUE) { 208 rangeSet = true; 209 minFeatureID = zeroIndexed ? 0 : 1; 210 if (maxFeatureID < minFeatureID + 1) { 211 throw new IllegalArgumentException("maxFeatureID must be positive, found " + maxFeatureID); 212 } 213 } 214 if ((url == null) && (path == null)) { 215 throw new PropertyException("","path","At most one of url and path must be set."); 216 } else if ((url != null) && (path != null) && !path.toUri().toURL().equals(url)) { 217 throw new PropertyException("","path","At most one of url and path must be set"); 218 } else if (path != null) { 219 // url is the store of record. 220 try { 221 url = path.toUri().toURL(); 222 } catch (MalformedURLException e) { 223 throw new PropertyException(e,"","path","Path was not a valid URL"); 224 } 225 } 226 read(); 227 } 228 229 /** 230 * Returns true if this dataset is zero indexed, false otherwise (i.e., it starts from 1). 231 * @return True if zero indexed. 232 */ 233 public boolean isZeroIndexed() { 234 return minFeatureID == 0; 235 } 236 237 /** 238 * Gets the maximum feature ID found. 239 * @return The maximum feature id. 240 */ 241 public int getMaxFeatureID() { 242 return maxFeatureID; 243 } 244 245 @Override 246 public String toString() { 247 if (path != null) { 248 return "LibSVMDataSource(path=" + path.toString() + ",zeroIndexed="+zeroIndexed+",minFeatureID=" + minFeatureID + ",maxFeatureID=" + maxFeatureID + ")"; 249 } else { 250 return "LibSVMDataSource(url=" + url.toString() + ",zeroIndexed="+zeroIndexed+",minFeatureID=" + minFeatureID + ",maxFeatureID=" + maxFeatureID + ")"; 251 } 252 } 253 254 @Override 255 public OutputFactory<T> getOutputFactory() { 256 return outputFactory; 257 } 258 259 @Override 260 public synchronized DataSourceProvenance getProvenance() { 261 if (provenance == null) { 262 provenance = cacheProvenance(); 263 } 264 return provenance; 265 } 266 267 private LibSVMDataSourceProvenance cacheProvenance() { 268 return new LibSVMDataSourceProvenance(this); 269 } 270 271 private void read() throws IOException { 272 int pos = 0; 273 ArrayList<HashMap<Integer,Double>> processedData = new ArrayList<>(); 274 ArrayList<String> labels = new ArrayList<>(); 275 276 // Idiom copied from Files.readAllLines, 277 // but this doesn't require keeping the whole file in RAM. 278 String line; 279 // Parse the libsvm file, ignoring malformed lines. 280 try (BufferedReader r = new BufferedReader(new InputStreamReader(url.openStream(),StandardCharsets.UTF_8))) { 281 for (;;) { 282 line = r.readLine(); 283 if (line == null) { 284 break; 285 } 286 pos++; 287 String[] fields = splitPattern.split(line); 288 try { 289 boolean valid = true; 290 HashMap<Integer, Double> features = new HashMap<>(); 291 for (int i = 1; i < fields.length && valid; i++) { 292 int ind = fields[i].indexOf(':'); 293 if (ind < 0) { 294 logger.warning(String.format("Weird line at %d", pos)); 295 valid = false; 296 } 297 String ids = fields[i].substring(0, ind); 298 int id = Integer.parseInt(ids); 299 if ((!rangeSet) && (maxFeatureID < id)) { 300 maxFeatureID = id; 301 } 302 if ((!rangeSet) && (minFeatureID > id)) { 303 minFeatureID = id; 304 } 305 double val = Double.parseDouble(fields[i].substring(ind + 1)); 306 Double value = features.put(id, val); 307 if (value != null) { 308 logger.warning(String.format("Repeated features at line %d", pos)); 309 valid = false; 310 } 311 } 312 if (valid) { 313 // Store the label 314 labels.add(fields[0]); 315 // Store the features 316 processedData.add(features); 317 } else { 318 throw new IOException("Invalid LibSVM format file"); 319 } 320 } catch (NumberFormatException ex) { 321 logger.warning(String.format("Weird line at %d", pos)); 322 throw new IOException("Invalid LibSVM format file", ex); 323 } 324 } 325 } 326 327 // Calculate the string width 328 int width = (""+maxFeatureID).length(); 329 String formatString = "%0"+width+"d"; 330 331 // Check to see if it's zero indexed or one indexed, if we didn't observe the zero feature 332 // we assume it's one indexed. 333 int maxID = maxFeatureID; 334 if (minFeatureID != 0) { 335 minFeatureID = 1; 336 zeroIndexed = false; 337 } else { 338 maxID++; 339 zeroIndexed = true; 340 } 341 342 String[] featureNames = new String[maxID]; 343 for (int i = 0; i < maxID; i++) { 344 featureNames[i] = String.format(formatString,i); 345 } 346 347 // Generate examples from the processed data 348 ArrayList<Feature> buffer = new ArrayList<>(); 349 for (int i = 0; i < processedData.size(); i++) { 350 String labelStr = labels.get(i); 351 HashMap<Integer,Double> features = processedData.get(i); 352 try { 353 T curLabel = outputFactory.generateOutput(labelStr); 354 ArrayExample<T> example = new ArrayExample<>(curLabel); 355 buffer.clear(); 356 for (Map.Entry<Integer, Double> e : features.entrySet()) { 357 // Null check to remove out of range feature indices from test data, if rangeSet was true 358 int id = e.getKey() - minFeatureID; 359 if (id < maxID) { 360 double value = e.getValue(); 361 Feature f = new Feature(featureNames[id], value); 362 buffer.add(f); 363 } 364 } 365 example.addAll(buffer); 366 data.add(example); 367 } catch (NumberFormatException e) { 368 // If the output isn't a valid number for regression tasks. 369 // Features are checked in the input loop above. 370 logger.warning(String.format("Failed to parse example %d",i)); 371 throw new IOException("Invalid LibSVM format file"); 372 } 373 } 374 } 375 376 /** 377 * The number of examples. 378 * @return The number of examples. 379 */ 380 public int size() { 381 return data.size(); 382 } 383 384 @Override 385 public Iterator<Example<T>> iterator() { 386 return data.iterator(); 387 } 388 389 /** 390 * Writes out a dataset in LibSVM format. 391 * <p> 392 * Can write either zero indexed or one indexed. 393 * 394 * @param dataset The dataset to write out. 395 * @param out A stream to write it to. 396 * @param zeroIndexed If true start the feature numbers from zero, otherwise start from one. 397 * @param transformationFunc A function which transforms an {@link Output} into a number. 398 * @param <T> The type of the Output. 399 */ 400 public static <T extends Output<T>> void writeLibSVMFormat(Dataset<T> dataset, PrintStream out, boolean zeroIndexed, Function<T,Number> transformationFunc) { 401 int modifier = zeroIndexed ? 0 : 1; 402 ImmutableFeatureMap featureMap = dataset.getFeatureIDMap(); 403 for (Example<T> example : dataset) { 404 out.print(transformationFunc.apply(example.getOutput())); 405 out.print(' '); 406 for (Feature feature : example) { 407 out.print(featureMap.get(feature.getName()).getID() + modifier); 408 out.print(':'); 409 out.print(feature.getValue()); 410 out.print(' '); 411 } 412 out.print('\n'); 413 } 414 } 415 416 /** 417 * The provenance for a {@link LibSVMDataSource}. 418 */ 419 public static final class LibSVMDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance { 420 private static final long serialVersionUID = 1L; 421 422 private final DateTimeProvenance fileModifiedTime; 423 private final DateTimeProvenance dataSourceCreationTime; 424 private final HashProvenance sha256Hash; 425 426 /** 427 * Constructs a provenance from the host object's information. 428 * @param host The host LibSVMDataSource. 429 * @param <T> The output type. 430 */ 431 <T extends Output<T>> LibSVMDataSourceProvenance(LibSVMDataSource<T> host) { 432 super(host,"DataSource"); 433 Optional<OffsetDateTime> time = ProvenanceUtil.getModifiedTime(host.url); 434 this.fileModifiedTime = time.map(offsetDateTime -> new DateTimeProvenance(FILE_MODIFIED_TIME, offsetDateTime)).orElseGet(() -> new DateTimeProvenance(FILE_MODIFIED_TIME, OffsetDateTime.MIN)); 435 this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now()); 436 this.sha256Hash = new HashProvenance(DEFAULT_HASH_TYPE,RESOURCE_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.url)); 437 } 438 439 /** 440 * Constructs a provenance during unmarshalling. 441 * @param map The map of unmarshalled provenances. 442 */ 443 public LibSVMDataSourceProvenance(Map<String,Provenance> map) { 444 this(extractProvenanceInfo(map)); 445 } 446 447 private LibSVMDataSourceProvenance(ExtractedInfo info) { 448 super(info); 449 this.fileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FILE_MODIFIED_TIME); 450 this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME); 451 this.sha256Hash = (HashProvenance) info.instanceValues.get(RESOURCE_HASH); 452 } 453 454 protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) { 455 Map<String,Provenance> configuredParameters = new HashMap<>(map); 456 String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName()).getValue(); 457 String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName()).getValue(); 458 459 Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>(); 460 instanceParameters.put(FILE_MODIFIED_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,FILE_MODIFIED_TIME,DateTimeProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName())); 461 instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName())); 462 instanceParameters.put(RESOURCE_HASH,ObjectProvenance.checkAndExtractProvenance(configuredParameters,RESOURCE_HASH,HashProvenance.class, LibSVMDataSourceProvenance.class.getSimpleName())); 463 464 return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters); 465 } 466 467 @Override 468 public boolean equals(Object o) { 469 if (this == o) return true; 470 if (!(o instanceof LibSVMDataSourceProvenance)) return false; 471 if (!super.equals(o)) return false; 472 LibSVMDataSourceProvenance pairs = (LibSVMDataSourceProvenance) o; 473 return fileModifiedTime.equals(pairs.fileModifiedTime) && 474 dataSourceCreationTime.equals(pairs.dataSourceCreationTime) && 475 sha256Hash.equals(pairs.sha256Hash); 476 } 477 478 @Override 479 public int hashCode() { 480 return Objects.hash(super.hashCode(), fileModifiedTime, dataSourceCreationTime, sha256Hash); 481 } 482 483 @Override 484 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 485 Map<String,PrimitiveProvenance<?>> map = super.getInstanceValues(); 486 487 map.put(FILE_MODIFIED_TIME,fileModifiedTime); 488 map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime); 489 map.put(RESOURCE_HASH,sha256Hash); 490 491 return map; 492 } 493 } 494 495}