001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.ensemble;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Output;
026import org.tribuo.Prediction;
027import org.tribuo.provenance.EnsembleModelProvenance;
028import org.tribuo.util.Util;
029
030import java.util.ArrayList;
031import java.util.Arrays;
032import java.util.HashMap;
033import java.util.List;
034import java.util.Map;
035import java.util.Map.Entry;
036import java.util.Optional;
037
038/**
039 * An ensemble model that uses weights to combine the ensemble member predictions.
040 */
041public final class WeightedEnsembleModel<T extends Output<T>> extends EnsembleModel<T> {
042    private static final long serialVersionUID = 1L;
043
044    protected final float[] weights;
045
046    protected final EnsembleCombiner<T> combiner;
047
048    public WeightedEnsembleModel(String name, EnsembleModelProvenance description, ImmutableFeatureMap featureIDMap,
049                                 ImmutableOutputInfo<T> outputIDInfo,
050                                 List<Model<T>> newModels, EnsembleCombiner<T> combiner) {
051        this(name,description,featureIDMap,outputIDInfo,newModels, combiner, Util.generateUniformVector(newModels.size(), 1.0f/newModels.size()));
052    }
053
054    public WeightedEnsembleModel(String name, EnsembleModelProvenance description, ImmutableFeatureMap featureIDMap,
055                          ImmutableOutputInfo<T> outputIDInfo,
056                          List<Model<T>> newModels, EnsembleCombiner<T> combiner, float[] weights) {
057        super(name,description,featureIDMap,outputIDInfo,newModels);
058        this.weights = Arrays.copyOf(weights,weights.length);
059        this.combiner = combiner;
060    }
061
062    @Override
063    public Prediction<T> predict(Example<T> example) {
064        List<Prediction<T>> predictions = new ArrayList<>();
065        for (Model<T> model : models) {
066            predictions.add(model.predict(example));
067        }
068
069        return combiner.combine(outputIDInfo,predictions,weights);
070    }
071
072    @Override
073    public Optional<Excuse<T>> getExcuse(Example<T> example) {
074        Map<String, Map<String,Double>> map = new HashMap<>();
075        Prediction<T> prediction = predict(example);
076        List<Excuse<T>> excuses = new ArrayList<>();
077
078        for (int i = 0; i < models.size(); i++) {
079            Optional<Excuse<T>> excuse = models.get(i).getExcuse(example);
080            if (excuse.isPresent()) {
081                excuses.add(excuse.get());
082                Map<String, List<Pair<String,Double>>> m = excuse.get().getScores();
083                for (Entry<String, List<Pair<String,Double>>> e : m.entrySet()) {
084                    Map<String, Double> innerMap = map.computeIfAbsent(e.getKey(), k -> new HashMap<>());
085                    for (Pair<String,Double> p : e.getValue()) {
086                        innerMap.merge(p.getA(), p.getB() * weights[i], Double::sum);
087                    }
088                }
089            }
090        }
091
092        if (map.isEmpty()) {
093            return Optional.empty();
094        } else {
095            Map<String, List<Pair<String, Double>>> outputMap = new HashMap<>();
096            for (Entry<String, Map<String, Double>> label : map.entrySet()) {
097                List<Pair<String, Double>> list = new ArrayList<>();
098
099                for (Entry<String, Double> entry : label.getValue().entrySet()) {
100                    list.add(new Pair<>(entry.getKey(), entry.getValue()));
101                }
102
103                list.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB()));
104                outputMap.put(label.getKey(), list);
105            }
106
107            return Optional.of(new EnsembleExcuse<>(example, prediction, outputMap, excuses));
108        }
109    }
110
111    @Override
112    protected EnsembleModel<T> copy(String name, EnsembleModelProvenance newProvenance, List<Model<T>> newModels) {
113        return new WeightedEnsembleModel<>(name,newProvenance,featureIDMap,outputIDInfo,newModels,combiner);
114    }
115}