001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import com.oracle.labs.mlrg.olcut.util.MutableNumber;
021import org.tribuo.util.Util;
022
023import java.io.IOException;
024import java.util.HashMap;
025import java.util.Map;
026import java.util.Random;
027import java.util.SplittableRandom;
028
029/**
030 * Stores information about Categorical features.
031 * <p>
032 * Contains a mapping from values to observed counts for that value, has
033 * an initial optimisation for the binary case to reduce memory consumption.
034 * </p>
035 * <p>
036 * Can be transformed into a {@link RealInfo} if there are too many unique observed values.
037 * </p>
038 * <p>
039 * Does not contain an id number, but can be transformed into {@link CategoricalIDInfo} which
040 * does contain an id number.
041 * </p>
042 * <p>
043 * Note that the synchronization in this class only protects instantiation where CDF and values
044 * are recomputed. Care should be taken if data is read while {@link #observe(double)} is called.
045 * </p>
046 */
047public class CategoricalInfo extends SkeletalVariableInfo {
048    private static final long serialVersionUID = 2L;
049
050    private static final MutableLong ZERO = new MutableLong(0);
051    /**
052     * The default threshold for converting a categorical info into a {@link RealInfo}.
053     */
054    public static final int THRESHOLD = 50;
055    private static final double COMPARISON_THRESHOLD = 1e-10;
056
057    /**
058     * The occurrence counts of each value.
059     */
060    protected Map<Double,MutableLong> valueCounts = null;
061
062    /**
063     * The observed value if it's only seen a single one.
064     */
065    protected double observedValue = Double.NaN;
066
067    /**
068     * The count of the observed value if it's only seen a single one.
069     */
070    protected long observedCount = 0;
071
072    // These variables are used in the sampling methods, and regenerated after serialization if a sample is required.
073    /**
074     * The values array.
075     */
076    protected transient double[] values = null;
077    /**
078     * The total number of observations (including zeros).
079     */
080    protected transient long totalObservations = -1;
081    /**
082     * The CDF to sample from.
083     */
084    protected transient double[] cdf = null;
085
086    /**
087     * Constructs a new empty categorical info for the supplied feature name.
088     * @param name The feature name.
089     */
090    public CategoricalInfo(String name) {
091        super(name);
092    }
093
094    /**
095     * Constructs a deep copy of the supplied categorical info.
096     * @param info The info to copy.
097     */
098    protected CategoricalInfo(CategoricalInfo info) {
099        this(info,info.name);
100    }
101
102    /**
103     * Constructs a deep copy of the supplied categorical info, with the new feature name.
104     * @param info The info to copy.
105     * @param newName The new feature name.
106     */
107    protected CategoricalInfo(CategoricalInfo info, String newName) {
108        super(newName,info.count);
109        if (info.valueCounts != null) {
110            valueCounts = MutableNumber.copyMap(info.valueCounts);
111        } else {
112            observedValue = info.observedValue;
113            observedCount = info.observedCount;
114        }
115    }
116
117    @Override
118    protected void observe(double value) {
119        if (value != 0.0) {
120            super.observe(value);
121            if (valueCounts != null) {
122                MutableLong count = valueCounts.computeIfAbsent(value, k -> new MutableLong());
123                count.increment();
124            } else {
125                if (Double.isNaN(observedValue)) {
126                    observedValue = value;
127                    observedCount++;
128                } else if (Math.abs(value - observedValue) < COMPARISON_THRESHOLD) {
129                    observedCount++;
130                } else {
131                    // Observed two values for this CategoricalInfo, now it needs a HashMap.
132                    valueCounts = new HashMap<>(4);
133                    valueCounts.put(observedValue, new MutableLong(observedCount));
134                    valueCounts.put(value, new MutableLong(1));
135                    observedValue = Double.NaN;
136                    observedCount = 0;
137                }
138            }
139            values = null;
140        }
141    }
142
143    /**
144     * Gets the number of times a specific value was observed, and zero if this value is unknown.
145     * @param value The value to check.
146     * @return The count of times this value was observed, zero otherwise.
147     */
148    public long getObservationCount(double value) {
149        if (valueCounts != null) {
150            return valueCounts.getOrDefault(value, ZERO).longValue();
151        } else {
152            if (Math.abs(value - observedValue) < COMPARISON_THRESHOLD) {
153                return observedCount;
154            } else {
155                return 0;
156            }
157        }
158    }
159
160    /**
161     * Gets the number of unique values this CategoricalInfo has observed.
162     * @return An int representing the number of unique values.
163     */
164    public int getUniqueObservations() {
165        if (valueCounts != null) {
166            return valueCounts.size();
167        } else {
168            if (Double.isNaN(observedValue)) {
169                return 0;
170            } else {
171                return 1;
172            }
173        }
174    }
175
176    /**
177     * Generates a {@link RealInfo} using the currently observed counts to calculate
178     * the min, max, mean and variance.
179     * @return A RealInfo representing the data in this CategoricalInfo.
180     */
181    public RealInfo generateRealInfo() {
182        double min = Double.POSITIVE_INFINITY;
183        double max = Double.NEGATIVE_INFINITY;
184        double sum = 0.0;
185        double sumSquares = 0.0;
186        double mean;
187
188        if (valueCounts != null) {
189            for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) {
190                double value = e.getKey();
191                double valCount = e.getValue().longValue();
192                if (value > max) {
193                    max = value;
194                }
195                if (value < min) {
196                    min = value;
197                }
198                sum += value * valCount;
199            }
200            mean = sum / count;
201
202            for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) {
203                double value = e.getKey();
204                double valCount = e.getValue().longValue();
205                sumSquares += (value - mean) * (value - mean) * valCount;
206            }
207        } else {
208            min = observedValue;
209            max = observedValue;
210            mean = observedValue;
211            sumSquares = 0.0;
212        }
213
214        return new RealInfo(name,count,max,min,mean,sumSquares);
215    }
216
217    @Override
218    public CategoricalInfo copy() {
219        return new CategoricalInfo(this);
220    }
221
222    @Override
223    public CategoricalIDInfo makeIDInfo(int id) {
224        return new CategoricalIDInfo(this,id);
225    }
226
227    @Override
228    public CategoricalInfo rename(String newName) {
229        return new CategoricalInfo(this,newName);
230    }
231
232    @Override
233    public synchronized double uniformSample(SplittableRandom rng) {
234        if (values == null) {
235            regenerateValues();
236        }
237        int sampleIdx = rng.nextInt(values.length);
238        return values[sampleIdx];
239    }
240
241    /**
242     * Samples a value from this feature according to the frequency of observation.
243     * @param rng The RNG to use.
244     * @param totalObservations The observations including the implicit zeros.
245     * @return The sampled value.
246     */
247    public double frequencyBasedSample(SplittableRandom rng, long totalObservations) {
248        if ((totalObservations != this.totalObservations) || (cdf == null)) {
249            regenerateCDF(totalObservations);
250        }
251        int lookup = Util.sampleFromCDF(cdf,rng);
252        return values[lookup];
253    }
254
255    /**
256     * Samples a value from this feature according to the frequency of observation.
257     * @param rng The RNG to use.
258     * @param totalObservations The observations including the implicit zeros.
259     * @return The sampled value.
260     */
261    public double frequencyBasedSample(Random rng, long totalObservations) {
262        if ((totalObservations != this.totalObservations) || (cdf == null)) {
263            regenerateCDF(totalObservations);
264        }
265        int lookup = Util.sampleFromCDF(cdf,rng);
266        return values[lookup];
267    }
268
269    /**
270     * Generates the CDF for sampling.
271     * @param newTotalObservations The new number of total observations including the implicit zeros.
272     */
273    private synchronized void regenerateCDF(long newTotalObservations) {
274        long[] counts;
275        if (valueCounts != null) {
276            // This is tricksy as if valueCounts contains zero that means
277            // we could have both observed zeros and unobserved zeros.
278            if (valueCounts.containsKey(0.0)) {
279                values = new double[valueCounts.size()];
280                counts = new long[valueCounts.size()];
281            } else {
282                values = new double[valueCounts.size()+1];
283                counts = new long[valueCounts.size()+1];
284            }
285            values[0] = 0;
286            counts[0] = newTotalObservations;
287            int counter = 1;
288            long total = 0;
289            for (Map.Entry<Double,MutableLong> e : valueCounts.entrySet()) {
290                if (e.getKey() != 0.0) {
291                    values[counter] = e.getKey();
292                    counts[counter] = e.getValue().longValue();
293                    total += counts[counter];
294                    counter++;
295                }
296            }
297            // Set the zero counts appropriately
298            counts[0] -= total;
299        } else {
300            if (Double.isNaN(observedValue) || observedValue == 0.0) {
301                values = new double[1];
302                counts = new long[1];
303                values[0] = 0;
304                counts[0] = newTotalObservations;
305            } else {
306                values = new double[2];
307                counts = new long[2];
308                values[0] = 0;
309                counts[0] = newTotalObservations - observedCount;
310                values[1] = observedValue;
311                counts[1] = observedCount;
312            }
313        }
314        long sum = 0;
315        for (int i = 0; i < counts.length; i++) {
316            sum += counts[i];
317        }
318        if (sum != newTotalObservations) {
319            throw new IllegalStateException("Total counts = " + sum + ", supplied value = " + newTotalObservations);
320        }
321        cdf = Util.generateCDF(counts,sum);
322        totalObservations = newTotalObservations;
323    }
324
325    /**
326     * Recomputes the values array.
327     */
328    private synchronized void regenerateValues() {
329        //
330        // Recompute values array
331        if (valueCounts != null) {
332            int counter;
333            if (valueCounts.containsKey(0.0)) {
334                values = new double[valueCounts.size()];
335                counter = 0;
336            } else {
337                values = new double[valueCounts.size() + 1];
338                values[0] = 0;
339                counter = 1;
340            }
341            for (Double key : valueCounts.keySet()) {
342                values[counter] = key;
343                counter++;
344            }
345        } else {
346            if (Double.isNaN(observedValue) || observedValue == 0.0) {
347                values = new double[1];
348                values[0] = 0;
349            } else {
350                values = new double[2];
351                values[0] = 0;
352                values[1] = observedValue;
353            }
354        }
355    }
356
357    @Override
358    public String toString() {
359        if (valueCounts != null) {
360            return "CategoricalFeature(name=" + name + ",count=" + count + ",map=" + valueCounts.toString() + ")";
361        } else {
362            return "CategoricalFeature(name=" + name + ",count=" + count + ",map={" +observedValue+","+observedCount+"})";
363        }
364    }
365
366    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
367        in.defaultReadObject();
368        totalObservations = -1;
369        values = null;
370        cdf = null;
371    }
372}