001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.ensemble; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021import org.tribuo.Example; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Prediction; 024import org.tribuo.classification.Label; 025import org.tribuo.ensemble.EnsembleCombiner; 026 027import java.util.LinkedHashMap; 028import java.util.List; 029import java.util.Map; 030 031/** 032 * A combiner which performs a weighted or unweighted vote across the predicted labels. 033 * <p> 034 * This uses the full distribution of predictions from each ensemble member, unlike {@link VotingCombiner} 035 * which uses the most likely prediction for each ensemble member. 036 */ 037public final class FullyWeightedVotingCombiner implements EnsembleCombiner<Label> { 038 private static final long serialVersionUID = 1L; 039 040 public FullyWeightedVotingCombiner() {} 041 042 @Override 043 public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) { 044 int numPredictions = predictions.size(); 045 int numUsed = 0; 046 double weight = 1.0 / numPredictions; 047 double sum = 0.0; 048 double[] score = new double[outputInfo.size()]; 049 for (Prediction<Label> p : predictions) { 050 if (numUsed < p.getNumActiveFeatures()) { 051 numUsed = p.getNumActiveFeatures(); 052 } 053 for (Label e : p.getOutputScores().values()) { 054 double curScore = weight * e.getScore(); 055 sum += curScore; 056 score[outputInfo.getID(e)] += curScore; 057 } 058 } 059 060 double maxScore = Double.NEGATIVE_INFINITY; 061 Label maxLabel = null; 062 Map<String,Label> predictionMap = new LinkedHashMap<>(); 063 for (int i = 0; i < score.length; i++) { 064 String name = outputInfo.getOutput(i).getLabel(); 065 Label label = new Label(name,score[i]/sum); 066 predictionMap.put(name,label); 067 if (label.getScore() > maxScore) { 068 maxScore = label.getScore(); 069 maxLabel = label; 070 } 071 } 072 073 Example<Label> example = predictions.get(0).getExample(); 074 075 return new Prediction<>(maxLabel,predictionMap,numUsed,example,true); 076 } 077 078 @Override 079 public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) { 080 if (predictions.size() != weights.length) { 081 throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length); 082 } 083 int numUsed = 0; 084 double sum = 0.0; 085 double[] score = new double[outputInfo.size()]; 086 for (int i = 0; i < weights.length; i++) { 087 Prediction<Label> p = predictions.get(i); 088 if (numUsed < p.getNumActiveFeatures()) { 089 numUsed = p.getNumActiveFeatures(); 090 } 091 for (Label e : p.getOutputScores().values()) { 092 double curScore = weights[i] * e.getScore(); 093 sum += curScore; 094 score[outputInfo.getID(e)] += curScore; 095 } 096 } 097 098 double maxScore = Double.NEGATIVE_INFINITY; 099 Label maxLabel = null; 100 Map<String,Label> predictionMap = new LinkedHashMap<>(); 101 for (int i = 0; i < score.length; i++) { 102 String name = outputInfo.getOutput(i).getLabel(); 103 Label label = new Label(name,score[i]/sum); 104 predictionMap.put(name,label); 105 if (label.getScore() > maxScore) { 106 maxScore = label.getScore(); 107 maxLabel = label; 108 } 109 } 110 111 Example<Label> example = predictions.get(0).getExample(); 112 113 return new Prediction<>(maxLabel,predictionMap,numUsed,example,true); 114 } 115 116 @Override 117 public String toString() { 118 return "FullyWeightedVotingCombiner()"; 119 } 120 121 @Override 122 public ConfiguredObjectProvenance getProvenance() { 123 return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner"); 124 } 125}