001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020
021import java.util.Map;
022
023/**
024 * Same as a {@link CategoricalInfo}, but with an additional int id field.
025 */
026public class CategoricalIDInfo extends CategoricalInfo implements VariableIDInfo {
027    private static final long serialVersionUID = 2L;
028
029    private final int id;
030
031    /**
032     * Constructs a categorical id info copying the information from the supplied info, with the specified id.
033     * @param info The info to copy.
034     * @param id The id number to use.
035     */
036    public CategoricalIDInfo(CategoricalInfo info, int id) {
037        super(info);
038        this.id = id;
039    }
040
041    /**
042     * Constructs a copy of the supplied categorical id info with the new name.
043     * <p>
044     * Used in the feature hashing system.
045     * @param info The info to copy.
046     * @param newName The new feature name.
047     */
048    private CategoricalIDInfo(CategoricalIDInfo info, String newName) {
049        super(info,newName);
050        this.id = info.id;
051    }
052
053    @Override
054    public int getID() {
055        return id;
056    }
057
058    /**
059     * Generates a {@link RealIDInfo} that matches this CategoricalInfo and
060     * also contains an id number.
061     */
062    @Override
063    public RealIDInfo generateRealInfo() {
064        double min = Double.POSITIVE_INFINITY;
065        double max = Double.NEGATIVE_INFINITY;
066        double sum = 0.0;
067        double sumSquares = 0.0;
068        double mean;
069
070        if (valueCounts != null) {
071            for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) {
072                double value = e.getKey();
073                double valCount = e.getValue().longValue();
074                if (value > max) {
075                    max = value;
076                }
077                if (value < min) {
078                    min = value;
079                }
080                sum += value * valCount;
081            }
082            mean = sum / count;
083
084            for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) {
085                double value = e.getKey();
086                double valCount = e.getValue().longValue();
087                sumSquares += (value - mean) * (value - mean) * valCount;
088            }
089        } else {
090            min = observedValue;
091            max = observedValue;
092            mean = observedValue;
093            sumSquares = 0.0;
094        }
095
096        return new RealIDInfo(name,count,max,min,mean,sumSquares,id);
097    }
098
099    @Override
100    public CategoricalIDInfo copy() {
101        return new CategoricalIDInfo(this,name);
102    }
103
104    @Override
105    public CategoricalIDInfo makeIDInfo(int id) {
106        return new CategoricalIDInfo(this,id);
107    }
108
109    @Override
110    public CategoricalIDInfo rename(String newName) {
111        return new CategoricalIDInfo(this,newName);
112    }
113
114    @Override
115    public String toString() {
116        if (valueCounts != null) {
117            return "CategoricalFeature(name=" + name + ",id=" + id + ",count=" + count + ",map=" + valueCounts.toString() + ")";
118        } else {
119            return "CategoricalFeature(name=" + name + ",id=" + id + ",count=" + count + ",map={" +observedValue+","+observedCount+"})";
120        }
121    }
122}