001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression;
018
019import com.oracle.labs.mlrg.olcut.util.MutableDouble;
020import com.oracle.labs.mlrg.olcut.util.MutableLong;
021import org.tribuo.MutableOutputInfo;
022
023import java.util.Map;
024
025/**
026 * A {@link MutableOutputInfo} for {@link Regressor}s. All observed Regressors must
027 * contain the same named dimensions.
028 */
029public class MutableRegressionInfo extends RegressionInfo implements MutableOutputInfo<Regressor> {
030    private static final long serialVersionUID = 2L;
031
032    MutableRegressionInfo() {
033        super();
034    }
035
036    public MutableRegressionInfo(RegressionInfo info) {
037        super(info);
038    }
039
040    @Override
041    public void observe(Regressor output) {
042        if (output == RegressionFactory.UNKNOWN_MULTIPLE_REGRESSOR) {
043            unknownCount++;
044        } else {
045            if (overallCount != 0) {
046                // Validate that the dimensions in this regressor are the same as the ones already observed.
047                String[] names = output.getNames();
048                if (names.length != countMap.size()) {
049                    throw new IllegalArgumentException("Expected this Regressor to contain " + countMap.size() + " dimensions, found " + names.length);
050                }
051                for (String name : names) {
052                    if (!countMap.containsKey(name)) {
053                        throw new IllegalArgumentException("Regressor contains unexpected dimension named '" +name + "'");
054                    }
055                }
056            }
057            for (Regressor.DimensionTuple r : output) {
058                String name = r.getName();
059                double value = r.getValue();
060
061                // Update max and min
062                minMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() < b.doubleValue() ? a : b);
063                maxMap.merge(name, new MutableDouble(value), (a, b) -> a.doubleValue() > b.doubleValue() ? a : b);
064
065                // Update count
066                MutableLong countValue = countMap.computeIfAbsent(name, k -> new MutableLong());
067                countValue.increment();
068
069                // Update mean
070                MutableDouble meanValue = meanMap.computeIfAbsent(name, k -> new MutableDouble());
071                double delta = value - meanValue.doubleValue();
072                meanValue.increment(delta / countValue.longValue());
073
074                // Update running sum of squares
075                double delta2 = value - meanValue.doubleValue();
076                MutableDouble sumSquaresValue = sumSquaresMap.computeIfAbsent(name, k -> new MutableDouble());
077                sumSquaresValue.increment(delta * delta2);
078            }
079            overallCount++;
080        }
081    }
082
083    @Override
084    public void clear() {
085        maxMap.clear();
086        minMap.clear();
087        meanMap.clear();
088        sumSquaresMap.clear();
089        countMap.clear();
090    }
091
092    @Override
093    public MutableRegressionInfo copy() {
094        return new MutableRegressionInfo(this);
095    }
096
097    @Override
098    public String toString() {
099        StringBuilder builder = new StringBuilder();
100        builder.append("MultipleRegressionOutput(");
101        for (Map.Entry<String,MutableLong> e : countMap.entrySet()) {
102            String name = e.getKey();
103            long count = e.getValue().longValue();
104            builder.append(String.format("{name=%s,count=%d,maxMap=%f,min=%f,mean=%f,variance=%f},",
105                    name,
106                    count,
107                    maxMap.get(name).doubleValue(),
108                    minMap.get(name).doubleValue(),
109                    meanMap.get(name).doubleValue(),
110                    (sumSquaresMap.get(name).doubleValue() / (count - 1))
111            ));
112        }
113        builder.deleteCharAt(builder.length()-1);
114        builder.append(")");
115        return builder.toString();
116    }
117
118    @Override
119    public String toReadableString() {
120        return toString();
121    }
122}