001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.sequence;
018
019import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.FeatureMap;
023import org.tribuo.ImmutableDataset;
024import org.tribuo.ImmutableFeatureMap;
025import org.tribuo.ImmutableOutputInfo;
026import org.tribuo.Output;
027import org.tribuo.OutputFactory;
028import org.tribuo.OutputInfo;
029import org.tribuo.provenance.DataProvenance;
030import org.tribuo.provenance.DatasetProvenance;
031
032import java.io.Serializable;
033import java.util.ArrayList;
034import java.util.Collections;
035import java.util.Iterator;
036import java.util.List;
037import java.util.Set;
038import java.util.logging.Logger;
039
040/**
041 * A class for sets of data, which are used to train and evaluate classifiers.
042 * <p>
043 * Subclass either {@link MutableSequenceDataset} or {@link ImmutableSequenceDataset} rather than this class.
044 *
045 * @param <T> the type of the outputs in the data set.
046 */
047public abstract class SequenceDataset<T extends Output<T>> implements Iterable<SequenceExample<T>>, Provenancable<DatasetProvenance>, Serializable {
048    private static final Logger logger = Logger.getLogger(SequenceDataset.class.getName());
049    private static final long serialVersionUID = 2L;
050
051    /**
052     * A factory for making {@link OutputInfo} and {@link Output} of the appropriate type.
053     */
054    protected final OutputFactory<T> outputFactory;
055
056    /**
057     * The data in this data set.
058     */
059    protected final List<SequenceExample<T>> data = new ArrayList<>();
060
061    /**
062     * The provenance of the data source, extracted on construction.
063     */
064    protected final DataProvenance sourceProvenance;
065
066    protected SequenceDataset(DataProvenance sourceProvenance, OutputFactory<T> outputFactory) {
067        this.sourceProvenance = sourceProvenance;
068        this.outputFactory = outputFactory;
069    }
070
071    /**
072     * Returns the description of the source provenance.
073     * @return The source provenance in text form.
074     */
075    public String getSourceDescription() {
076        return "SequenceDataset(source=" + sourceProvenance.toString() + ")";
077    }
078
079    /**
080     * Returns an unmodifiable view on the data.
081     * @return The data.
082     */
083    public List<SequenceExample<T>> getData() {
084        return Collections.unmodifiableList(data);
085    }
086
087    /**
088     * Returns the source provenance.
089     * @return The source provenance.
090     */
091    public DataProvenance getSourceProvenance() {
092        return sourceProvenance;
093    }
094
095    /**
096     * Gets the set of labels that occur in the examples in this dataset.
097     *
098     * @return the set of labels that occur in the examples in this dataset.
099     */
100    public abstract Set<T> getOutputs();
101
102    /**
103     * Gets the example at the specified index, or throws IllegalArgumentException if
104     * the index is out of bounds.
105     * @param index The index.
106     * @return The example at that index.
107     */
108    public SequenceExample<T> getExample(int index) {
109        if ((index < 0) || (index >= size())) {
110            throw new IllegalArgumentException("Example index " + index + " is out of bounds.");
111        }
112        return data.get(index);
113    }
114
115    /**
116     * Returns a view on this SequenceDataset which aggregates all
117     * the examples and ignores the sequence structure.
118     *
119     * @return A flattened view on this dataset.
120     */
121    public Dataset<T> getFlatDataset() {
122        return new FlatDataset<>(this);
123    }
124
125    /**
126     * Gets the size of the data set.
127     *
128     * @return the size of the data set.
129     */
130    public int size() {
131        return data.size();
132    }
133
134    /**
135     * An immutable view on the output info in this dataset.
136     * @return The output info.
137     */
138    public abstract ImmutableOutputInfo<T> getOutputIDInfo();
139
140    /**
141     * The output info in this dataset.
142     * @return The output info.
143     */
144    public abstract OutputInfo<T> getOutputInfo();
145
146    /**
147     * An immutable view on the feature map.
148     * @return The feature map.
149     */
150    public abstract ImmutableFeatureMap getFeatureIDMap();
151
152    /**
153     * The feature map.
154     * @return The feature map.
155     */
156    public abstract FeatureMap getFeatureMap();
157
158    /**
159     * Gets the output factory.
160     * @return The output factory.
161     */
162    public OutputFactory<T> getOutputFactory() {
163        return outputFactory;
164    }
165
166    @Override
167    public Iterator<SequenceExample<T>> iterator() {
168        return data.iterator();
169    }
170
171    @Override
172    public String toString() {
173        return "SequenceDataset(source=" + sourceProvenance.toString() + ")";
174    }
175
176    private static class FlatDataset<T extends Output<T>> extends ImmutableDataset<T> {
177        private static final long serialVersionUID = 1L;
178
179        public FlatDataset(SequenceDataset<T> sequenceDataset) {
180            super(sequenceDataset.sourceProvenance, sequenceDataset.outputFactory, sequenceDataset.getFeatureIDMap(), sequenceDataset.getOutputIDInfo());
181            for (SequenceExample<T> seq : sequenceDataset) {
182                for (Example<T> e : seq) {
183                    data.add(e);
184                }
185            }
186        }
187    }
188
189}
190