001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression; 018 019import com.oracle.labs.mlrg.olcut.util.MutableDouble; 020import com.oracle.labs.mlrg.olcut.util.MutableLong; 021import com.oracle.labs.mlrg.olcut.util.MutableNumber; 022import com.oracle.labs.mlrg.olcut.util.Pair; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.MutableOutputInfo; 025import org.tribuo.OutputInfo; 026import org.tribuo.regression.Regressor.DimensionTuple; 027 028import java.util.ArrayList; 029import java.util.Collections; 030import java.util.Comparator; 031import java.util.LinkedHashMap; 032import java.util.Map; 033import java.util.Set; 034import java.util.SortedSet; 035import java.util.TreeMap; 036import java.util.TreeSet; 037 038/** 039 * The base class for regression information using {@link Regressor}s. 040 * <p> 041 * Stores the observed min, max, mean and variance for each dimension. 042 */ 043public abstract class RegressionInfo implements OutputInfo<Regressor> { 044 private static final long serialVersionUID = 2L; 045 046 private static final MutableDouble NAN = new MutableDouble(Double.NaN); 047 048 protected Map<String,MutableDouble> maxMap = new LinkedHashMap<>(); 049 protected Map<String,MutableDouble> minMap = new LinkedHashMap<>(); 050 051 protected Map<String,MutableDouble> meanMap = new LinkedHashMap<>(); 052 protected Map<String,MutableDouble> sumSquaresMap = new LinkedHashMap<>(); 053 054 protected Map<String,MutableLong> countMap = new TreeMap<>(); 055 056 protected long overallCount = 0; 057 058 protected int unknownCount = 0; 059 060 RegressionInfo() { } 061 062 RegressionInfo(RegressionInfo other) { 063 this.maxMap = MutableNumber.copyMap(other.maxMap); 064 this.minMap = MutableNumber.copyMap(other.minMap); 065 this.meanMap = MutableNumber.copyMap(other.meanMap); 066 this.sumSquaresMap = MutableNumber.copyMap(other.sumSquaresMap); 067 this.countMap = MutableNumber.copyMap(other.countMap); 068 this.overallCount = other.overallCount; 069 } 070 071 @Override 072 public int getUnknownCount() { 073 return unknownCount; 074 } 075 076 /** 077 * Returns a set containing a Regressor for each dimension with the minimum value observed. 078 * @return A set of Regressors, each with one active dimension. 079 */ 080 @Override 081 public Set<Regressor> getDomain() { 082 TreeSet<DimensionTuple> outputs = new TreeSet<>(Comparator.comparing(DimensionTuple::getName)); 083 for (Map.Entry<String,MutableDouble> e : minMap.entrySet()) { 084 outputs.add(new DimensionTuple(e.getKey(),e.getValue().doubleValue())); 085 } 086 @SuppressWarnings("unchecked") // DimensionTuple is a subtype of Regressor, and this set is immutable. 087 SortedSet<Regressor> setOutputs = (SortedSet<Regressor>) (SortedSet) Collections.unmodifiableSortedSet(outputs); 088 return setOutputs; 089 } 090 091 /** 092 * Gets the minimum value this RegressionInfo has seen, or NaN if it's not seen anything. 093 * @param name The dimension to check. 094 * @return The minimum value for that dimension. 095 */ 096 public double getMin(String name) { 097 return minMap.getOrDefault(name,NAN).doubleValue(); 098 } 099 100 /** 101 * Gets the maximum value this RegressionInfo has seen, or NaN if it's not seen that dimension. 102 * @param name The dimension to check. 103 * @return The maximum value for that dimension. 104 */ 105 public double getMax(String name) { 106 return maxMap.getOrDefault(name,NAN).doubleValue(); 107 } 108 109 /** 110 * Gets the mean value this RegressionInfo has seen, or NaN if it's not seen that dimension. 111 * @param name The dimension to check. 112 * @return The mean value for that dimension. 113 */ 114 public double getMean(String name) { 115 return meanMap.getOrDefault(name,NAN).doubleValue(); 116 } 117 118 /** 119 * Gets the variance this RegressionInfo has seen, or NaN if it's not seen that dimension. 120 * @param name The dimension to check. 121 * @return The variance for that dimension. 122 */ 123 public double getVariance(String name) { 124 MutableDouble sumSquaresDbl = sumSquaresMap.get(name); 125 if (sumSquaresDbl != null) { 126 return sumSquaresDbl.doubleValue() / (countMap.get(name).longValue()-1); 127 } else { 128 return Double.NaN; 129 } 130 } 131 132 /** 133 * The number of dimensions this OutputInfo has seen. 134 * @return The number of dimensions this OutputInfo has seen. 135 */ 136 @Override 137 public int size() { 138 return countMap.size(); 139 } 140 141 @Override 142 public ImmutableOutputInfo<Regressor> generateImmutableOutputInfo() { 143 return new ImmutableRegressionInfo(this); 144 } 145 146 @Override 147 public MutableOutputInfo<Regressor> generateMutableOutputInfo() { 148 return new MutableRegressionInfo(this); 149 } 150 151 @Override 152 public abstract RegressionInfo copy(); 153 154 @Override 155 public Iterable<Pair<String,Long>> outputCountsIterable() { 156 ArrayList<Pair<String,Long>> list = new ArrayList<>(); 157 158 for (Map.Entry<String,MutableLong> e : countMap.entrySet()) { 159 list.add(new Pair<>(e.getKey(), e.getValue().longValue())); 160 } 161 162 return list; 163 } 164}