001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.data.text.impl;
018
019import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
023import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
024import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
025import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
026import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
027import org.tribuo.Example;
028import org.tribuo.Output;
029import org.tribuo.OutputFactory;
030import org.tribuo.data.text.TextDataSource;
031import org.tribuo.data.text.TextFeatureExtractor;
032import org.tribuo.provenance.ConfiguredDataSourceProvenance;
033
034import java.io.File;
035import java.io.IOException;
036import java.nio.charset.StandardCharsets;
037import java.nio.file.Files;
038import java.nio.file.Path;
039import java.time.Instant;
040import java.time.OffsetDateTime;
041import java.time.ZoneId;
042import java.util.HashMap;
043import java.util.Map;
044import java.util.Objects;
045import java.util.Optional;
046import java.util.logging.Logger;
047import java.util.regex.Pattern;
048
049/**
050 * A dataset for a simple data format for text classification experiments. A line
051 * in the file looks like:
052 *
053 * <pre>
054 * OUTPUT##Document text
055 * </pre>
056 *
057 * Each line in the file specifies a single output and document pair. Leading and
058 * trailing spaces will be trimmed from outputs and documents. Outputs will be
059 * converted to upper case.
060 * 
061 * <p> 
062 * 
063 * As with all of our text data, the file should be in UTF-8.
064 */
065public class SimpleTextDataSource<T extends Output<T>> extends TextDataSource<T> {
066
067    private static final Logger logger = Logger.getLogger(SimpleTextDataSource.class.getName());
068
069    private static final Pattern splitPattern = Pattern.compile("##");
070
071    protected ConfiguredDataSourceProvenance provenance;
072
073    /**
074     * for olcut
075     */
076    protected SimpleTextDataSource() {}
077
078    public SimpleTextDataSource(Path path, OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor) throws IOException {
079        super(path, outputFactory, extractor);
080        postConfig();
081    }
082
083    public SimpleTextDataSource(File file, OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor) throws IOException {
084        super(file, outputFactory, extractor);
085        postConfig();
086    }
087
088    protected SimpleTextDataSource(OutputFactory<T> outputFactory, TextFeatureExtractor<T> extractor) {
089        super((Path)null,outputFactory,extractor);
090    }
091
092    /**
093     * Used by the OLCUT configuration system, and should not be called by external code.
094     */
095    @Override
096    public void postConfig() throws IOException {
097        read();
098        provenance = cacheProvenance();
099    }
100
101    protected Optional<Example<T>> parseLine(String line, int n) {
102        line = line.trim();
103        if(line.isEmpty()) {
104            return Optional.empty();
105        }
106        String[] fields = splitPattern.split(line);
107        if(fields.length != 2) {
108            logger.warning(String.format("Bad line in %s at %d: %s",
109                    path, n, line.substring(Math.min(50, line.length()))));
110            return Optional.empty();
111        }
112        T label = outputFactory.generateOutput(fields[0].trim().toUpperCase());
113        return Optional.of(extractor.extract(label, handleDoc(fields[1].trim())));
114    }
115
116    @Override
117    protected void read() throws IOException {
118        int n = 0;
119        for (String line : Files.readAllLines(path, StandardCharsets.UTF_8)) {
120            n++;
121            Optional<Example<T>> example = parseLine(line, n);
122            if (example.isPresent()) {
123                Example<T> e = example.get();
124                if (e.validateExample()) {
125                    data.add(e);
126                } else {
127                    logger.warning("Invalid example found after parsing line " + n);
128                }
129            }
130        }
131    }
132
133    @Override
134    public ConfiguredDataSourceProvenance getProvenance() {
135        return provenance;
136    }
137
138    protected ConfiguredDataSourceProvenance cacheProvenance() {
139        return new SimpleTextDataSourceProvenance(this);
140    }
141
142    /**
143     * Provenance for {@link SimpleTextDataSource}.
144     */
145    public static class SimpleTextDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
146        private static final long serialVersionUID = 1L;
147
148        private final DateTimeProvenance fileModifiedTime;
149        private final DateTimeProvenance dataSourceCreationTime;
150        private final HashProvenance sha256Hash;
151
152        <T extends Output<T>> SimpleTextDataSourceProvenance(SimpleTextDataSource<T> host) {
153            super(host,"DataSource");
154            this.fileModifiedTime = new DateTimeProvenance(FILE_MODIFIED_TIME,OffsetDateTime.ofInstant(Instant.ofEpochMilli(host.path.toFile().lastModified()), ZoneId.systemDefault()));
155            this.dataSourceCreationTime = new DateTimeProvenance(DATASOURCE_CREATION_TIME,OffsetDateTime.now());
156            this.sha256Hash = new HashProvenance(DEFAULT_HASH_TYPE,RESOURCE_HASH,ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE,host.path));
157        }
158
159        public SimpleTextDataSourceProvenance(Map<String,Provenance> map) {
160            this(extractProvenanceInfo(map));
161        }
162
163        private SimpleTextDataSourceProvenance(ExtractedInfo info) {
164            super(info);
165            this.fileModifiedTime = (DateTimeProvenance) info.instanceValues.get(FILE_MODIFIED_TIME);
166            this.dataSourceCreationTime = (DateTimeProvenance) info.instanceValues.get(DATASOURCE_CREATION_TIME);
167            this.sha256Hash = (HashProvenance) info.instanceValues.get(RESOURCE_HASH);
168        }
169
170        protected static ExtractedInfo extractProvenanceInfo(Map<String,Provenance> map) {
171            Map<String,Provenance> configuredParameters = new HashMap<>(map);
172            String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName()).getValue();
173            String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters,HOST_SHORT_NAME, StringProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName()).getValue();
174
175            Map<String,PrimitiveProvenance<?>> instanceParameters = new HashMap<>();
176            instanceParameters.put(FILE_MODIFIED_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,FILE_MODIFIED_TIME,DateTimeProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName()));
177            instanceParameters.put(DATASOURCE_CREATION_TIME,ObjectProvenance.checkAndExtractProvenance(configuredParameters,DATASOURCE_CREATION_TIME,DateTimeProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName()));
178            instanceParameters.put(RESOURCE_HASH,ObjectProvenance.checkAndExtractProvenance(configuredParameters,RESOURCE_HASH,HashProvenance.class, SimpleTextDataSourceProvenance.class.getSimpleName()));
179
180            return new ExtractedInfo(className,hostTypeStringName,configuredParameters,instanceParameters);
181        }
182
183        @Override
184        public boolean equals(Object o) {
185            if (this == o) return true;
186            if (!(o instanceof SimpleTextDataSourceProvenance)) return false;
187            if (!super.equals(o)) return false;
188            SimpleTextDataSourceProvenance pairs = (SimpleTextDataSourceProvenance) o;
189            return fileModifiedTime.equals(pairs.fileModifiedTime) &&
190                    dataSourceCreationTime.equals(pairs.dataSourceCreationTime) &&
191                    sha256Hash.equals(pairs.sha256Hash);
192        }
193
194        @Override
195        public int hashCode() {
196            return Objects.hash(super.hashCode(), fileModifiedTime, dataSourceCreationTime, sha256Hash);
197        }
198
199        @Override
200        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
201            Map<String,PrimitiveProvenance<?>> map = new HashMap<>();
202
203            map.put(FILE_MODIFIED_TIME,fileModifiedTime);
204            map.put(DATASOURCE_CREATION_TIME,dataSourceCreationTime);
205            map.put(RESOURCE_HASH,sha256Hash);
206
207            return map;
208        }
209    }
210
211}