001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.dataset; 018 019import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance; 023import com.oracle.labs.mlrg.olcut.util.Pair; 024import org.tribuo.Dataset; 025import org.tribuo.Example; 026import org.tribuo.Feature; 027import org.tribuo.FeatureMap; 028import org.tribuo.ImmutableDataset; 029import org.tribuo.ImmutableFeatureMap; 030import org.tribuo.MutableFeatureMap; 031import org.tribuo.Output; 032import org.tribuo.VariableInfo; 033import org.tribuo.impl.ArrayExample; 034import org.tribuo.provenance.DatasetProvenance; 035 036import java.util.ArrayList; 037import java.util.HashSet; 038import java.util.List; 039import java.util.Map; 040import java.util.Objects; 041import java.util.Set; 042import java.util.logging.Logger; 043 044/** 045 * This class creates a pruned dataset in which low frequency features that 046 * occur less than the provided minimum cardinality have been removed. This can 047 * be useful when the dataset is very large due to many low-frequency features. 048 * For example, this class can be used to remove low frequency words from a BoW 049 * formatted dataset. Here, a new dataset is created so that the feature counts 050 * are recalculated and so that the original, passed-in dataset is not modified. 051 * The returned dataset may have fewer examples because if any of the examples 052 * have no features after the minimum cardinality has been applied, then those 053 * examples will not be added to the constructed dataset. 054 * 055 * @param <T> The type of the outputs in this {@link Dataset}. 056 */ 057public class MinimumCardinalityDataset<T extends Output<T>> extends ImmutableDataset<T> { 058 private static final long serialVersionUID = 1L; 059 060 private static final Logger logger = Logger.getLogger(MinimumCardinalityDataset.class.getName()); 061 062 private final int minCardinality; 063 064 private int numExamplesRemoved = 0; 065 066 private final Set<String> removed = new HashSet<>(); 067 068 /** 069 * @param dataset this dataset is left untouched and is used to populate the constructed dataset. 070 * @param minCardinality features with a frequency less than minCardinality will be removed. 071 */ 072 public MinimumCardinalityDataset(Dataset<T> dataset, int minCardinality) { 073 super(dataset.getProvenance(), dataset.getOutputFactory()); 074 this.minCardinality = minCardinality; 075 076 MutableFeatureMap featureInfos = new MutableFeatureMap(); 077 078 List<Feature> features = new ArrayList<>(); 079 // 080 // Rebuild the data list only with features that have a minimum cardinality. 081 FeatureMap wfm = dataset.getFeatureMap(); 082 for (Example<T> ex : dataset) { 083 features.clear(); 084 ArrayExample<T> nex = new ArrayExample<>(ex.getOutput()); 085 nex.setWeight(ex.getWeight()); 086 for (Feature f : ex) { 087 VariableInfo finfo = wfm.get(f.getName()); 088 if (finfo == null || finfo.getCount() < minCardinality) { 089 // 090 // The feature info might be null if we have a feature at 091 // prediction time that we didn't see 092 // at training time. 093 removed.add(f.getName()); 094 } else { 095 features.add(f); 096 } 097 } 098 nex.addAll(features); 099 if (nex.size() > 0) { 100 if (!nex.validateExample()) { 101 throw new IllegalStateException("Duplicate features found in example " + nex.toString()); 102 } 103 data.add(nex); 104 } else { 105 numExamplesRemoved++; 106 } 107 } 108 109 // Copy out the feature infos above the threshold. 110 for (VariableInfo info : wfm) { 111 if (info.getCount() >= minCardinality) { 112 featureInfos.put(info.copy()); 113 } 114 } 115 116 outputIDInfo = dataset.getOutputIDInfo(); 117 featureIDMap = new ImmutableFeatureMap(featureInfos); 118 119 if(numExamplesRemoved > 0) { 120 logger.info(String.format("filtered out %d examples because it had zero features after the minimum frequency count was applied.", numExamplesRemoved)); 121 } 122 123 } 124 125 /** 126 * The feature names that were removed. 127 * @return The feature names. 128 */ 129 public Set<String> getRemoved() { 130 return removed; 131 } 132 133 /** 134 * The number of examples removed due to a lack of features. 135 * @return The number of removed examples. 136 */ 137 public int getNumExamplesRemoved() { 138 return numExamplesRemoved; 139 } 140 141 /** 142 * The minimum cardinality threshold for the features. 143 * @return The cardinality threshold. 144 */ 145 public int getMinCardinality() { 146 return minCardinality; 147 } 148 149 @Override 150 public DatasetProvenance getProvenance() { 151 return new MinimumCardinalityDatasetProvenance(this); 152 } 153 154 /** 155 * Provenance for {@link MinimumCardinalityDataset}. 156 */ 157 public static class MinimumCardinalityDatasetProvenance extends DatasetProvenance { 158 private static final long serialVersionUID = 1L; 159 160 private static final String MIN_CARDINALITY = "min-cardinality"; 161 162 private final IntProvenance minCardinality; 163 164 <T extends Output<T>> MinimumCardinalityDatasetProvenance(MinimumCardinalityDataset<T> dataset) { 165 super(dataset.sourceProvenance, new ListProvenance<>(), dataset); 166 this.minCardinality = new IntProvenance(MIN_CARDINALITY,dataset.minCardinality); 167 } 168 169 public MinimumCardinalityDatasetProvenance(Map<String,Provenance> map) { 170 super(map); 171 this.minCardinality = ObjectProvenance.checkAndExtractProvenance(map,MIN_CARDINALITY,IntProvenance.class,MinimumCardinalityDatasetProvenance.class.getSimpleName()); 172 } 173 174 @Override 175 public boolean equals(Object o) { 176 if (this == o) return true; 177 if (!(o instanceof MinimumCardinalityDatasetProvenance)) return false; 178 if (!super.equals(o)) return false; 179 MinimumCardinalityDatasetProvenance pairs = (MinimumCardinalityDatasetProvenance) o; 180 return minCardinality.equals(pairs.minCardinality); 181 } 182 183 @Override 184 public int hashCode() { 185 return Objects.hash(super.hashCode(), minCardinality); 186 } 187 188 @Override 189 protected List<Pair<String, Provenance>> allProvenances() { 190 List<Pair<String,Provenance>> provenances = super.allProvenances(); 191 provenances.add(new Pair<>(MIN_CARDINALITY,minCardinality)); 192 return provenances; 193 } 194 } 195}