001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.ensemble;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021import org.tribuo.Example;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Prediction;
024import org.tribuo.classification.Label;
025import org.tribuo.ensemble.EnsembleCombiner;
026
027import java.util.LinkedHashMap;
028import java.util.List;
029import java.util.Map;
030
031/**
032 * A combiner which performs a weighted or unweighted vote across the predicted labels.
033 * <p>
034 * This uses the full distribution of predictions from each ensemble member, unlike {@link VotingCombiner}
035 * which uses the most likely prediction for each ensemble member.
036 */
037public final class FullyWeightedVotingCombiner implements EnsembleCombiner<Label> {
038    private static final long serialVersionUID = 1L;
039
040    public FullyWeightedVotingCombiner() {}
041
042    @Override
043    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) {
044        int numPredictions = predictions.size();
045        int numUsed = 0;
046        double weight = 1.0 / numPredictions;
047        double sum = 0.0;
048        double[] score = new double[outputInfo.size()];
049        for (Prediction<Label> p : predictions) {
050            if (numUsed < p.getNumActiveFeatures()) {
051                numUsed = p.getNumActiveFeatures();
052            }
053            for (Label e : p.getOutputScores().values()) {
054                double curScore = weight * e.getScore();
055                sum += curScore;
056                score[outputInfo.getID(e)] += curScore;
057            }
058        }
059
060        double maxScore = Double.NEGATIVE_INFINITY;
061        Label maxLabel = null;
062        Map<String,Label> predictionMap = new LinkedHashMap<>();
063        for (int i = 0; i < score.length; i++) {
064            String name = outputInfo.getOutput(i).getLabel();
065            Label label = new Label(name,score[i]/sum);
066            predictionMap.put(name,label);
067            if (label.getScore() > maxScore) {
068                maxScore = label.getScore();
069                maxLabel = label;
070            }
071        }
072
073        Example<Label> example = predictions.get(0).getExample();
074
075        return new Prediction<>(maxLabel,predictionMap,numUsed,example,true);
076    }
077
078    @Override
079    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) {
080        if (predictions.size() != weights.length) {
081            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length);
082        }
083        int numUsed = 0;
084        double sum = 0.0;
085        double[] score = new double[outputInfo.size()];
086        for (int i = 0; i < weights.length; i++) {
087            Prediction<Label> p = predictions.get(i);
088            if (numUsed < p.getNumActiveFeatures()) {
089                numUsed = p.getNumActiveFeatures();
090            }
091            for (Label e : p.getOutputScores().values()) {
092                double curScore = weights[i] * e.getScore();
093                sum += curScore;
094                score[outputInfo.getID(e)] += curScore;
095            }
096        }
097
098        double maxScore = Double.NEGATIVE_INFINITY;
099        Label maxLabel = null;
100        Map<String,Label> predictionMap = new LinkedHashMap<>();
101        for (int i = 0; i < score.length; i++) {
102            String name = outputInfo.getOutput(i).getLabel();
103            Label label = new Label(name,score[i]/sum);
104            predictionMap.put(name,label);
105            if (label.getScore() > maxScore) {
106                maxScore = label.getScore();
107                maxLabel = label;
108            }
109        }
110
111        Example<Label> example = predictions.get(0).getExample();
112
113        return new Prediction<>(maxLabel,predictionMap,numUsed,example,true);
114    }
115
116    @Override
117    public String toString() {
118        return "FullyWeightedVotingCombiner()";
119    }
120
121    @Override
122    public ConfiguredObjectProvenance getProvenance() {
123        return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner");
124    }
125}