001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.ensemble;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021import org.tribuo.Example;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Prediction;
024import org.tribuo.classification.Label;
025import org.tribuo.ensemble.EnsembleCombiner;
026
027import java.util.LinkedHashMap;
028import java.util.List;
029import java.util.Map;
030
031/**
032 * A combiner which performs a weighted or unweighted vote across the predicted labels.
033 * <p>
034 * This uses the most likely prediction from each ensemble member, unlike {@link FullyWeightedVotingCombiner}
035 * which uses the full distribution of predictions for each ensemble member.
036 */
037public final class VotingCombiner implements EnsembleCombiner<Label> {
038    private static final long serialVersionUID = 1L;
039
040    public VotingCombiner() {}
041
042    @Override
043    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) {
044        int numPredictions = predictions.size();
045        int numUsed = 0;
046        double weight = 1.0 / numPredictions;
047        double[] score = new double[outputInfo.size()];
048        for (Prediction<Label> p : predictions) {
049            if (numUsed < p.getNumActiveFeatures()) {
050                numUsed = p.getNumActiveFeatures();
051            }
052            score[outputInfo.getID(p.getOutput())] += weight;
053        }
054
055        double maxScore = Double.NEGATIVE_INFINITY;
056        Label maxLabel = null;
057        Map<String,Label> predictionMap = new LinkedHashMap<>();
058        for (int i = 0; i < score.length; i++) {
059            String name = outputInfo.getOutput(i).getLabel();
060            Label label = new Label(name,score[i]);
061            predictionMap.put(name,label);
062            if (label.getScore() > maxScore) {
063                maxScore = label.getScore();
064                maxLabel = label;
065            }
066        }
067
068        Example<Label> example = predictions.get(0).getExample();
069
070        return new Prediction<>(maxLabel,predictionMap,numUsed,example,true);
071    }
072
073    @Override
074    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) {
075        if (predictions.size() != weights.length) {
076            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length);
077        }
078        int numUsed = 0;
079        double sum = 0.0;
080        double[] score = new double[outputInfo.size()];
081        for (int i = 0; i < weights.length; i++) {
082            Prediction<Label> p = predictions.get(i);
083            if (numUsed < p.getNumActiveFeatures()) {
084                numUsed = p.getNumActiveFeatures();
085            }
086            score[outputInfo.getID(p.getOutput())] += weights[i];
087            sum += weights[i];
088        }
089
090        double maxScore = Double.NEGATIVE_INFINITY;
091        Label maxLabel = null;
092        Map<String,Label> predictionMap = new LinkedHashMap<>();
093        for (int i = 0; i < score.length; i++) {
094            String name = outputInfo.getOutput(i).getLabel();
095            Label label = new Label(name,score[i]/sum);
096            predictionMap.put(name,label);
097            if (label.getScore() > maxScore) {
098                maxScore = label.getScore();
099                maxLabel = label;
100            }
101        }
102
103        Example<Label> example = predictions.get(0).getExample();
104
105        return new Prediction<>(maxLabel,predictionMap,numUsed,example,true);
106    }
107
108    @Override
109    public String toString() {
110        return "VotingCombiner()";
111    }
112
113    @Override
114    public ConfiguredObjectProvenance getProvenance() {
115        return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner");
116    }
117}