001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.text.impl; 018 019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil; 023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance; 026import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 027import org.tribuo.Example; 028import org.tribuo.Output; 029import org.tribuo.OutputFactory; 030import org.tribuo.data.text.TextDataSource; 031import org.tribuo.data.text.TextFeatureExtractor; 032import org.tribuo.provenance.ConfiguredDataSourceProvenance; 033 034import java.io.File; 035import java.io.IOException; 036import java.nio.charset.StandardCharsets; 037import java.nio.file.Files; 038import java.nio.file.Path; 039import java.time.Instant; 040import java.time.OffsetDateTime; 041import java.time.ZoneId; 042import java.util.HashMap; 043import java.util.Map; 044import java.util.Objects; 045import java.util.Optional; 046import java.util.logging.Logger; 047import java.util.regex.Pattern; 048 049/** 050 * A dataset for a simple data format for text classification experiments. A line 051 * in the file looks like: 052 * 053 * <pre> 054 * OUTPUT##Document text 055 * </pre> 056 * 057 * Each line in the file specifies a single output and document pair. Leading and 058 * trailing spaces will be trimmed from outputs and documents. Outputs will be 059 * converted to upper case. 060 * 061 * <p> 062 * 063 * As with all of our text data, the file should be in UTF-8. 064 */ 065public class SimpleTextDataSource<T extends Output<T>> extends TextDataSource<T> { 066 067 private static final Logger logger = Logger.getLogger(SimpleTextDataSource.class.getName()); 068 069 private static final Pattern splitPattern = Pattern.compile("##"); 070 071 protected ConfiguredDataSourceProvenance provenance; 072 073 /** 074 * for olcut 075 */ 076 protected SimpleTextDataSource() {} 077 078 public SimpleTextDataSource(Path path, OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor) throws IOException { 079 super(path, outputFactory, extractor); 080 postConfig(); 081 } 082 083 public SimpleTextDataSource(File file, OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor) throws IOException { 084 super(file, outputFactory, extractor); 085 postConfig(); 086 } 087 088 protected SimpleTextDataSource(OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor) { 089 super((Path)null,outputFactory,extractor); 090 } 091 092 /** 093 * Used by the OLCUT configuration system, and should not be called by external code. 094 */ 095 @Override 096 public void postConfig() throws IOException { 097 read(); 098 provenance = cacheProvenance(); 099 } 100 101 protected Optional<Example<T>> parseLine(String line, int n) { 102 line = line.trim(); 103 if(line.isEmpty()) { 104 return Optional.empty(); 105 } 106 String[] fields = splitPattern.split(line); 107 if(fields.length != 2) { 108 logger.warning(String.format("Bad line in %s at %d: %s", 109 path, n, line.substring(Math.min(50, line.length())))); 110 return Optional.empty(); 111 } 112 T label = outputFactory.generateOutput(fields[0].trim().toUpperCase()); 113 return Optional.of(extractor.extract(label, handleDoc(fields[1].trim()))); 114 } 115 116 @Override 117 protected void read() throws IOException { 118 int n = 0; 119 for (String line : Files.readAllLines(path, StandardCharsets.UTF_8)) { 120 n++; 121 Optional<Example<T>> example = parseLine(line, n); 122 if (example.isPresent()) { 123 Example<T> e = example.get(); 124 if (e.validateExample()) { 125 data.add(e); 126 } else { 127 logger.warning("Invalid example found after parsing line " + n); 128 } 129 } 130 } 131 } 132 133 @Override 134 public ConfiguredDataSourceProvenance getProvenance() { 135 return provenance; 136 } 137 138 protected ConfiguredDataSourceProvenance cacheProvenance() { 139 return new SimpleTextDataSourceProvenance(this); 140 } 141 142 /** 143 * Provenance for {@link SimpleTextDataSource}. 144 */ 145 public static class SimpleTextDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance { 146 private static final long serialVersionUID = 1L; 147 148 private final DateTimeProvenance fileModifiedTime; 149 private final DateTimeProvenance dataSourceCreationTime; 150 private final HashProvenance sha256Hash; 151 152 <T extends Output<T>> SimpleTextDataSourceProvenance(SimpleTextDataSource<T> host) { 153 super(host,"DataSource"); 154 this.fileModifiedTime = new DateTimeProvenance(FILE_MODIFIED_TIME,OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.path.toFile().lastModified()), ZoneId.systemDefault())); 155 this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now()); 156 this.sha256Hash = new HashProvenance(DEFAULT_HASH_TYPE,RESOURCE_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.path)); 157 } 158 159 public SimpleTextDataSourceProvenance(Map<String,Provenance> map) { 160 this(extractProvenanceInfo(map)); 161 } 162 163 private SimpleTextDataSourceProvenance(ExtractedInfo info) { 164 super(info); 165 this.fileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FILE_MODIFIED_TIME); 166 this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME); 167 this.sha256Hash = (HashProvenance) info.instanceValues.get(RESOURCE_HASH); 168 } 169 170 protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) { 171 Map<String,Provenance> configuredParameters = new HashMap<>(map); 172 String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName()).getValue(); 173 String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName()).getValue(); 174 175 Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>(); 176 instanceParameters.put(FILE_MODIFIED_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,FILE_MODIFIED_TIME,DateTimeProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName())); 177 instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName())); 178 instanceParameters.put(RESOURCE_HASH,ObjectProvenance.checkAndExtractProvenance(configuredParameters,RESOURCE_HASH,HashProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName())); 179 180 return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters); 181 } 182 183 @Override 184 public boolean equals(Object o) { 185 if (this == o) return true; 186 if (!(o instanceof SimpleTextDataSourceProvenance)) return false; 187 if (!super.equals(o)) return false; 188 SimpleTextDataSourceProvenance pairs = (SimpleTextDataSourceProvenance) o; 189 return fileModifiedTime.equals(pairs.fileModifiedTime) && 190 dataSourceCreationTime.equals(pairs.dataSourceCreationTime) && 191 sha256Hash.equals(pairs.sha256Hash); 192 } 193 194 @Override 195 public int hashCode() { 196 return Objects.hash(super.hashCode(), fileModifiedTime, dataSourceCreationTime, sha256Hash); 197 } 198 199 @Override 200 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 201 Map<String,PrimitiveProvenance<?>> map = new HashMap<>(); 202 203 map.put(FILE_MODIFIED_TIME,fileModifiedTime); 204 map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime); 205 map.put(RESOURCE_HASH,sha256Hash); 206 207 return map; 208 } 209 } 210 211}