001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.sequence;
018
019import java.util.ArrayList;
020import java.util.HashSet;
021import java.util.List;
022import java.util.Map;
023import java.util.Objects;
024import java.util.Set;
025import java.util.logging.Logger;
026
027import org.tribuo.Example;
028import org.tribuo.Feature;
029import org.tribuo.FeatureMap;
030import org.tribuo.ImmutableFeatureMap;
031import org.tribuo.MutableFeatureMap;
032import org.tribuo.Output;
033import org.tribuo.VariableInfo;
034import org.tribuo.impl.ArrayExample;
035import org.tribuo.impl.BinaryFeaturesExample;
036import org.tribuo.provenance.DatasetProvenance;
037
038import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
039import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
040import com.oracle.labs.mlrg.olcut.provenance.Provenance;
041import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
042import com.oracle.labs.mlrg.olcut.util.Pair;
043
044/**
045 * This class creates a pruned dataset in which low frequency features that
046 * occur less than the provided minimum cardinality have been removed. This can
047 * be useful when the dataset is very large due to many low-frequency features.
048 * Here, a new dataset is created so that the feature counts are recalculated
049 * and so that the original, passed-in dataset is not modified. The returned
050 * dataset may have fewer sequence examples because if any of the sequence
051 * examples have examples with no features after the minimum cardinality has
052 * been applied, then those sequence examples will not be added to the
053 * constructed dataset.
054 * 
055 * @param <T> The type of the outputs in this {@link SequenceDataset}.
056 */
057public class MinimumCardinalitySequenceDataset<T extends Output<T>> extends ImmutableSequenceDataset<T> {
058    private static final long serialVersionUID = 1L;
059
060    private static final Logger logger = Logger.getLogger(MinimumCardinalitySequenceDataset.class.getName());
061
062    private final int minCardinality;
063
064    private int numExamplesRemoved = 0;
065
066    private final Set<String> removedFeatureNames = new HashSet<>();
067
068    /**
069     * @param sequenceDataset this dataset is left untouched and is used to populate
070     *                        the constructed dataset.
071     * @param minCardinality  features with a frequency less than minCardinality
072     *                        will be removed.
073     */
074    public MinimumCardinalitySequenceDataset(SequenceDataset<T> sequenceDataset, int minCardinality) {
075        super(sequenceDataset.getProvenance(), sequenceDataset.getOutputFactory());
076        this.minCardinality = minCardinality;
077
078        MutableFeatureMap featureInfos = new MutableFeatureMap();
079
080        List<Feature> features = new ArrayList<>();
081        //
082        // Rebuild the data list only with features that have a minimum cardinality.
083        FeatureMap featureMap = sequenceDataset.getFeatureMap();
084        for (SequenceExample<T> sequenceExample : sequenceDataset) {
085            boolean add = true;
086            List<Example<T>> newExamples = new ArrayList<>(sequenceExample.size());
087            for (Example<T> example : sequenceExample) {
088                features.clear();
089                Example<T> newExample;
090                if(example instanceof BinaryFeaturesExample) {
091                    newExample = new BinaryFeaturesExample<>(example.getOutput());
092                } else {
093                    newExample = new ArrayExample<>(example.getOutput());
094                }
095                newExample.setWeight(example.getWeight());
096                for (Feature feature : example) {
097                    VariableInfo featureInfo = featureMap.get(feature.getName());
098                    if (featureInfo == null || featureInfo.getCount() < minCardinality) {
099                        //
100                        // The feature info might be null if we have a feature at
101                        // prediction time that we didn't see
102                        // at training time.
103                        removedFeatureNames.add(feature.getName());
104                    } else {
105                        features.add(feature);
106                    }
107                }
108                newExample.addAll(features);
109                if (newExample.size() > 0) {
110                    if (!newExample.validateExample()) {
111                        throw new IllegalStateException("Duplicate features found in example " + newExample.toString());
112                    }
113                    newExamples.add(newExample);
114                } else {
115                    numExamplesRemoved++;
116                    add = false;
117                    break;
118                }
119            }
120            if (add) {
121                SequenceExample<T> newSequenceExample = new SequenceExample<>(newExamples);
122                data.add(newSequenceExample);
123            }
124        }
125
126        // Copy out the feature infos above the threshold.
127        for (VariableInfo info : featureMap) {
128            if (info.getCount() >= minCardinality) {
129                featureInfos.put(info.copy());
130            }
131        }
132
133        this.outputIDInfo = sequenceDataset.getOutputIDInfo();
134        this.featureIDMap = new ImmutableFeatureMap(featureInfos);
135
136        if (numExamplesRemoved > 0) {
137            logger.info(String.format(
138                    "filtered out %d sequence examples because (at least) one of its examples had zero features after the minimum frequency count was applied.",
139                    numExamplesRemoved));
140        }
141    }
142
143    /**
144     * The feature names that were removed.
145     * 
146     * @return The feature names.
147     */
148    public Set<String> getRemoved() {
149        return removedFeatureNames;
150    }
151
152    /**
153     * The number of examples removed due to a lack of features.
154     * 
155     * @return The number of removed examples.
156     */
157    public int getNumExamplesRemoved() {
158        return numExamplesRemoved;
159    }
160
161    /**
162     * The minimum cardinality threshold for the features.
163     * 
164     * @return The cardinality threshold.
165     */
166    public int getMinCardinality() {
167        return minCardinality;
168    }
169
170    @Override
171    public DatasetProvenance getProvenance() {
172        return new MinimumCardinalitySequenceDatasetProvenance(this);
173    }
174
175    /**
176     * Provenance for {@link MinimumCardinalitySequenceDataset}.
177     */
178    public static class MinimumCardinalitySequenceDatasetProvenance extends DatasetProvenance {
179        private static final long serialVersionUID = 1L;
180
181        private static final String MIN_CARDINALITY = "min-cardinality";
182
183        private final IntProvenance minCardinality;
184
185        <T extends Output<T>> MinimumCardinalitySequenceDatasetProvenance(
186                MinimumCardinalitySequenceDataset<T> dataset) {
187            super(dataset.sourceProvenance, new ListProvenance<>(), dataset);
188            this.minCardinality = new IntProvenance(MIN_CARDINALITY, dataset.minCardinality);
189        }
190
191        public MinimumCardinalitySequenceDatasetProvenance(Map<String, Provenance> map) {
192            super(map);
193            this.minCardinality = ObjectProvenance.checkAndExtractProvenance(map, MIN_CARDINALITY, IntProvenance.class,
194                    MinimumCardinalitySequenceDatasetProvenance.class.getSimpleName());
195        }
196
197        @Override
198        public boolean equals(Object o) {
199            if (this == o)
200                return true;
201            if (!(o instanceof MinimumCardinalitySequenceDatasetProvenance))
202                return false;
203            if (!super.equals(o))
204                return false;
205            MinimumCardinalitySequenceDatasetProvenance pairs = (MinimumCardinalitySequenceDatasetProvenance) o;
206            return minCardinality.equals(pairs.minCardinality);
207        }
208
209        @Override
210        public int hashCode() {
211            return Objects.hash(super.hashCode(), minCardinality);
212        }
213
214        @Override
215        protected List<Pair<String, Provenance>> allProvenances() {
216            List<Pair<String, Provenance>> provenances = super.allProvenances();
217            provenances.add(new Pair<>(MIN_CARDINALITY, minCardinality));
218            return provenances;
219        }
220    }
221}