001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.ensemble; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021import org.tribuo.Example; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Prediction; 024import org.tribuo.classification.Label; 025import org.tribuo.ensemble.EnsembleCombiner; 026 027import java.util.LinkedHashMap; 028import java.util.List; 029import java.util.Map; 030 031/** 032 * A combiner which performs a weighted or unweighted vote across the predicted labels. 033 * <p> 034 * This uses the most likely prediction from each ensemble member, unlike {@link FullyWeightedVotingCombiner} 035 * which uses the full distribution of predictions for each ensemble member. 036 */ 037public final class VotingCombiner implements EnsembleCombiner<Label> { 038 private static final long serialVersionUID = 1L; 039 040 public VotingCombiner() {} 041 042 @Override 043 public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) { 044 int numPredictions = predictions.size(); 045 int numUsed = 0; 046 double weight = 1.0 / numPredictions; 047 double[] score = new double[outputInfo.size()]; 048 for (Prediction<Label> p : predictions) { 049 if (numUsed < p.getNumActiveFeatures()) { 050 numUsed = p.getNumActiveFeatures(); 051 } 052 score[outputInfo.getID(p.getOutput())] += weight; 053 } 054 055 double maxScore = Double.NEGATIVE_INFINITY; 056 Label maxLabel = null; 057 Map<String,Label> predictionMap = new LinkedHashMap<>(); 058 for (int i = 0; i < score.length; i++) { 059 String name = outputInfo.getOutput(i).getLabel(); 060 Label label = new Label(name,score[i]); 061 predictionMap.put(name,label); 062 if (label.getScore() > maxScore) { 063 maxScore = label.getScore(); 064 maxLabel = label; 065 } 066 } 067 068 Example<Label> example = predictions.get(0).getExample(); 069 070 return new Prediction<>(maxLabel,predictionMap,numUsed,example,true); 071 } 072 073 @Override 074 public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) { 075 if (predictions.size() != weights.length) { 076 throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length); 077 } 078 int numUsed = 0; 079 double sum = 0.0; 080 double[] score = new double[outputInfo.size()]; 081 for (int i = 0; i < weights.length; i++) { 082 Prediction<Label> p = predictions.get(i); 083 if (numUsed < p.getNumActiveFeatures()) { 084 numUsed = p.getNumActiveFeatures(); 085 } 086 score[outputInfo.getID(p.getOutput())] += weights[i]; 087 sum += weights[i]; 088 } 089 090 double maxScore = Double.NEGATIVE_INFINITY; 091 Label maxLabel = null; 092 Map<String,Label> predictionMap = new LinkedHashMap<>(); 093 for (int i = 0; i < score.length; i++) { 094 String name = outputInfo.getOutput(i).getLabel(); 095 Label label = new Label(name,score[i]/sum); 096 predictionMap.put(name,label); 097 if (label.getScore() > maxScore) { 098 maxScore = label.getScore(); 099 maxLabel = label; 100 } 101 } 102 103 Example<Label> example = predictions.get(0).getExample(); 104 105 return new Prediction<>(maxLabel,predictionMap,numUsed,example,true); 106 } 107 108 @Override 109 public String toString() { 110 return "VotingCombiner()"; 111 } 112 113 @Override 114 public ConfiguredObjectProvenance getProvenance() { 115 return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner"); 116 } 117}