001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import com.oracle.labs.mlrg.olcut.util.MutableNumber; 021import org.tribuo.util.Util; 022 023import java.io.IOException; 024import java.util.HashMap; 025import java.util.Map; 026import java.util.Random; 027import java.util.SplittableRandom; 028 029/** 030 * Stores information about Categorical features. 031 * <p> 032 * Contains a mapping from values to observed counts for that value, has 033 * an initial optimisation for the binary case to reduce memory consumption. 034 * </p> 035 * <p> 036 * Can be transformed into a {@link RealInfo} if there are too many unique observed values. 037 * </p> 038 * <p> 039 * Does not contain an id number, but can be transformed into {@link CategoricalIDInfo} which 040 * does contain an id number. 041 * </p> 042 * <p> 043 * Note that the synchronization in this class only protects instantiation where CDF and values 044 * are recomputed. Care should be taken if data is read while {@link #observe(double)} is called. 045 * </p> 046 */ 047public class CategoricalInfo extends SkeletalVariableInfo { 048 private static final long serialVersionUID = 2L; 049 050 private static final MutableLong ZERO = new MutableLong(0); 051 /** 052 * The default threshold for converting a categorical info into a {@link RealInfo}. 053 */ 054 public static final int THRESHOLD = 50; 055 private static final double COMPARISON_THRESHOLD = 1e-10; 056 057 /** 058 * The occurrence counts of each value. 059 */ 060 protected Map<Double,MutableLong> valueCounts = null; 061 062 /** 063 * The observed value if it's only seen a single one. 064 */ 065 protected double observedValue = Double.NaN; 066 067 /** 068 * The count of the observed value if it's only seen a single one. 069 */ 070 protected long observedCount = 0; 071 072 // These variables are used in the sampling methods, and regenerated after serialization if a sample is required. 073 /** 074 * The values array. 075 */ 076 protected transient double[] values = null; 077 /** 078 * The total number of observations (including zeros). 079 */ 080 protected transient long totalObservations = -1; 081 /** 082 * The CDF to sample from. 083 */ 084 protected transient double[] cdf = null; 085 086 /** 087 * Constructs a new empty categorical info for the supplied feature name. 088 * @param name The feature name. 089 */ 090 public CategoricalInfo(String name) { 091 super(name); 092 } 093 094 /** 095 * Constructs a deep copy of the supplied categorical info. 096 * @param info The info to copy. 097 */ 098 protected CategoricalInfo(CategoricalInfo info) { 099 this(info,info.name); 100 } 101 102 /** 103 * Constructs a deep copy of the supplied categorical info, with the new feature name. 104 * @param info The info to copy. 105 * @param newName The new feature name. 106 */ 107 protected CategoricalInfo(CategoricalInfo info, String newName) { 108 super(newName,info.count); 109 if (info.valueCounts != null) { 110 valueCounts = MutableNumber.copyMap(info.valueCounts); 111 } else { 112 observedValue = info.observedValue; 113 observedCount = info.observedCount; 114 } 115 } 116 117 @Override 118 protected void observe(double value) { 119 if (value != 0.0) { 120 super.observe(value); 121 if (valueCounts != null) { 122 MutableLong count = valueCounts.computeIfAbsent(value, k -> new MutableLong()); 123 count.increment(); 124 } else { 125 if (Double.isNaN(observedValue)) { 126 observedValue = value; 127 observedCount++; 128 } else if (Math.abs(value - observedValue) < COMPARISON_THRESHOLD) { 129 observedCount++; 130 } else { 131 // Observed two values for this CategoricalInfo, now it needs a HashMap. 132 valueCounts = new HashMap<>(4); 133 valueCounts.put(observedValue, new MutableLong(observedCount)); 134 valueCounts.put(value, new MutableLong(1)); 135 observedValue = Double.NaN; 136 observedCount = 0; 137 } 138 } 139 values = null; 140 } 141 } 142 143 /** 144 * Gets the number of times a specific value was observed, and zero if this value is unknown. 145 * @param value The value to check. 146 * @return The count of times this value was observed, zero otherwise. 147 */ 148 public long getObservationCount(double value) { 149 if (valueCounts != null) { 150 return valueCounts.getOrDefault(value, ZERO).longValue(); 151 } else { 152 if (Math.abs(value - observedValue) < COMPARISON_THRESHOLD) { 153 return observedCount; 154 } else { 155 return 0; 156 } 157 } 158 } 159 160 /** 161 * Gets the number of unique values this CategoricalInfo has observed. 162 * @return An int representing the number of unique values. 163 */ 164 public int getUniqueObservations() { 165 if (valueCounts != null) { 166 return valueCounts.size(); 167 } else { 168 if (Double.isNaN(observedValue)) { 169 return 0; 170 } else { 171 return 1; 172 } 173 } 174 } 175 176 /** 177 * Generates a {@link RealInfo} using the currently observed counts to calculate 178 * the min, max, mean and variance. 179 * @return A RealInfo representing the data in this CategoricalInfo. 180 */ 181 public RealInfo generateRealInfo() { 182 double min = Double.POSITIVE_INFINITY; 183 double max = Double.NEGATIVE_INFINITY; 184 double sum = 0.0; 185 double sumSquares = 0.0; 186 double mean; 187 188 if (valueCounts != null) { 189 for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) { 190 double value = e.getKey(); 191 double valCount = e.getValue().longValue(); 192 if (value > max) { 193 max = value; 194 } 195 if (value < min) { 196 min = value; 197 } 198 sum += value * valCount; 199 } 200 mean = sum / count; 201 202 for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) { 203 double value = e.getKey(); 204 double valCount = e.getValue().longValue(); 205 sumSquares += (value - mean) * (value - mean) * valCount; 206 } 207 } else { 208 min = observedValue; 209 max = observedValue; 210 mean = observedValue; 211 sumSquares = 0.0; 212 } 213 214 return new RealInfo(name,count,max,min,mean,sumSquares); 215 } 216 217 @Override 218 public CategoricalInfo copy() { 219 return new CategoricalInfo(this); 220 } 221 222 @Override 223 public CategoricalIDInfo makeIDInfo(int id) { 224 return new CategoricalIDInfo(this,id); 225 } 226 227 @Override 228 public CategoricalInfo rename(String newName) { 229 return new CategoricalInfo(this,newName); 230 } 231 232 @Override 233 public synchronized double uniformSample(SplittableRandom rng) { 234 if (values == null) { 235 regenerateValues(); 236 } 237 int sampleIdx = rng.nextInt(values.length); 238 return values[sampleIdx]; 239 } 240 241 /** 242 * Samples a value from this feature according to the frequency of observation. 243 * @param rng The RNG to use. 244 * @param totalObservations The observations including the implicit zeros. 245 * @return The sampled value. 246 */ 247 public double frequencyBasedSample(SplittableRandom rng, long totalObservations) { 248 if ((totalObservations != this.totalObservations) || (cdf == null)) { 249 regenerateCDF(totalObservations); 250 } 251 int lookup = Util.sampleFromCDF(cdf,rng); 252 return values[lookup]; 253 } 254 255 /** 256 * Samples a value from this feature according to the frequency of observation. 257 * @param rng The RNG to use. 258 * @param totalObservations The observations including the implicit zeros. 259 * @return The sampled value. 260 */ 261 public double frequencyBasedSample(Random rng, long totalObservations) { 262 if ((totalObservations != this.totalObservations) || (cdf == null)) { 263 regenerateCDF(totalObservations); 264 } 265 int lookup = Util.sampleFromCDF(cdf,rng); 266 return values[lookup]; 267 } 268 269 /** 270 * Generates the CDF for sampling. 271 * @param newTotalObservations The new number of total observations including the implicit zeros. 272 */ 273 private synchronized void regenerateCDF(long newTotalObservations) { 274 long[] counts; 275 if (valueCounts != null) { 276 // This is tricksy as if valueCounts contains zero that means 277 // we could have both observed zeros and unobserved zeros. 278 if (valueCounts.containsKey(0.0)) { 279 values = new double[valueCounts.size()]; 280 counts = new long[valueCounts.size()]; 281 } else { 282 values = new double[valueCounts.size()+1]; 283 counts = new long[valueCounts.size()+1]; 284 } 285 values[0] = 0; 286 counts[0] = newTotalObservations; 287 int counter = 1; 288 long total = 0; 289 for (Map.Entry<Double,MutableLong> e : valueCounts.entrySet()) { 290 if (e.getKey() != 0.0) { 291 values[counter] = e.getKey(); 292 counts[counter] = e.getValue().longValue(); 293 total += counts[counter]; 294 counter++; 295 } 296 } 297 // Set the zero counts appropriately 298 counts[0] -= total; 299 } else { 300 if (Double.isNaN(observedValue) || observedValue == 0.0) { 301 values = new double[1]; 302 counts = new long[1]; 303 values[0] = 0; 304 counts[0] = newTotalObservations; 305 } else { 306 values = new double[2]; 307 counts = new long[2]; 308 values[0] = 0; 309 counts[0] = newTotalObservations - observedCount; 310 values[1] = observedValue; 311 counts[1] = observedCount; 312 } 313 } 314 long sum = 0; 315 for (int i = 0; i < counts.length; i++) { 316 sum += counts[i]; 317 } 318 if (sum != newTotalObservations) { 319 throw new IllegalStateException("Total counts = " + sum + ", supplied value = " + newTotalObservations); 320 } 321 cdf = Util.generateCDF(counts,sum); 322 totalObservations = newTotalObservations; 323 } 324 325 /** 326 * Recomputes the values array. 327 */ 328 private synchronized void regenerateValues() { 329 // 330 // Recompute values array 331 if (valueCounts != null) { 332 int counter; 333 if (valueCounts.containsKey(0.0)) { 334 values = new double[valueCounts.size()]; 335 counter = 0; 336 } else { 337 values = new double[valueCounts.size() + 1]; 338 values[0] = 0; 339 counter = 1; 340 } 341 for (Double key : valueCounts.keySet()) { 342 values[counter] = key; 343 counter++; 344 } 345 } else { 346 if (Double.isNaN(observedValue) || observedValue == 0.0) { 347 values = new double[1]; 348 values[0] = 0; 349 } else { 350 values = new double[2]; 351 values[0] = 0; 352 values[1] = observedValue; 353 } 354 } 355 } 356 357 @Override 358 public String toString() { 359 if (valueCounts != null) { 360 return "CategoricalFeature(name=" + name + ",count=" + count + ",map=" + valueCounts.toString() + ")"; 361 } else { 362 return "CategoricalFeature(name=" + name + ",count=" + count + ",map={" +observedValue+","+observedCount+"})"; 363 } 364 } 365 366 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 367 in.defaultReadObject(); 368 totalObservations = -1; 369 values = null; 370 cdf = null; 371 } 372}