001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.dataset;
018
019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
023import com.oracle.labs.mlrg.olcut.util.Pair;
024import org.tribuo.Dataset;
025import org.tribuo.Example;
026import org.tribuo.Feature;
027import org.tribuo.FeatureMap;
028import org.tribuo.ImmutableDataset;
029import org.tribuo.ImmutableFeatureMap;
030import org.tribuo.MutableFeatureMap;
031import org.tribuo.Output;
032import org.tribuo.VariableInfo;
033import org.tribuo.impl.ArrayExample;
034import org.tribuo.provenance.DatasetProvenance;
035
036import java.util.ArrayList;
037import java.util.HashSet;
038import java.util.List;
039import java.util.Map;
040import java.util.Objects;
041import java.util.Set;
042import java.util.logging.Logger;
043
044/**
045 * This class creates a pruned dataset in which low frequency features that
046 * occur less than the provided minimum cardinality have been removed. This can
047 * be useful when the dataset is very large due to many low-frequency features.
048 * For example, this class can be used to remove low frequency words from a BoW
049 * formatted dataset. Here, a new dataset is created so that the feature counts
050 * are recalculated and so that the original, passed-in dataset is not modified.
051 * The returned dataset may have fewer examples because if any of the examples
052 * have no features after the minimum cardinality has been applied, then those
053 * examples will not be added to the constructed dataset.
054 * 
055 * @param <T> The type of the outputs in this {@link Dataset}.
056 */
057public class MinimumCardinalityDataset<T extends Output<T>> extends ImmutableDataset<T> {
058    private static final long serialVersionUID = 1L;
059
060    private static final Logger logger = Logger.getLogger(MinimumCardinalityDataset.class.getName());
061
062    private final int minCardinality;
063    
064    private int numExamplesRemoved = 0;
065
066    private final Set<String> removed = new HashSet<>();
067
068    /**
069     * @param dataset this dataset is left untouched and is used to populate the constructed dataset.
070     * @param minCardinality features with a frequency less than minCardinality will be removed.  
071     */
072    public MinimumCardinalityDataset(Dataset<T> dataset, int minCardinality) {
073        super(dataset.getProvenance(), dataset.getOutputFactory());
074        this.minCardinality = minCardinality;
075
076        MutableFeatureMap featureInfos = new MutableFeatureMap();
077
078        List<Feature> features = new ArrayList<>();
079        //
080        // Rebuild the data list only with features that have a minimum cardinality.
081        FeatureMap wfm = dataset.getFeatureMap();
082        for (Example<T> ex : dataset) {
083            features.clear();
084            ArrayExample<T> nex = new ArrayExample<>(ex.getOutput());
085            nex.setWeight(ex.getWeight());
086            for (Feature f : ex) {
087                VariableInfo finfo = wfm.get(f.getName());
088                if (finfo == null || finfo.getCount() < minCardinality) {
089                    //
090                    // The feature info might be null if we have a feature at 
091                    // prediction time that we didn't see
092                    // at training time.
093                    removed.add(f.getName());
094                } else {
095                    features.add(f);
096                }
097            }
098            nex.addAll(features);
099            if (nex.size() > 0) {
100                if (!nex.validateExample()) {
101                    throw new IllegalStateException("Duplicate features found in example " + nex.toString());
102                }
103                data.add(nex);
104            } else {
105                numExamplesRemoved++;
106            }
107        }
108
109        // Copy out the feature infos above the threshold.
110        for (VariableInfo info : wfm) {
111            if (info.getCount() >= minCardinality) {
112                featureInfos.put(info.copy());
113            }
114        }
115
116        outputIDInfo = dataset.getOutputIDInfo();
117        featureIDMap = new ImmutableFeatureMap(featureInfos);
118
119        if(numExamplesRemoved > 0) {
120            logger.info(String.format("filtered out %d examples because it had zero features after the minimum frequency count was applied.", numExamplesRemoved));
121        }
122
123    }
124
125    /**
126     * The feature names that were removed.
127     * @return The feature names.
128     */
129    public Set<String> getRemoved() {
130        return removed;
131    }
132
133    /**
134     * The number of examples removed due to a lack of features.
135     * @return The number of removed examples.
136     */
137    public int getNumExamplesRemoved() {
138        return numExamplesRemoved;
139    }
140
141    /**
142     * The minimum cardinality threshold for the features.
143     * @return The cardinality threshold.
144     */
145    public int getMinCardinality() {
146        return minCardinality;
147    }
148
149    @Override
150    public DatasetProvenance getProvenance() {
151        return new MinimumCardinalityDatasetProvenance(this);
152    }
153
154    /**
155     * Provenance for {@link MinimumCardinalityDataset}.
156     */
157    public static class MinimumCardinalityDatasetProvenance extends DatasetProvenance {
158        private static final long serialVersionUID = 1L;
159
160        private static final String MIN_CARDINALITY = "min-cardinality";
161
162        private final IntProvenance minCardinality;
163
164        <T extends Output<T>> MinimumCardinalityDatasetProvenance(MinimumCardinalityDataset<T> dataset) {
165            super(dataset.sourceProvenance, new ListProvenance<>(), dataset);
166            this.minCardinality = new IntProvenance(MIN_CARDINALITY,dataset.minCardinality);
167        }
168
169        public MinimumCardinalityDatasetProvenance(Map<String,Provenance> map) {
170            super(map);
171            this.minCardinality = ObjectProvenance.checkAndExtractProvenance(map,MIN_CARDINALITY,IntProvenance.class,MinimumCardinalityDatasetProvenance.class.getSimpleName());
172        }
173
174        @Override
175        public boolean equals(Object o) {
176            if (this == o) return true;
177            if (!(o instanceof MinimumCardinalityDatasetProvenance)) return false;
178            if (!super.equals(o)) return false;
179            MinimumCardinalityDatasetProvenance pairs = (MinimumCardinalityDatasetProvenance) o;
180            return minCardinality.equals(pairs.minCardinality);
181        }
182
183        @Override
184        public int hashCode() {
185            return Objects.hash(super.hashCode(), minCardinality);
186        }
187
188        @Override
189        protected List<Pair<String, Provenance>> allProvenances() {
190            List<Pair<String,Provenance>> provenances = super.allProvenances();
191            provenances.add(new Pair<>(MIN_CARDINALITY,minCardinality));
192            return provenances;
193        }
194    }
195}