001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.text;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
022import com.oracle.labs.mlrg.olcut.provenance.Provenance;
023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
026import org.tribuo.ConfigurableDataSource;
027import org.tribuo.Example;
028import org.tribuo.Output;
029import org.tribuo.OutputFactory;
030import org.tribuo.provenance.ConfiguredDataSourceProvenance;
031
032import java.io.BufferedReader;
033import java.io.FileInputStream;
034import java.io.IOException;
035import java.io.InputStreamReader;
036import java.nio.charset.Charset;
037import java.nio.charset.StandardCharsets;
038import java.nio.file.DirectoryStream;
039import java.nio.file.Files;
040import java.nio.file.Path;
041import java.nio.file.Paths;
042import java.time.Instant;
043import java.time.OffsetDateTime;
044import java.time.ZoneId;
045import java.util.ArrayDeque;
046import java.util.ArrayList;
047import java.util.Arrays;
048import java.util.HashMap;
049import java.util.Iterator;
050import java.util.List;
051import java.util.Map;
052import java.util.NoSuchElementException;
053import java.util.Objects;
054import java.util.Queue;
055import java.util.logging.Logger;
056
057/**
058 * A data source for a somewhat-common format for text classification datasets:
059 * a top level directory that contains a number of subdirectories. Each of these
060 * subdirectories contains the data for a output whose name is the name of the
061 * subdirectory.
062 * <p>
063 * In these subdirectories are a number of files. Each file represents a single
064 * document that should be labeled with the name of the subdirectory.
065 * <p>
066 * This data source will produce appropriately labeled {@code Examples<T>}
067 * from each of these files.
068 *
069 * @param <T> The type of the features built by the underlying text processing
070 * infrastructure.
071 */
072public class DirectoryFileSource<T extends Output<T>> implements ConfigurableDataSource<T> {
073
074    private static final Logger logger = Logger.getLogger(DirectoryFileSource.class.getName());
075
076    /**
077     * The top-level directory containing the data set.
078     */
079    @Config(description="The top-level directory containing the data set.")
080    private Path dataDir = Paths.get(".");
081
082    private final Charset enc = StandardCharsets.UTF_8;
083
084    /**
085     * Document preprocessors that should be run on the documents that make up
086     * this data set.
087     */
088    @Config(description="The preprocessors to apply to the input documents.")
089    protected List<DocumentPreprocessor> preprocessors = new ArrayList<>();
090
091    /**
092     * The factory that converts a String into an {@link Output}.
093     */
094    @Config(mandatory=true,description="The output factory to use.")
095    protected OutputFactory<T> outputFactory;
096
097    /**
098     * The extractor that we'll use to turn text into examples.
099     */
100    @Config(mandatory=true,description="The feature extractor that converts text into examples.")
101    protected TextFeatureExtractor<T> extractor;
102
103    /**
104     * for olcut
105     */
106    protected DirectoryFileSource() {}
107
108    /**
109     * Creates a data source that will use the given feature extractor and
110     * document preprocessors on the data read from the files in the directories
111     * representing classes.
112     *
113     * @param outputFactory The output factory used to generate the outputs.
114     * @param extractor The text feature extractor that will run on the
115     * documents.
116     * @param preprocessors Pre-processors that we will run on the documents
117     * before extracting their features.
118     */
119    public DirectoryFileSource(OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor, DocumentPreprocessor... preprocessors) {
120        this.outputFactory = outputFactory;
121        this.extractor = extractor;
122        this.preprocessors.addAll(Arrays.asList(preprocessors));
123    }
124
125    public DirectoryFileSource(Path newsDir, OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor, DocumentPreprocessor... preprocessors) {
126        this.dataDir = newsDir;
127        this.outputFactory = outputFactory;
128        this.extractor = extractor;
129        this.preprocessors.addAll(Arrays.asList(preprocessors));
130    }
131
132    @Override
133    public String toString() {
134        return "DirectoryDataSource(directory="+dataDir.toString()+",extractor="+extractor.toString()+",preprocessors="+preprocessors.toString()+")";
135    }
136
137    @Override
138    public OutputFactory<T> getOutputFactory() {
139        return outputFactory;
140    }
141
142    @Override
143    public Iterator<Example<T>> iterator() {
144        return new DirectoryIterator();
145    }
146
147    private class DirectoryIterator implements Iterator<Example<T>> {
148
149        /**
150         * The top-level paths in the provided directory, which is to say the
151         * directories that give the labels their names.
152         */
153        private final Queue<Path> labelDirs = new ArrayDeque<>();
154
155        /**
156         * The path for the current output, resolved against the top-level
157         * directory.
158         */
159        private Path labelPath;
160
161        /**
162         * The current output to apply to docs.
163         */
164        private String label;
165
166        /**
167         * The paths for the files in a particular output directory.
168         */
169        private final Queue<Path> labelPaths = new ArrayDeque<>();
170
171        private final StringBuilder db = new StringBuilder();
172
173        public DirectoryIterator() {
174            //
175            // Get the top-level paths AKA the tags.
176            try (DirectoryStream<Path> stream = Files.newDirectoryStream(dataDir)) {
177                for (Path entry : stream) {
178                    labelDirs.offer(entry);
179                }
180            } catch (IOException ex) {
181                throw new IllegalStateException("Can't open directory " + dataDir, ex);
182            }
183            logger.info(String.format("Got %d output directories in %s", labelDirs.size(), dataDir));
184        }
185
186        @Override
187        public boolean hasNext() {
188            if (labelPaths.isEmpty()) {
189                return !labelDirs.isEmpty();
190            }
191            return true;
192        }
193
194        @Override
195        public Example<T> next() {
196            if (labelPaths.isEmpty()) {
197                if (labelDirs.isEmpty()) {
198                    throw new NoSuchElementException("No more files");
199                } else {
200                    labelPath = labelDirs.poll();
201                    label = labelPath.getFileName().toString();
202                    try (DirectoryStream<Path> stream = Files.newDirectoryStream(labelPath)) {
203                        for (Path entry : stream) {
204                            labelPaths.offer(entry);
205                        }
206                        logger.info(String.format("Got %d paths in %s", labelPaths.size(), labelPath));
207                    } catch (IOException ex) {
208                        throw new IllegalStateException("Can't open directory " + labelPath, ex);
209                    }
210                }
211            }
212            Path p = labelPaths.poll();
213            db.delete(0, db.length());
214            try (BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(p.toFile()), enc))) {
215                String line;
216                while ((line = br.readLine()) != null) {
217                    line = line.trim();
218                    if (line.isEmpty()) {
219                        db.append('\n');
220                    } else {
221                        db.append(line);
222                    }
223                    db.append('\n');
224                }
225                String postproc = db.toString();
226                for (DocumentPreprocessor preproc : preprocessors) {
227                    postproc = preproc.processDoc(postproc);
228                    if (postproc == null) {
229                        break;
230                    }
231                }
232                if (postproc != null) {
233                    Example<T> ret = extractor.extract(outputFactory.generateOutput(label), postproc);
234                    return ret;
235                } else {
236                    //
237                    // Uh, it got post processed away. See if there's another one
238                    // and return it.
239                    if (!hasNext()) {
240                        throw new NoSuchElementException("No more files");
241                    }
242                    return next();
243                }
244            } catch (IOException ex) {
245                throw new IllegalStateException("Error reading path " + p, ex);
246            }
247        }
248
249    }
250
251    @Override
252    public ConfiguredDataSourceProvenance getProvenance() {
253        return new DirectoryFileSourceProvenance(this);
254    }
255
256    /**
257     * Provenance for {@link DirectoryFileSource}.
258     */
259    public static class DirectoryFileSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
260        private static final long serialVersionUID = 1L;
261
262        private final DateTimeProvenance fileModifiedTime;
263        private final DateTimeProvenance dataSourceCreationTime;
264
265        <T extends Output<T>> DirectoryFileSourceProvenance(DirectoryFileSource<T> host) {
266            super(host,"DataSource");
267            this.fileModifiedTime = new DateTimeProvenance(FILE_MODIFIED_TIME,OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.dataDir.toFile().lastModified()), ZoneId.systemDefault()));
268            this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now());
269        }
270
271        public DirectoryFileSourceProvenance(Map<String,Provenance> map) {
272            this(extractProvenanceInfo(map));
273        }
274
275        private DirectoryFileSourceProvenance(ExtractedInfo info) {
276            super(info);
277            this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME);
278            this.fileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FILE_MODIFIED_TIME);
279        }
280
281        protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) {
282            Map<String,Provenance> configuredParameters = new HashMap<>(map);
283            String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, DirectoryFileSourceProvenance.class.getSimpleName()).getValue();
284            String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, DirectoryFileSourceProvenance.class.getSimpleName()).getValue();
285
286            Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>();
287            instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, DirectoryFileSourceProvenance.class.getSimpleName()));
288
289            return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters);
290        }
291
292        @Override
293        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
294            Map<String,PrimitiveProvenance<?>> map = new HashMap<>();
295
296            map.put(FILE_MODIFIED_TIME,fileModifiedTime);
297            map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime);
298
299            return map;
300        }
301
302        @Override
303        public boolean equals(Object o) {
304            if (this == o) return true;
305            if (!(o instanceof DirectoryFileSourceProvenance)) return false;
306            if (!super.equals(o)) return false;
307            DirectoryFileSourceProvenance pairs = (DirectoryFileSourceProvenance) o;
308            return fileModifiedTime.equals(pairs.fileModifiedTime) &&
309                    dataSourceCreationTime.equals(pairs.dataSourceCreationTime);
310        }
311
312        @Override
313        public int hashCode() {
314            return Objects.hash(super.hashCode(), fileModifiedTime, dataSourceCreationTime);
315        }
316    }
317}