001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.sequence; 018 019import java.util.ArrayList; 020import java.util.HashSet; 021import java.util.List; 022import java.util.Map; 023import java.util.Objects; 024import java.util.Set; 025import java.util.logging.Logger; 026 027import org.tribuo.Example; 028import org.tribuo.Feature; 029import org.tribuo.FeatureMap; 030import org.tribuo.ImmutableFeatureMap; 031import org.tribuo.MutableFeatureMap; 032import org.tribuo.Output; 033import org.tribuo.VariableInfo; 034import org.tribuo.impl.ArrayExample; 035import org.tribuo.impl.BinaryFeaturesExample; 036import org.tribuo.provenance.DatasetProvenance; 037 038import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 039import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 040import com.oracle.labs.mlrg.olcut.provenance.Provenance; 041import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance; 042import com.oracle.labs.mlrg.olcut.util.Pair; 043 044/** 045 * This class creates a pruned dataset in which low frequency features that 046 * occur less than the provided minimum cardinality have been removed. This can 047 * be useful when the dataset is very large due to many low-frequency features. 048 * Here, a new dataset is created so that the feature counts are recalculated 049 * and so that the original, passed-in dataset is not modified. The returned 050 * dataset may have fewer sequence examples because if any of the sequence 051 * examples have examples with no features after the minimum cardinality has 052 * been applied, then those sequence examples will not be added to the 053 * constructed dataset. 054 * 055 * @param <T> The type of the outputs in this {@link SequenceDataset}. 056 */ 057public class MinimumCardinalitySequenceDataset<T extends Output<T>> extends ImmutableSequenceDataset<T> { 058 private static final long serialVersionUID = 1L; 059 060 private static final Logger logger = Logger.getLogger(MinimumCardinalitySequenceDataset.class.getName()); 061 062 private final int minCardinality; 063 064 private int numExamplesRemoved = 0; 065 066 private final Set<String> removedFeatureNames = new HashSet<>(); 067 068 /** 069 * @param sequenceDataset this dataset is left untouched and is used to populate 070 * the constructed dataset. 071 * @param minCardinality features with a frequency less than minCardinality 072 * will be removed. 073 */ 074 public MinimumCardinalitySequenceDataset(SequenceDataset<T> sequenceDataset, int minCardinality) { 075 super(sequenceDataset.getProvenance(), sequenceDataset.getOutputFactory()); 076 this.minCardinality = minCardinality; 077 078 MutableFeatureMap featureInfos = new MutableFeatureMap(); 079 080 List<Feature> features = new ArrayList<>(); 081 // 082 // Rebuild the data list only with features that have a minimum cardinality. 083 FeatureMap featureMap = sequenceDataset.getFeatureMap(); 084 for (SequenceExample<T> sequenceExample : sequenceDataset) { 085 boolean add = true; 086 List<Example<T>> newExamples = new ArrayList<>(sequenceExample.size()); 087 for (Example<T> example : sequenceExample) { 088 features.clear(); 089 Example<T> newExample; 090 if(example instanceof BinaryFeaturesExample) { 091 newExample = new BinaryFeaturesExample<>(example.getOutput()); 092 } else { 093 newExample = new ArrayExample<>(example.getOutput()); 094 } 095 newExample.setWeight(example.getWeight()); 096 for (Feature feature : example) { 097 VariableInfo featureInfo = featureMap.get(feature.getName()); 098 if (featureInfo == null || featureInfo.getCount() < minCardinality) { 099 // 100 // The feature info might be null if we have a feature at 101 // prediction time that we didn't see 102 // at training time. 103 removedFeatureNames.add(feature.getName()); 104 } else { 105 features.add(feature); 106 } 107 } 108 newExample.addAll(features); 109 if (newExample.size() > 0) { 110 if (!newExample.validateExample()) { 111 throw new IllegalStateException("Duplicate features found in example " + newExample.toString()); 112 } 113 newExamples.add(newExample); 114 } else { 115 numExamplesRemoved++; 116 add = false; 117 break; 118 } 119 } 120 if (add) { 121 SequenceExample<T> newSequenceExample = new SequenceExample<>(newExamples); 122 data.add(newSequenceExample); 123 } 124 } 125 126 // Copy out the feature infos above the threshold. 127 for (VariableInfo info : featureMap) { 128 if (info.getCount() >= minCardinality) { 129 featureInfos.put(info.copy()); 130 } 131 } 132 133 this.outputIDInfo = sequenceDataset.getOutputIDInfo(); 134 this.featureIDMap = new ImmutableFeatureMap(featureInfos); 135 136 if (numExamplesRemoved > 0) { 137 logger.info(String.format( 138 "filtered out %d sequence examples because (at least) one of its examples had zero features after the minimum frequency count was applied.", 139 numExamplesRemoved)); 140 } 141 } 142 143 /** 144 * The feature names that were removed. 145 * 146 * @return The feature names. 147 */ 148 public Set<String> getRemoved() { 149 return removedFeatureNames; 150 } 151 152 /** 153 * The number of examples removed due to a lack of features. 154 * 155 * @return The number of removed examples. 156 */ 157 public int getNumExamplesRemoved() { 158 return numExamplesRemoved; 159 } 160 161 /** 162 * The minimum cardinality threshold for the features. 163 * 164 * @return The cardinality threshold. 165 */ 166 public int getMinCardinality() { 167 return minCardinality; 168 } 169 170 @Override 171 public DatasetProvenance getProvenance() { 172 return new MinimumCardinalitySequenceDatasetProvenance(this); 173 } 174 175 /** 176 * Provenance for {@link MinimumCardinalitySequenceDataset}. 177 */ 178 public static class MinimumCardinalitySequenceDatasetProvenance extends DatasetProvenance { 179 private static final long serialVersionUID = 1L; 180 181 private static final String MIN_CARDINALITY = "min-cardinality"; 182 183 private final IntProvenance minCardinality; 184 185 <T extends Output<T>> MinimumCardinalitySequenceDatasetProvenance( 186 MinimumCardinalitySequenceDataset<T> dataset) { 187 super(dataset.sourceProvenance, new ListProvenance<>(), dataset); 188 this.minCardinality = new IntProvenance(MIN_CARDINALITY, dataset.minCardinality); 189 } 190 191 public MinimumCardinalitySequenceDatasetProvenance(Map<String, Provenance> map) { 192 super(map); 193 this.minCardinality = ObjectProvenance.checkAndExtractProvenance(map, MIN_CARDINALITY, IntProvenance.class, 194 MinimumCardinalitySequenceDatasetProvenance.class.getSimpleName()); 195 } 196 197 @Override 198 public boolean equals(Object o) { 199 if (this == o) 200 return true; 201 if (!(o instanceof MinimumCardinalitySequenceDatasetProvenance)) 202 return false; 203 if (!super.equals(o)) 204 return false; 205 MinimumCardinalitySequenceDatasetProvenance pairs = (MinimumCardinalitySequenceDatasetProvenance) o; 206 return minCardinality.equals(pairs.minCardinality); 207 } 208 209 @Override 210 public int hashCode() { 211 return Objects.hash(super.hashCode(), minCardinality); 212 } 213 214 @Override 215 protected List<Pair<String, Provenance>> allProvenances() { 216 List<Pair<String, Provenance>> provenances = super.allProvenances(); 217 provenances.add(new Pair<>(MIN_CARDINALITY, minCardinality)); 218 return provenances; 219 } 220 } 221}