001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.data.text; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance; 022import com.oracle.labs.mlrg.olcut.provenance.Provenance; 023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance; 024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; 025import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; 026import org.tribuo.ConfigurableDataSource; 027import org.tribuo.Example; 028import org.tribuo.Output; 029import org.tribuo.OutputFactory; 030import org.tribuo.provenance.ConfiguredDataSourceProvenance; 031 032import java.io.BufferedReader; 033import java.io.FileInputStream; 034import java.io.IOException; 035import java.io.InputStreamReader; 036import java.nio.charset.Charset; 037import java.nio.charset.StandardCharsets; 038import java.nio.file.DirectoryStream; 039import java.nio.file.Files; 040import java.nio.file.Path; 041import java.nio.file.Paths; 042import java.time.Instant; 043import java.time.OffsetDateTime; 044import java.time.ZoneId; 045import java.util.ArrayDeque; 046import java.util.ArrayList; 047import java.util.Arrays; 048import java.util.HashMap; 049import java.util.Iterator; 050import java.util.List; 051import java.util.Map; 052import java.util.NoSuchElementException; 053import java.util.Objects; 054import java.util.Queue; 055import java.util.logging.Logger; 056 057/** 058 * A data source for a somewhat-common format for text classification datasets: 059 * a top level directory that contains a number of subdirectories. Each of these 060 * subdirectories contains the data for a output whose name is the name of the 061 * subdirectory. 062 * <p> 063 * In these subdirectories are a number of files. Each file represents a single 064 * document that should be labeled with the name of the subdirectory. 065 * <p> 066 * This data source will produce appropriately labeled {@code Examples<T>} 067 * from each of these files. 068 * 069 * @param <T> The type of the features built by the underlying text processing 070 * infrastructure. 071 */ 072public class DirectoryFileSource<T extends Output<T>> implements ConfigurableDataSource<T> { 073 074 private static final Logger logger = Logger.getLogger(DirectoryFileSource.class.getName()); 075 076 /** 077 * The top-level directory containing the data set. 078 */ 079 @Config(description="The top-level directory containing the data set.") 080 private Path dataDir = Paths.get("."); 081 082 private final Charset enc = StandardCharsets.UTF_8; 083 084 /** 085 * Document preprocessors that should be run on the documents that make up 086 * this data set. 087 */ 088 @Config(description="The preprocessors to apply to the input documents.") 089 protected List<DocumentPreprocessor> preprocessors = new ArrayList<>(); 090 091 /** 092 * The factory that converts a String into an {@link Output}. 093 */ 094 @Config(mandatory=true,description="The output factory to use.") 095 protected OutputFactory<T> outputFactory; 096 097 /** 098 * The extractor that we'll use to turn text into examples. 099 */ 100 @Config(mandatory=true,description="The feature extractor that converts text into examples.") 101 protected TextFeatureExtractor<T> extractor; 102 103 /** 104 * for olcut 105 */ 106 protected DirectoryFileSource() {} 107 108 /** 109 * Creates a data source that will use the given feature extractor and 110 * document preprocessors on the data read from the files in the directories 111 * representing classes. 112 * 113 * @param outputFactory The output factory used to generate the outputs. 114 * @param extractor The text feature extractor that will run on the 115 * documents. 116 * @param preprocessors Pre-processors that we will run on the documents 117 * before extracting their features. 118 */ 119 public DirectoryFileSource(OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor, DocumentPreprocessor... preprocessors) { 120 this.outputFactory = outputFactory; 121 this.extractor = extractor; 122 this.preprocessors.addAll(Arrays.asList(preprocessors)); 123 } 124 125 public DirectoryFileSource(Path newsDir, OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor, DocumentPreprocessor... preprocessors) { 126 this.dataDir = newsDir; 127 this.outputFactory = outputFactory; 128 this.extractor = extractor; 129 this.preprocessors.addAll(Arrays.asList(preprocessors)); 130 } 131 132 @Override 133 public String toString() { 134 return "DirectoryDataSource(directory="+dataDir.toString()+",extractor="+extractor.toString()+",preprocessors="+preprocessors.toString()+")"; 135 } 136 137 @Override 138 public OutputFactory<T> getOutputFactory() { 139 return outputFactory; 140 } 141 142 @Override 143 public Iterator<Example<T>> iterator() { 144 return new DirectoryIterator(); 145 } 146 147 private class DirectoryIterator implements Iterator<Example<T>> { 148 149 /** 150 * The top-level paths in the provided directory, which is to say the 151 * directories that give the labels their names. 152 */ 153 private final Queue<Path> labelDirs = new ArrayDeque<>(); 154 155 /** 156 * The path for the current output, resolved against the top-level 157 * directory. 158 */ 159 private Path labelPath; 160 161 /** 162 * The current output to apply to docs. 163 */ 164 private String label; 165 166 /** 167 * The paths for the files in a particular output directory. 168 */ 169 private final Queue<Path> labelPaths = new ArrayDeque<>(); 170 171 private final StringBuilder db = new StringBuilder(); 172 173 public DirectoryIterator() { 174 // 175 // Get the top-level paths AKA the tags. 176 try (DirectoryStream<Path> stream = Files.newDirectoryStream(dataDir)) { 177 for (Path entry : stream) { 178 labelDirs.offer(entry); 179 } 180 } catch (IOException ex) { 181 throw new IllegalStateException("Can't open directory " + dataDir, ex); 182 } 183 logger.info(String.format("Got %d output directories in %s", labelDirs.size(), dataDir)); 184 } 185 186 @Override 187 public boolean hasNext() { 188 if (labelPaths.isEmpty()) { 189 return !labelDirs.isEmpty(); 190 } 191 return true; 192 } 193 194 @Override 195 public Example<T> next() { 196 if (labelPaths.isEmpty()) { 197 if (labelDirs.isEmpty()) { 198 throw new NoSuchElementException("No more files"); 199 } else { 200 labelPath = labelDirs.poll(); 201 label = labelPath.getFileName().toString(); 202 try (DirectoryStream<Path> stream = Files.newDirectoryStream(labelPath)) { 203 for (Path entry : stream) { 204 labelPaths.offer(entry); 205 } 206 logger.info(String.format("Got %d paths in %s", labelPaths.size(), labelPath)); 207 } catch (IOException ex) { 208 throw new IllegalStateException("Can't open directory " + labelPath, ex); 209 } 210 } 211 } 212 Path p = labelPaths.poll(); 213 db.delete(0, db.length()); 214 try (BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(p.toFile()), enc))) { 215 String line; 216 while ((line = br.readLine()) != null) { 217 line = line.trim(); 218 if (line.isEmpty()) { 219 db.append('\n'); 220 } else { 221 db.append(line); 222 } 223 db.append('\n'); 224 } 225 String postproc = db.toString(); 226 for (DocumentPreprocessor preproc : preprocessors) { 227 postproc = preproc.processDoc(postproc); 228 if (postproc == null) { 229 break; 230 } 231 } 232 if (postproc != null) { 233 Example<T> ret = extractor.extract(outputFactory.generateOutput(label), postproc); 234 return ret; 235 } else { 236 // 237 // Uh, it got post processed away. See if there's another one 238 // and return it. 239 if (!hasNext()) { 240 throw new NoSuchElementException("No more files"); 241 } 242 return next(); 243 } 244 } catch (IOException ex) { 245 throw new IllegalStateException("Error reading path " + p, ex); 246 } 247 } 248 249 } 250 251 @Override 252 public ConfiguredDataSourceProvenance getProvenance() { 253 return new DirectoryFileSourceProvenance(this); 254 } 255 256 /** 257 * Provenance for {@link DirectoryFileSource}. 258 */ 259 public static class DirectoryFileSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance { 260 private static final long serialVersionUID = 1L; 261 262 private final DateTimeProvenance fileModifiedTime; 263 private final DateTimeProvenance dataSourceCreationTime; 264 265 <T extends Output<T>> DirectoryFileSourceProvenance(DirectoryFileSource<T> host) { 266 super(host,"DataSource"); 267 this.fileModifiedTime = new DateTimeProvenance(FILE_MODIFIED_TIME,OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.dataDir.toFile().lastModified()), ZoneId.systemDefault())); 268 this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now()); 269 } 270 271 public DirectoryFileSourceProvenance(Map<String,Provenance> map) { 272 this(extractProvenanceInfo(map)); 273 } 274 275 private DirectoryFileSourceProvenance(ExtractedInfo info) { 276 super(info); 277 this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME); 278 this.fileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FILE_MODIFIED_TIME); 279 } 280 281 protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) { 282 Map<String,Provenance> configuredParameters = new HashMap<>(map); 283 String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, DirectoryFileSourceProvenance.class.getSimpleName()).getValue(); 284 String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, DirectoryFileSourceProvenance.class.getSimpleName()).getValue(); 285 286 Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>(); 287 instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, DirectoryFileSourceProvenance.class.getSimpleName())); 288 289 return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters); 290 } 291 292 @Override 293 public Map<String, PrimitiveProvenance<?>> getInstanceValues() { 294 Map<String,PrimitiveProvenance<?>> map = new HashMap<>(); 295 296 map.put(FILE_MODIFIED_TIME,fileModifiedTime); 297 map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime); 298 299 return map; 300 } 301 302 @Override 303 public boolean equals(Object o) { 304 if (this == o) return true; 305 if (!(o instanceof DirectoryFileSourceProvenance)) return false; 306 if (!super.equals(o)) return false; 307 DirectoryFileSourceProvenance pairs = (DirectoryFileSourceProvenance) o; 308 return fileModifiedTime.equals(pairs.fileModifiedTime) && 309 dataSourceCreationTime.equals(pairs.dataSourceCreationTime); 310 } 311 312 @Override 313 public int hashCode() { 314 return Objects.hash(super.hashCode(), fileModifiedTime, dataSourceCreationTime); 315 } 316 } 317}