001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.transform.transformations;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
023import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
024import org.tribuo.transform.TransformStatistics;
025import org.tribuo.transform.Transformation;
026import org.tribuo.transform.TransformationProvenance;
027import org.tribuo.transform.Transformer;
028
029import java.util.Arrays;
030import java.util.Collections;
031import java.util.HashMap;
032import java.util.Map;
033import java.util.Objects;
034
035/**
036 * A Transformation which bins values.
037 * <p>
038 * Three binning types are implemented:
039 * <ul>
040 * <li>Equal width bins, based on the observed min and max.</li>
041 * <li>Equal frequency bins, based on the observed data.   </li>
042 * <li>Standard deviation width bins, based on the observed standard deviation and mean.</li>
043 * </ul>
044 * <p>
045 * The equal frequency {@link TransformStatistics} needs to
046 * store all the observed feature values, and thus has much higher
047 * memory usage than all other binning types.
048 * <p>
049 * The binned values are in the range [1, numBins].
050 */
051public final class BinningTransformation implements Transformation {
052
053    /**
054     * The allowed binning types.
055     */
056    public enum BinningType { EQUAL_WIDTH, EQUAL_FREQUENCY, STD_DEVS }
057
058    private static final String NUM_BINS = "numBins";
059    private static final String TYPE = "type";
060
061    @Config(description="Number of bins.")
062    private int numBins;
063
064    @Config(description="Binning algorithm to use.")
065    private BinningType type;
066
067    /**
068     * For olcut.
069     */
070    private BinningTransformation() { }
071
072    private BinningTransformation(BinningType type, int numBins) {
073        this.type = type;
074        this.numBins = numBins;
075        postConfig();
076    }
077
078    /**
079     * Used by the OLCUT configuration system, and should not be called by external code.
080     */
081    @Override
082    public void postConfig() {
083        if (numBins < 2) {
084            throw new IllegalArgumentException("Number of bins must be 2 or greater, found " + numBins);
085        } else if (type == BinningType.STD_DEVS && ((numBins & 1) == 1)) {
086            throw new IllegalArgumentException("Std dev must have an even number of bins, found " + numBins);
087        }
088    }
089
090    @Override
091    public TransformStatistics createStats() {
092        switch (type) {
093            case EQUAL_WIDTH:
094                return new EqualWidthStats(numBins);
095            case EQUAL_FREQUENCY:
096                return new EqualFreqStats(numBins);
097            case STD_DEVS:
098                return new StdDevStats(numBins);
099            default:
100                throw new IllegalStateException("Unknown binning type " + type);
101        }
102    }
103
104    @Override
105    public TransformationProvenance getProvenance() {
106        return new BinningTransformationProvenance(this);
107    }
108
109    /**
110     * Provenance for {@link BinningTransformation}.
111     */
112    public final static class BinningTransformationProvenance implements TransformationProvenance {
113        private static final long serialVersionUID = 1L;
114
115        private final IntProvenance numBins;
116        private final EnumProvenance<BinningType> type;
117
118        BinningTransformationProvenance(BinningTransformation host) {
119            this.numBins = new IntProvenance(NUM_BINS,host.numBins);
120            this.type = new EnumProvenance<>(TYPE,host.type);
121        }
122
123        @SuppressWarnings("unchecked") // Enum cast
124        public BinningTransformationProvenance(Map<String,Provenance> map) {
125            numBins = ObjectProvenance.checkAndExtractProvenance(map,NUM_BINS,IntProvenance.class,BinningTransformationProvenance.class.getSimpleName());
126            type = ObjectProvenance.checkAndExtractProvenance(map,TYPE,EnumProvenance.class,BinningTransformationProvenance.class.getSimpleName());
127        }
128
129        @Override
130        public String getClassName() {
131            return BinningTransformation.class.getName();
132        }
133
134        @Override
135        public boolean equals(Object o) {
136            if (this == o) return true;
137            if (!(o instanceof BinningTransformationProvenance)) return false;
138            BinningTransformationProvenance pairs = (BinningTransformationProvenance) o;
139            return numBins.equals(pairs.numBins) &&
140                    type.equals(pairs.type);
141        }
142
143        @Override
144        public int hashCode() {
145            return Objects.hash(numBins, type);
146        }
147
148        @Override
149        public Map<String, Provenance> getConfiguredParameters() {
150            Map<String,Provenance> map = new HashMap<>();
151            map.put(NUM_BINS,numBins);
152            map.put(TYPE,type);
153            return Collections.unmodifiableMap(map);
154        }
155    }
156
157    @Override
158    public String toString() {
159        return "BinningTransformation(type="+type+",numBins="+numBins+")";
160    }
161
162    /**
163     * Returns a BinningTransformation which generates
164     * fixed equal width bins between the observed min and max
165     * values.
166     * <p>
167     * Values outside the observed range are clamped to either
168     * the minimum or maximum bin. Bins are numbered in the range
169     * [1,numBins].
170     * @param numBins The number of bins to generate.
171     * @return An equal width binning.
172     */
173    public static BinningTransformation equalWidth(int numBins) {
174        return new BinningTransformation(BinningType.EQUAL_WIDTH,numBins);
175    }
176
177    /**
178     * Returns a BinningTransformation which generates
179     * bins which contain the same amount of training data
180     * that is, each bin has an equal probability of occurrence
181     * in the training data.
182     * <p>
183     * Values outside the observed range are clamped to either
184     * the minimum or maximum bin. Bins are numbered in the range
185     * [1,numBins].
186     * @param numBins The number of bins to generate.
187     * @return An equal frequency binning.
188     */
189    public static BinningTransformation equalFrequency(int numBins) {
190        return new BinningTransformation(BinningType.EQUAL_FREQUENCY,numBins);
191    }
192
193    /**
194     * Returns a BinningTransformation which generates bins
195     * based on the observed standard deviation of the training
196     * data. Each bin is a standard deviation wide, except for
197     * the bins at the edges which encompass all lower or higher
198     * values.
199     * <p>
200     * Bins are numbered in the range [1,numDeviations*2]. The middle two
201     * bins are either side of the mean, the lowest bin is the mean minus
202     * numDeviations * observed standard deviation, the highest bin is the
203     * mean plus numDeviations * observed standard deviation.
204     * @param numDeviations The number of standard deviations to bin.
205     * @return A standard deviation based binning.
206     */
207    public static BinningTransformation stdDevs(int numDeviations) {
208        return new BinningTransformation(BinningType.STD_DEVS,numDeviations*2);
209    }
210
211    private static class EqualWidthStats implements TransformStatistics {
212        private final int numBins;
213
214        private double min = Double.POSITIVE_INFINITY;
215        private double max = Double.NEGATIVE_INFINITY;
216
217        public EqualWidthStats(int numBins) {
218            this.numBins = numBins;
219        }
220
221        @Override
222        public void observeValue(double value) {
223            if (value < min) {
224                min = value;
225            }
226            if (value > max) {
227                max = value;
228            }
229        }
230
231        @Override
232        public void observeSparse() { }
233
234        @Override
235        public void observeSparse(int count) { }
236
237        @Override
238        public Transformer generateTransformer() {
239            double range = Math.abs(max - min);
240            double increment = range / numBins;
241            double[] bins = new double[numBins];
242            double[] values = new double[numBins];
243
244            for (int i = 0; i < bins.length; i++) {
245                bins[i] = min + ((i+1) * increment);
246                values[i] = i+1;
247            }
248
249            return new BinningTransformer(BinningType.EQUAL_WIDTH,bins,values);
250        }
251
252        @Override
253        public String toString() {
254            return "EqualWidthStatistics(min="+min+",max="+max+",numBins="+numBins+")";
255        }
256    }
257
258    private static class EqualFreqStats implements TransformStatistics {
259        private static final int DEFAULT_SIZE = 50;
260        private final int numBins;
261
262        private double[] observedValues;
263        private int count;
264
265        public EqualFreqStats(int numBins) {
266            this.numBins = numBins;
267            this.observedValues = new double[DEFAULT_SIZE];
268            this.count = 0;
269        }
270
271        @Override
272        public void observeValue(double value) {
273            if (observedValues.length == count + 1) {
274                growArray();
275            }
276            observedValues[count] = value;
277            count++;
278        }
279
280        protected void growArray(int minCapacity) {
281            int newCapacity = newCapacity(minCapacity);
282            observedValues = Arrays.copyOf(observedValues,newCapacity);
283        }
284
285        /**
286         * Returns a capacity at least as large as the given minimum capacity.
287         * Returns the current capacity increased by 50% if that suffices.
288         * Will not return a capacity greater than MAX_ARRAY_SIZE unless
289         * the given minimum capacity is greater than MAX_ARRAY_SIZE.
290         *
291         * @param minCapacity the desired minimum capacity
292         * @throws OutOfMemoryError if minCapacity is less than zero
293         */
294        private int newCapacity(int minCapacity) {
295            // overflow-conscious code
296            int oldCapacity = observedValues.length;
297            int newCapacity = oldCapacity + (oldCapacity >> 1);
298            if (newCapacity - minCapacity <= 0) {
299                if (minCapacity < 0) // overflow
300                    throw new OutOfMemoryError();
301                return minCapacity;
302            }
303            return newCapacity;
304        }
305
306        protected void growArray() {
307            growArray(count+1);
308        }
309
310        @Override
311        public void observeSparse() { }
312
313        @Override
314        public void observeSparse(int count) { }
315
316        @Override
317        public Transformer generateTransformer() {
318            if (numBins > observedValues.length) {
319                throw new IllegalStateException("Needs more values than bins, requested " + numBins + " bins, but only found " + observedValues.length + " values.");
320            }
321            Arrays.sort(observedValues,0,count);
322            double[] bins = new double[numBins];
323            double[] values = new double[numBins];
324            int increment = count / numBins;
325            for (int i = 0; i < numBins-1; i++) {
326                bins[i] = observedValues[(i+1)*increment];
327                values[i] = i+1;
328            }
329            bins[numBins-1] = observedValues[count-1];
330            values[numBins-1] = numBins;
331            return new BinningTransformer(BinningType.EQUAL_FREQUENCY, bins, values);
332        }
333
334        @Override
335        public String toString() {
336            return "EqualFreqStatistics(count="+count+",numBins="+numBins+")";
337        }
338    }
339
340    private static class StdDevStats implements TransformStatistics {
341        private final int numBins;
342
343        private double mean = 0;
344        private double sumSquares = 0;
345        private long count = 0;
346
347        public StdDevStats(int numBins) {
348            this.numBins = numBins;
349        }
350
351        @Override
352        public void observeValue(double value) {
353            count++;
354            double delta = value - mean;
355            mean += delta / count;
356            double delta2 = value - mean;
357            sumSquares += delta * delta2;
358        }
359
360        @Override
361        public void observeSparse() { }
362
363        @Override
364        public void observeSparse(int count) { }
365
366        @Override
367        public Transformer generateTransformer() {
368            double[] bins = new double[numBins];
369            double[] values = new double[numBins];
370
371            double stdDev = Math.sqrt(sumSquares/(count-1));
372
373            int binCount = -(numBins/2);
374
375            for (int i = 0; i < bins.length; i++) {
376                values[i] = i+1;
377                binCount++;
378                bins[i] = mean + (binCount * stdDev);
379            }
380
381            return new BinningTransformer(BinningType.STD_DEVS,bins,values);
382        }
383
384        @Override
385        public String toString() {
386            return "StdDevStatistics(mean="+mean+",sumSquares="+sumSquares+",count="+count+",numBins="+numBins+")";
387        }
388    }
389
390    private static class BinningTransformer implements Transformer {
391        private static final long serialVersionUID = 1L;
392
393        private final BinningType type;
394        private final double[] bins;
395        private final double[] values;
396
397        public BinningTransformer(BinningType type, double[] bins, double[] values) {
398            this.type = type;
399            this.bins = bins;
400            this.values = values;
401        }
402
403        @Override
404        public double transform(double input) {
405            if (input > bins[bins.length-1]) {
406                return values[bins.length-1];
407            } else {
408                int index = Arrays.binarySearch(bins,input);
409                if (index < 0) {
410                    return values[- 1 - index];
411                } else {
412                    return values[index];
413                }
414            }
415        }
416
417        @Override
418        public String toString() {
419            return "BinningTransformer(type="+type+",bins="+Arrays.toString(bins)+",values="+Arrays.toString(values)+")";
420        }
421    }
422}