001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.sequence;
018
019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
020import org.tribuo.Example;
021import org.tribuo.Feature;
022import org.tribuo.FeatureMap;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Output;
026import org.tribuo.OutputFactory;
027import org.tribuo.OutputInfo;
028import org.tribuo.VariableInfo;
029import org.tribuo.provenance.DataProvenance;
030import org.tribuo.provenance.DatasetProvenance;
031import org.tribuo.util.Merger;
032
033import java.io.Serializable;
034import java.util.ArrayList;
035import java.util.List;
036import java.util.Set;
037
038/**
039 * This is a {@link SequenceDataset} which has an {@link ImmutableFeatureMap} to store the feature information.
040 * Whenever an example is added to this dataset it removes features that do not exist in the FeatureMap.
041 * The dataset is immutable after construction (unless the examples are modified).
042 */
043public class ImmutableSequenceDataset<T extends Output<T>> extends SequenceDataset<T> implements Serializable {
044    private static final long serialVersionUID = 1L;
045
046    /**
047     * A map from labels to IDs for the labels found in this dataset.
048     */
049    protected ImmutableOutputInfo<T> outputIDInfo;
050
051    /**
052     * A map from feature names to IDs for the features found in this dataset.
053     */
054    protected ImmutableFeatureMap featureIDMap;
055
056    private DatasetProvenance provenance;
057
058    /**
059     * If you call this it's your job to setup outputIDInfo and featureIDMap.
060     * @param sourceProvenance A description of the dataset including preprocessing steps.
061     * @param outputFactory The output factory.
062     */
063    protected ImmutableSequenceDataset(DataProvenance sourceProvenance, OutputFactory<T> outputFactory) {
064        super(sourceProvenance,outputFactory);
065    }
066
067    public ImmutableSequenceDataset(SequenceDataSource<T> dataSource, SequenceModel<T> model) {
068        this(dataSource,dataSource.getProvenance(),model.getFeatureIDMap(),model.getOutputIDInfo(),dataSource.getOutputFactory());
069    }
070
071    public ImmutableSequenceDataset(SequenceDataSource<T> dataSource, FeatureMap featureIDMap, OutputInfo<T> outputIDInfo) {
072        this(dataSource,dataSource.getProvenance(),featureIDMap,outputIDInfo,dataSource.getOutputFactory());
073    }
074
075    /**
076     * Creates a dataset from a data source. This method will create the output
077     * and feature ID maps that are needed for training and evaluating classifiers.
078     * @param dataSource The input data.
079     * @param sourceProvenance A description of the data.
080     * @param featureIDMap The feature map, used to remove unknown features.
081     * @param outputIDInfo The output map.
082     * @param outputFactory The output factory.
083     */
084    public ImmutableSequenceDataset(Iterable<SequenceExample<T>> dataSource, DataProvenance sourceProvenance, FeatureMap featureIDMap, OutputInfo<T> outputIDInfo, OutputFactory<T> outputFactory) {
085        this(dataSource,sourceProvenance, new ImmutableFeatureMap(featureIDMap), outputIDInfo.generateImmutableOutputInfo(),outputFactory);
086    }
087
088    /**
089     * Creates a dataset from a data source.
090     * @param dataSource The input data.
091     * @param sourceProvenance A description of the data.
092     * @param featureIDMap The feature id map, used to remove unknown features.
093     * @param outputIDInfo The output id map.
094     * @param outputFactory The output factory.
095     */
096    public ImmutableSequenceDataset(Iterable<SequenceExample<T>> dataSource, DataProvenance sourceProvenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, OutputFactory<T> outputFactory) {
097        super(sourceProvenance,outputFactory);
098        this.featureIDMap = featureIDMap;
099        this.outputIDInfo = outputIDInfo;
100
101        for (SequenceExample<T> ex : dataSource) {
102            add(ex);
103        }
104    }
105
106    /**
107     * This is dangerous, and should not be used unless you've overridden everything in ImmutableSequenceDataset.
108     * @param sourceProvenance A description of the data, including all preprocessing.
109     * @param featureIDMap The feature id map, used to remove unknown features.
110     * @param outputIDInfo The output id map.
111     */
112    protected ImmutableSequenceDataset(DataProvenance sourceProvenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo) {
113        super(sourceProvenance,null);
114        this.featureIDMap = featureIDMap;
115        this.outputIDInfo = outputIDInfo;
116    }
117
118    /**
119     * Adds a {@link SequenceExample} to the dataset, which will insert feature ids, remove unknown features
120     * and sort the examples by the feature ids.
121     * @param ex The example to add.
122     */
123    protected void add(SequenceExample<T> ex) {
124        if (ex.size() > 0) {
125            List<Feature> featuresToRemove = new ArrayList<>();
126            for (Example<T> e : ex) {
127                featuresToRemove.clear();
128                for (Feature f : e) {
129                    VariableInfo info = featureIDMap.get(f.getName());
130                    if (info == null) {
131                        featuresToRemove.add(f);
132                    }
133                }
134                e.removeFeatures(featuresToRemove);
135                if (!e.validateExample()) {
136                    throw new IllegalStateException("Duplicate features or invalid features inside the Example, or all features were removed.");
137                }
138            }
139            data.add(ex);
140            ex.canonicalise(featureIDMap);
141        } else {
142            throw new IllegalArgumentException("SequenceExample is empty.");
143        }
144    }
145
146    /**
147     * Adds an {@link SequenceExample} to the dataset. Use only
148     * when the example has already been validated.
149     * @param ex An {@link SequenceExample} to add to the dataset.
150     */
151    private void unsafeAdd(SequenceExample<T> ex) {
152        data.add(ex);
153    }
154
155    /**
156     * Adds a {@link SequenceExample} to the dataset, which will insert feature ids, remove unknown features
157     * and sort the examples by the feature ids.
158     * @param ex The example to add.
159     * @param merger The merger to use to remove duplicate features.
160     */
161    protected void add(SequenceExample<T> ex, Merger merger) {
162        if (ex.size() > 0) {
163            data.add(ex);
164            List<Feature> featuresToRemove = new ArrayList<>();
165            for (Example<T> e : ex) {
166                featuresToRemove.clear();
167                for (Feature f : e) {
168                    VariableInfo info = featureIDMap.get(f.getName());
169                    if (info == null) {
170                        featuresToRemove.add(f);
171                    }
172                }
173                e.removeFeatures(featuresToRemove);
174                e.reduceByName(merger);
175                if (!e.validateExample()) {
176                    throw new IllegalStateException("Duplicate features or invalid features inside the Example, or all features were removed.");
177                }
178            }
179        } else {
180            throw new IllegalArgumentException("SequenceExample is empty.");
181        }
182    }
183
184    @Override
185    public Set<T> getOutputs() {
186        return outputIDInfo.getDomain();
187    }
188
189    @Override
190    public ImmutableFeatureMap getFeatureIDMap() {
191        return featureIDMap;
192    }
193
194    @Override
195    public ImmutableFeatureMap getFeatureMap() {
196        return featureIDMap;
197    }
198
199    @Override
200    public ImmutableOutputInfo<T> getOutputIDInfo() {
201        return outputIDInfo;
202    }
203
204    @Override
205    public ImmutableOutputInfo<T> getOutputInfo() {
206        return outputIDInfo;
207    }
208
209    @Override
210    public String toString(){
211        return "ImmutableSequenceDataset(source="+ sourceProvenance.toString()+")";
212    }
213
214    @Override
215    public synchronized DatasetProvenance getProvenance() {
216        if (provenance == null) {
217            provenance = cacheProvenance();
218        }
219        return provenance;
220    }
221
222    private DatasetProvenance cacheProvenance() {
223        return new DatasetProvenance(sourceProvenance,new ListProvenance<>(),this);
224    }
225
226    /**
227     * Creates an immutable deep copy of the supplied dataset.
228     * @param dataset The dataset to copy.
229     * @param <T> The type of output.
230     * @return An immutable copy of the dataset.
231     */
232    public static <T extends Output<T>> ImmutableSequenceDataset<T> copyDataset(SequenceDataset<T> dataset) {
233        ArrayList<SequenceExample<T>> newData = new ArrayList<>();
234        for (SequenceExample<T> e : dataset) {
235            newData.add(e.copy());
236        }
237        return new ImmutableSequenceDataset<>(newData,dataset.getSourceProvenance(),dataset.getFeatureIDMap(),dataset.getOutputInfo(),dataset.getOutputFactory());
238    }
239
240    /**
241     * Creates an immutable deep copy of the supplied dataset, using a different feature and output map.
242     * @param dataset The dataset to copy.
243     * @param featureIDMap The new feature map to use. Removes features which are not found in this map.
244     * @param outputIDInfo The new output info to use.
245     * @param <T> The type of output.
246     * @return An immutable copy of the dataset.
247     */
248    public static <T extends Output<T>> ImmutableSequenceDataset<T> copyDataset(SequenceDataset<T> dataset, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo) {
249        ArrayList<SequenceExample<T>> newData = new ArrayList<>();
250        for (SequenceExample<T> e : dataset) {
251            newData.add(e.copy());
252        }
253        return new ImmutableSequenceDataset<>(newData,dataset.getSourceProvenance(),featureIDMap,outputIDInfo,dataset.getOutputFactory());
254    }
255
256    /**
257     * Creates an immutable deep copy of the supplied dataset.
258     * @param dataset The dataset to copy.
259     * @param featureIDMap The new feature map to use. Removes features which are not found in this map.
260     * @param outputIDInfo The new output info to use.
261     * @param merger The merge function to use to reduce features given new ids.
262     * @param <T> The type of output.
263     * @return An immutable copy of the dataset.
264     */
265    public static <T extends Output<T>> ImmutableSequenceDataset<T> copyDataset(SequenceDataset<T> dataset, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, Merger merger) {
266        ImmutableSequenceDataset<T> copy = new ImmutableSequenceDataset<>(dataset.getProvenance(),featureIDMap,outputIDInfo);
267        for (SequenceExample<T> e : dataset) {
268            copy.add(e.copy(),merger);
269        }
270        return copy;
271    }
272
273    static <T extends Output<T>> ImmutableSequenceDataset<T> changeFeatureMap(SequenceDataset<T> dataset, ImmutableFeatureMap featureIDMap) {
274        ImmutableSequenceDataset<T> copy = new ImmutableSequenceDataset<>(dataset.getProvenance(),featureIDMap,dataset.getOutputIDInfo());
275        for (SequenceExample<T> e : dataset) {
276            copy.unsafeAdd(e);
277        }
278        return copy;
279    }
280}