001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.ensemble;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021import org.tribuo.Example;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Prediction;
024import org.tribuo.ensemble.EnsembleCombiner;
025import org.tribuo.regression.Regressor;
026
027import java.util.Arrays;
028import java.util.List;
029
030/**
031 * A combiner which performs a weighted or unweighted average of the predicted
032 * regressors independently across the output dimensions.
033 */
034public class AveragingCombiner implements EnsembleCombiner<Regressor> {
035    private static final long serialVersionUID = 1L;
036
037    @Override
038    public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions) {
039        int numPredictions = predictions.size();
040        int dimensions = outputInfo.size();
041        int numUsed = 0;
042        String[] names;
043        double[] mean = new double[dimensions];
044        double[] variance = new double[dimensions];
045        for (Prediction<Regressor> p : predictions) {
046            if (numUsed < p.getNumActiveFeatures()) {
047                numUsed = p.getNumActiveFeatures();
048            }
049            Regressor curValue = p.getOutput();
050            for (int i = 0; i < dimensions; i++) {
051                double value = curValue.getValues()[i];
052                double oldMean = mean[i];
053                mean[i] += (value - oldMean);
054                variance[i] += (value - oldMean) * (value - mean[i]);
055            }
056        }
057        names = predictions.get(0).getOutput().getNames();
058        if (numPredictions > 1) {
059            for (int i = 0; i < dimensions; i++) {
060                variance[i] /= (numPredictions-1);
061            }
062        } else {
063            Arrays.fill(variance,0);
064        }
065
066        Example<Regressor> example = predictions.get(0).getExample();
067        return new Prediction<>(new Regressor(names,mean,variance),numUsed,example);
068    }
069
070    @Override
071    public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions, float[] weights) {
072        if (predictions.size() != weights.length) {
073            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length);
074        }
075        int dimensions = outputInfo.size();
076        int numUsed = 0;
077        String[] names;
078        double[] mean = new double[dimensions];
079        double[] variance = new double[dimensions];
080        double weightSum = 0.0;
081        for (int i = 0; i < weights.length; i++) {
082            Prediction<Regressor> p = predictions.get(i);
083            if (numUsed < p.getNumActiveFeatures()) {
084                numUsed = p.getNumActiveFeatures();
085            }
086            Regressor curValue = p.getOutput();
087            float weight = weights[i];
088            weightSum += weight;
089            for (int j = 0; j < dimensions; j++) {
090                double value = curValue.getValues()[j];
091                double oldMean = mean[j];
092                mean[j] += (weight / weightSum) * (value - oldMean);
093                variance[j] += weight * (value - oldMean) * (value - mean[j]);
094            }
095        }
096        names = predictions.get(0).getOutput().getNames();
097        if (weights.length > 1) {
098            for (int i = 0; i < dimensions; i++) {
099                variance[i] /= (weightSum-1);
100            }
101        } else {
102            Arrays.fill(variance,0);
103        }
104
105        Example<Regressor> example = predictions.get(0).getExample();
106        return new Prediction<>(new Regressor(names,mean,variance),numUsed,example);
107    }
108
109    @Override
110    public String toString() {
111        return "MultipleOutputAveragingCombiner()";
112    }
113
114    @Override
115    public ConfiguredObjectProvenance getProvenance() {
116        return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner");
117    }
118}