001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression; 018 019import com.oracle.labs.mlrg.olcut.util.MutableDouble; 020import com.oracle.labs.mlrg.olcut.util.MutableLong; 021import org.tribuo.MutableOutputInfo; 022 023import java.util.Map; 024 025/** 026 * A {@link MutableOutputInfo} for {@link Regressor}s. All observed Regressors must 027 * contain the same named dimensions. 028 */ 029public class MutableRegressionInfo extends RegressionInfo implements MutableOutputInfo<Regressor> { 030 private static final long serialVersionUID = 2L; 031 032 MutableRegressionInfo() { 033 super(); 034 } 035 036 public MutableRegressionInfo(RegressionInfo info) { 037 super(info); 038 } 039 040 @Override 041 public void observe(Regressor output) { 042 if (output == RegressionFactory.UNKNOWN_MULTIPLE_REGRESSOR) { 043 unknownCount++; 044 } else { 045 if (overallCount != 0) { 046 // Validate that the dimensions in this regressor are the same as the ones already observed. 047 String[] names = output.getNames(); 048 if (names.length != countMap.size()) { 049 throw new IllegalArgumentException("Expected this Regressor to contain " + countMap.size() + " dimensions, found " + names.length); 050 } 051 for (String name : names) { 052 if (!countMap.containsKey(name)) { 053 throw new IllegalArgumentException("Regressor contains unexpected dimension named '" +name + "'"); 054 } 055 } 056 } 057 for (Regressor.DimensionTuple r : output) { 058 String name = r.getName(); 059 double value = r.getValue(); 060 061 // Update max and min 062 minMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() < b.doubleValue() ? a : b); 063 maxMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() > b.doubleValue() ? a : b); 064 065 // Update count 066 MutableLong countValue = countMap.computeIfAbsent(name, k -> new MutableLong()); 067 countValue.increment(); 068 069 // Update mean 070 MutableDouble meanValue = meanMap.computeIfAbsent(name, k -> new MutableDouble()); 071 double delta = value - meanValue.doubleValue(); 072 meanValue.increment(delta / countValue.longValue()); 073 074 // Update running sum of squares 075 double delta2 = value - meanValue.doubleValue(); 076 MutableDouble sumSquaresValue = sumSquaresMap.computeIfAbsent(name, k -> new MutableDouble()); 077 sumSquaresValue.increment(delta * delta2); 078 } 079 overallCount++; 080 } 081 } 082 083 @Override 084 public void clear() { 085 maxMap.clear(); 086 minMap.clear(); 087 meanMap.clear(); 088 sumSquaresMap.clear(); 089 countMap.clear(); 090 } 091 092 @Override 093 public MutableRegressionInfo copy() { 094 return new MutableRegressionInfo(this); 095 } 096 097 @Override 098 public String toString() { 099 StringBuilder builder = new StringBuilder(); 100 builder.append("MultipleRegressionOutput("); 101 for (Map.Entry<String,MutableLong> e : countMap.entrySet()) { 102 String name = e.getKey(); 103 long count = e.getValue().longValue(); 104 builder.append(String.format("{name=%s,count=%d,maxMap=%f,min=%f,mean=%f,variance=%f},", 105 name, 106 count, 107 maxMap.get(name).doubleValue(), 108 minMap.get(name).doubleValue(), 109 meanMap.get(name).doubleValue(), 110 (sumSquaresMap.get(name).doubleValue() / (count - 1)) 111 )); 112 } 113 builder.deleteCharAt(builder.length()-1); 114 builder.append(")"); 115 return builder.toString(); 116 } 117 118 @Override 119 public String toReadableString() { 120 return toString(); 121 } 122}