001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.ensemble; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021import org.tribuo.Example; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Prediction; 024import org.tribuo.ensemble.EnsembleCombiner; 025import org.tribuo.regression.Regressor; 026 027import java.util.Arrays; 028import java.util.List; 029 030/** 031 * A combiner which performs a weighted or unweighted average of the predicted 032 * regressors independently across the output dimensions. 033 */ 034public class AveragingCombiner implements EnsembleCombiner<Regressor> { 035 private static final long serialVersionUID = 1L; 036 037 @Override 038 public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions) { 039 int numPredictions = predictions.size(); 040 int dimensions = outputInfo.size(); 041 int numUsed = 0; 042 String[] names; 043 double[] mean = new double[dimensions]; 044 double[] variance = new double[dimensions]; 045 for (Prediction<Regressor> p : predictions) { 046 if (numUsed < p.getNumActiveFeatures()) { 047 numUsed = p.getNumActiveFeatures(); 048 } 049 Regressor curValue = p.getOutput(); 050 for (int i = 0; i < dimensions; i++) { 051 double value = curValue.getValues()[i]; 052 double oldMean = mean[i]; 053 mean[i] += (value - oldMean); 054 variance[i] += (value - oldMean) * (value - mean[i]); 055 } 056 } 057 names = predictions.get(0).getOutput().getNames(); 058 if (numPredictions > 1) { 059 for (int i = 0; i < dimensions; i++) { 060 variance[i] /= (numPredictions-1); 061 } 062 } else { 063 Arrays.fill(variance,0); 064 } 065 066 Example<Regressor> example = predictions.get(0).getExample(); 067 return new Prediction<>(new Regressor(names,mean,variance),numUsed,example); 068 } 069 070 @Override 071 public Prediction<Regressor> combine(ImmutableOutputInfo<Regressor> outputInfo, List<Prediction<Regressor>> predictions, float[] weights) { 072 if (predictions.size() != weights.length) { 073 throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length); 074 } 075 int dimensions = outputInfo.size(); 076 int numUsed = 0; 077 String[] names; 078 double[] mean = new double[dimensions]; 079 double[] variance = new double[dimensions]; 080 double weightSum = 0.0; 081 for (int i = 0; i < weights.length; i++) { 082 Prediction<Regressor> p = predictions.get(i); 083 if (numUsed < p.getNumActiveFeatures()) { 084 numUsed = p.getNumActiveFeatures(); 085 } 086 Regressor curValue = p.getOutput(); 087 float weight = weights[i]; 088 weightSum += weight; 089 for (int j = 0; j < dimensions; j++) { 090 double value = curValue.getValues()[j]; 091 double oldMean = mean[j]; 092 mean[j] += (weight / weightSum) * (value - oldMean); 093 variance[j] += weight * (value - oldMean) * (value - mean[j]); 094 } 095 } 096 names = predictions.get(0).getOutput().getNames(); 097 if (weights.length > 1) { 098 for (int i = 0; i < dimensions; i++) { 099 variance[i] /= (weightSum-1); 100 } 101 } else { 102 Arrays.fill(variance,0); 103 } 104 105 Example<Regressor> example = predictions.get(0).getExample(); 106 return new Prediction<>(new Regressor(names,mean,variance),numUsed,example); 107 } 108 109 @Override 110 public String toString() { 111 return "MultipleOutputAveragingCombiner()"; 112 } 113 114 @Override 115 public ConfiguredObjectProvenance getProvenance() { 116 return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner"); 117 } 118}