001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.transform.transformations; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance; 023import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance; 024import org.tribuo.transform.TransformStatistics; 025import org.tribuo.transform.Transformation; 026import org.tribuo.transform.TransformationProvenance; 027import org.tribuo.transform.Transformer; 028 029import java.util.Arrays; 030import java.util.Collections; 031import java.util.HashMap; 032import java.util.Map; 033import java.util.Objects; 034 035/** 036 * A Transformation which bins values. 037 * <p> 038 * Three binning types are implemented: 039 * <ul> 040 * <li>Equal width bins, based on the observed min and max.</li> 041 * <li>Equal frequency bins, based on the observed data. </li> 042 * <li>Standard deviation width bins, based on the observed standard deviation and mean.</li> 043 * </ul> 044 * <p> 045 * The equal frequency {@link TransformStatistics} needs to 046 * store all the observed feature values, and thus has much higher 047 * memory usage than all other binning types. 048 * <p> 049 * The binned values are in the range [1, numBins]. 050 */ 051public final class BinningTransformation implements Transformation { 052 053 /** 054 * The allowed binning types. 055 */ 056 public enum BinningType { EQUAL_WIDTH, EQUAL_FREQUENCY, STD_DEVS } 057 058 private static final String NUM_BINS = "numBins"; 059 private static final String TYPE = "type"; 060 061 @Config(description="Number of bins.") 062 private int numBins; 063 064 @Config(description="Binning algorithm to use.") 065 private BinningType type; 066 067 /** 068 * For olcut. 069 */ 070 private BinningTransformation() { } 071 072 private BinningTransformation(BinningType type, int numBins) { 073 this.type = type; 074 this.numBins = numBins; 075 postConfig(); 076 } 077 078 /** 079 * Used by the OLCUT configuration system, and should not be called by external code. 080 */ 081 @Override 082 public void postConfig() { 083 if (numBins < 2) { 084 throw new IllegalArgumentException("Number of bins must be 2 or greater, found " + numBins); 085 } else if (type == BinningType.STD_DEVS && ((numBins & 1) == 1)) { 086 throw new IllegalArgumentException("Std dev must have an even number of bins, found " + numBins); 087 } 088 } 089 090 @Override 091 public TransformStatistics createStats() { 092 switch (type) { 093 case EQUAL_WIDTH: 094 return new EqualWidthStats(numBins); 095 case EQUAL_FREQUENCY: 096 return new EqualFreqStats(numBins); 097 case STD_DEVS: 098 return new StdDevStats(numBins); 099 default: 100 throw new IllegalStateException("Unknown binning type " + type); 101 } 102 } 103 104 @Override 105 public TransformationProvenance getProvenance() { 106 return new BinningTransformationProvenance(this); 107 } 108 109 /** 110 * Provenance for {@link BinningTransformation}. 111 */ 112 public final static class BinningTransformationProvenance implements TransformationProvenance { 113 private static final long serialVersionUID = 1L; 114 115 private final IntProvenance numBins; 116 private final EnumProvenance<BinningType> type; 117 118 BinningTransformationProvenance(BinningTransformation host) { 119 this.numBins = new IntProvenance(NUM_BINS,host.numBins); 120 this.type = new EnumProvenance<>(TYPE,host.type); 121 } 122 123 @SuppressWarnings("unchecked") // Enum cast 124 public BinningTransformationProvenance(Map<String,Provenance> map) { 125 numBins = ObjectProvenance.checkAndExtractProvenance(map,NUM_BINS,IntProvenance.class,BinningTransformationProvenance.class.getSimpleName()); 126 type = ObjectProvenance.checkAndExtractProvenance(map,TYPE,EnumProvenance.class,BinningTransformationProvenance.class.getSimpleName()); 127 } 128 129 @Override 130 public String getClassName() { 131 return BinningTransformation.class.getName(); 132 } 133 134 @Override 135 public boolean equals(Object o) { 136 if (this == o) return true; 137 if (!(o instanceof BinningTransformationProvenance)) return false; 138 BinningTransformationProvenance pairs = (BinningTransformationProvenance) o; 139 return numBins.equals(pairs.numBins) && 140 type.equals(pairs.type); 141 } 142 143 @Override 144 public int hashCode() { 145 return Objects.hash(numBins, type); 146 } 147 148 @Override 149 public Map<String, Provenance> getConfiguredParameters() { 150 Map<String,Provenance> map = new HashMap<>(); 151 map.put(NUM_BINS,numBins); 152 map.put(TYPE,type); 153 return Collections.unmodifiableMap(map); 154 } 155 } 156 157 @Override 158 public String toString() { 159 return "BinningTransformation(type="+type+",numBins="+numBins+")"; 160 } 161 162 /** 163 * Returns a BinningTransformation which generates 164 * fixed equal width bins between the observed min and max 165 * values. 166 * <p> 167 * Values outside the observed range are clamped to either 168 * the minimum or maximum bin. Bins are numbered in the range 169 * [1,numBins]. 170 * @param numBins The number of bins to generate. 171 * @return An equal width binning. 172 */ 173 public static BinningTransformation equalWidth(int numBins) { 174 return new BinningTransformation(BinningType.EQUAL_WIDTH,numBins); 175 } 176 177 /** 178 * Returns a BinningTransformation which generates 179 * bins which contain the same amount of training data 180 * that is, each bin has an equal probability of occurrence 181 * in the training data. 182 * <p> 183 * Values outside the observed range are clamped to either 184 * the minimum or maximum bin. Bins are numbered in the range 185 * [1,numBins]. 186 * @param numBins The number of bins to generate. 187 * @return An equal frequency binning. 188 */ 189 public static BinningTransformation equalFrequency(int numBins) { 190 return new BinningTransformation(BinningType.EQUAL_FREQUENCY,numBins); 191 } 192 193 /** 194 * Returns a BinningTransformation which generates bins 195 * based on the observed standard deviation of the training 196 * data. Each bin is a standard deviation wide, except for 197 * the bins at the edges which encompass all lower or higher 198 * values. 199 * <p> 200 * Bins are numbered in the range [1,numDeviations*2]. The middle two 201 * bins are either side of the mean, the lowest bin is the mean minus 202 * numDeviations * observed standard deviation, the highest bin is the 203 * mean plus numDeviations * observed standard deviation. 204 * @param numDeviations The number of standard deviations to bin. 205 * @return A standard deviation based binning. 206 */ 207 public static BinningTransformation stdDevs(int numDeviations) { 208 return new BinningTransformation(BinningType.STD_DEVS,numDeviations*2); 209 } 210 211 private static class EqualWidthStats implements TransformStatistics { 212 private final int numBins; 213 214 private double min = Double.POSITIVE_INFINITY; 215 private double max = Double.NEGATIVE_INFINITY; 216 217 public EqualWidthStats(int numBins) { 218 this.numBins = numBins; 219 } 220 221 @Override 222 public void observeValue(double value) { 223 if (value < min) { 224 min = value; 225 } 226 if (value > max) { 227 max = value; 228 } 229 } 230 231 @Override 232 public void observeSparse() { } 233 234 @Override 235 public void observeSparse(int count) { } 236 237 @Override 238 public Transformer generateTransformer() { 239 double range = Math.abs(max - min); 240 double increment = range / numBins; 241 double[] bins = new double[numBins]; 242 double[] values = new double[numBins]; 243 244 for (int i = 0; i < bins.length; i++) { 245 bins[i] = min + ((i+1) * increment); 246 values[i] = i+1; 247 } 248 249 return new BinningTransformer(BinningType.EQUAL_WIDTH,bins,values); 250 } 251 252 @Override 253 public String toString() { 254 return "EqualWidthStatistics(min="+min+",max="+max+",numBins="+numBins+")"; 255 } 256 } 257 258 private static class EqualFreqStats implements TransformStatistics { 259 private static final int DEFAULT_SIZE = 50; 260 private final int numBins; 261 262 private double[] observedValues; 263 private int count; 264 265 public EqualFreqStats(int numBins) { 266 this.numBins = numBins; 267 this.observedValues = new double[DEFAULT_SIZE]; 268 this.count = 0; 269 } 270 271 @Override 272 public void observeValue(double value) { 273 if (observedValues.length == count + 1) { 274 growArray(); 275 } 276 observedValues[count] = value; 277 count++; 278 } 279 280 protected void growArray(int minCapacity) { 281 int newCapacity = newCapacity(minCapacity); 282 observedValues = Arrays.copyOf(observedValues,newCapacity); 283 } 284 285 /** 286 * Returns a capacity at least as large as the given minimum capacity. 287 * Returns the current capacity increased by 50% if that suffices. 288 * Will not return a capacity greater than MAX_ARRAY_SIZE unless 289 * the given minimum capacity is greater than MAX_ARRAY_SIZE. 290 * 291 * @param minCapacity the desired minimum capacity 292 * @throws OutOfMemoryError if minCapacity is less than zero 293 */ 294 private int newCapacity(int minCapacity) { 295 // overflow-conscious code 296 int oldCapacity = observedValues.length; 297 int newCapacity = oldCapacity + (oldCapacity >> 1); 298 if (newCapacity - minCapacity <= 0) { 299 if (minCapacity < 0) // overflow 300 throw new OutOfMemoryError(); 301 return minCapacity; 302 } 303 return newCapacity; 304 } 305 306 protected void growArray() { 307 growArray(count+1); 308 } 309 310 @Override 311 public void observeSparse() { } 312 313 @Override 314 public void observeSparse(int count) { } 315 316 @Override 317 public Transformer generateTransformer() { 318 if (numBins > observedValues.length) { 319 throw new IllegalStateException("Needs more values than bins, requested " + numBins + " bins, but only found " + observedValues.length + " values."); 320 } 321 Arrays.sort(observedValues,0,count); 322 double[] bins = new double[numBins]; 323 double[] values = new double[numBins]; 324 int increment = count / numBins; 325 for (int i = 0; i < numBins-1; i++) { 326 bins[i] = observedValues[(i+1)*increment]; 327 values[i] = i+1; 328 } 329 bins[numBins-1] = observedValues[count-1]; 330 values[numBins-1] = numBins; 331 return new BinningTransformer(BinningType.EQUAL_FREQUENCY, bins, values); 332 } 333 334 @Override 335 public String toString() { 336 return "EqualFreqStatistics(count="+count+",numBins="+numBins+")"; 337 } 338 } 339 340 private static class StdDevStats implements TransformStatistics { 341 private final int numBins; 342 343 private double mean = 0; 344 private double sumSquares = 0; 345 private long count = 0; 346 347 public StdDevStats(int numBins) { 348 this.numBins = numBins; 349 } 350 351 @Override 352 public void observeValue(double value) { 353 count++; 354 double delta = value - mean; 355 mean += delta / count; 356 double delta2 = value - mean; 357 sumSquares += delta * delta2; 358 } 359 360 @Override 361 public void observeSparse() { } 362 363 @Override 364 public void observeSparse(int count) { } 365 366 @Override 367 public Transformer generateTransformer() { 368 double[] bins = new double[numBins]; 369 double[] values = new double[numBins]; 370 371 double stdDev = Math.sqrt(sumSquares/(count-1)); 372 373 int binCount = -(numBins/2); 374 375 for (int i = 0; i < bins.length; i++) { 376 values[i] = i+1; 377 binCount++; 378 bins[i] = mean + (binCount * stdDev); 379 } 380 381 return new BinningTransformer(BinningType.STD_DEVS,bins,values); 382 } 383 384 @Override 385 public String toString() { 386 return "StdDevStatistics(mean="+mean+",sumSquares="+sumSquares+",count="+count+",numBins="+numBins+")"; 387 } 388 } 389 390 private static class BinningTransformer implements Transformer { 391 private static final long serialVersionUID = 1L; 392 393 private final BinningType type; 394 private final double[] bins; 395 private final double[] values; 396 397 public BinningTransformer(BinningType type, double[] bins, double[] values) { 398 this.type = type; 399 this.bins = bins; 400 this.values = values; 401 } 402 403 @Override 404 public double transform(double input) { 405 if (input > bins[bins.length-1]) { 406 return values[bins.length-1]; 407 } else { 408 int index = Arrays.binarySearch(bins,input); 409 if (index < 0) { 410 return values[- 1 - index]; 411 } else { 412 return values[index]; 413 } 414 } 415 } 416 417 @Override 418 public String toString() { 419 return "BinningTransformer(type="+type+",bins="+Arrays.toString(bins)+",values="+Arrays.toString(values)+")"; 420 } 421 } 422}