001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression;
018
019import com.oracle.labs.mlrg.olcut.util.MutableDouble;
020import com.oracle.labs.mlrg.olcut.util.MutableLong;
021import com.oracle.labs.mlrg.olcut.util.MutableNumber;
022import com.oracle.labs.mlrg.olcut.util.Pair;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.MutableOutputInfo;
025import org.tribuo.OutputInfo;
026import org.tribuo.regression.Regressor.DimensionTuple;
027
028import java.util.ArrayList;
029import java.util.Collections;
030import java.util.Comparator;
031import java.util.LinkedHashMap;
032import java.util.Map;
033import java.util.Set;
034import java.util.SortedSet;
035import java.util.TreeMap;
036import java.util.TreeSet;
037
038/**
039 * The base class for regression information using {@link Regressor}s.
040 * <p>
041 * Stores the observed min, max, mean and variance for each dimension.
042 */
043public abstract class RegressionInfo implements OutputInfo<Regressor> {
044    private static final long serialVersionUID = 2L;
045
046    private static final MutableDouble NAN = new MutableDouble(Double.NaN);
047
048    protected Map<String,MutableDouble> maxMap = new LinkedHashMap<>();
049    protected Map<String,MutableDouble> minMap = new LinkedHashMap<>();
050
051    protected Map<String,MutableDouble> meanMap = new LinkedHashMap<>();
052    protected Map<String,MutableDouble> sumSquaresMap = new LinkedHashMap<>();
053
054    protected Map<String,MutableLong> countMap = new TreeMap<>();
055
056    protected long overallCount = 0;
057
058    protected int unknownCount = 0;
059
060    RegressionInfo() { }
061
062    RegressionInfo(RegressionInfo other) {
063        this.maxMap = MutableNumber.copyMap(other.maxMap);
064        this.minMap = MutableNumber.copyMap(other.minMap);
065        this.meanMap = MutableNumber.copyMap(other.meanMap);
066        this.sumSquaresMap = MutableNumber.copyMap(other.sumSquaresMap);
067        this.countMap = MutableNumber.copyMap(other.countMap);
068        this.overallCount = other.overallCount;
069    }
070
071    @Override
072    public int getUnknownCount() {
073        return unknownCount;
074    }
075
076    /**
077     * Returns a set containing a Regressor for each dimension with the minimum value observed.
078     * @return A set of Regressors, each with one active dimension.
079     */
080    @Override
081    public Set<Regressor> getDomain() {
082        TreeSet<DimensionTuple> outputs = new TreeSet<>(Comparator.comparing(DimensionTuple::getName));
083        for (Map.Entry<String,MutableDouble> e : minMap.entrySet()) {
084            outputs.add(new DimensionTuple(e.getKey(),e.getValue().doubleValue()));
085        }
086        @SuppressWarnings("unchecked") // DimensionTuple is a subtype of Regressor, and this set is immutable.
087        SortedSet<Regressor> setOutputs = (SortedSet<Regressor>) (SortedSet) Collections.unmodifiableSortedSet(outputs);
088        return setOutputs;
089    }
090
091    /**
092     * Gets the minimum value this RegressionInfo has seen, or NaN if it's not seen anything.
093     * @param name The dimension to check.
094     * @return The minimum value for that dimension.
095     */
096    public double getMin(String name) {
097        return minMap.getOrDefault(name,NAN).doubleValue();
098    }
099
100    /**
101     * Gets the maximum value this RegressionInfo has seen, or NaN if it's not seen that dimension.
102     * @param name The dimension to check.
103     * @return The maximum value for that dimension.
104     */
105    public double getMax(String name) {
106        return maxMap.getOrDefault(name,NAN).doubleValue();
107    }
108
109    /**
110     * Gets the mean value this RegressionInfo has seen, or NaN if it's not seen that dimension.
111     * @param name The dimension to check.
112     * @return The mean value for that dimension.
113     */
114    public double getMean(String name) {
115        return meanMap.getOrDefault(name,NAN).doubleValue();
116    }
117
118    /**
119     * Gets the variance this RegressionInfo has seen, or NaN if it's not seen that dimension.
120     * @param name The dimension to check.
121     * @return The variance for that dimension.
122     */
123    public double getVariance(String name) {
124        MutableDouble sumSquaresDbl = sumSquaresMap.get(name);
125        if (sumSquaresDbl != null) {
126            return sumSquaresDbl.doubleValue() / (countMap.get(name).longValue()-1);
127        } else {
128            return Double.NaN;
129        }
130    }
131
132    /**
133     * The number of dimensions this OutputInfo has seen.
134     * @return The number of dimensions this OutputInfo has seen.
135     */
136    @Override
137    public int size() {
138        return countMap.size();
139    }
140
141    @Override
142    public ImmutableOutputInfo<Regressor> generateImmutableOutputInfo() {
143        return new ImmutableRegressionInfo(this);
144    }
145
146    @Override
147    public MutableOutputInfo<Regressor> generateMutableOutputInfo() {
148        return new MutableRegressionInfo(this);
149    }
150
151    @Override
152    public abstract RegressionInfo copy();
153
154    @Override
155    public Iterable<Pair<String,Long>> outputCountsIterable() {
156        ArrayList<Pair<String,Long>> list = new ArrayList<>();
157
158        for (Map.Entry<String,MutableLong> e : countMap.entrySet()) {
159            list.add(new Pair<>(e.getKey(), e.getValue().longValue()));
160        }
161
162        return list;
163    }
164}