001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.ensemble; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Output; 026import org.tribuo.Prediction; 027import org.tribuo.provenance.EnsembleModelProvenance; 028import org.tribuo.util.Util; 029 030import java.util.ArrayList; 031import java.util.Arrays; 032import java.util.HashMap; 033import java.util.List; 034import java.util.Map; 035import java.util.Map.Entry; 036import java.util.Optional; 037 038/** 039 * An ensemble model that uses weights to combine the ensemble member predictions. 040 */ 041public final class WeightedEnsembleModel<T extends Output<T>> extends EnsembleModel<T> { 042 private static final long serialVersionUID = 1L; 043 044 protected final float[] weights; 045 046 protected final EnsembleCombiner<T> combiner; 047 048 public WeightedEnsembleModel(String name, EnsembleModelProvenance description, ImmutableFeatureMap featureIDMap, 049 ImmutableOutputInfo<T> outputIDInfo, 050 List<Model<T>> newModels, EnsembleCombiner<T> combiner) { 051 this(name,description,featureIDMap,outputIDInfo,newModels, combiner, Util.generateUniformVector(newModels.size(), 1.0f/newModels.size())); 052 } 053 054 public WeightedEnsembleModel(String name, EnsembleModelProvenance description, ImmutableFeatureMap featureIDMap, 055 ImmutableOutputInfo<T> outputIDInfo, 056 List<Model<T>> newModels, EnsembleCombiner<T> combiner, float[] weights) { 057 super(name,description,featureIDMap,outputIDInfo,newModels); 058 this.weights = Arrays.copyOf(weights,weights.length); 059 this.combiner = combiner; 060 } 061 062 @Override 063 public Prediction<T> predict(Example<T> example) { 064 List<Prediction<T>> predictions = new ArrayList<>(); 065 for (Model<T> model : models) { 066 predictions.add(model.predict(example)); 067 } 068 069 return combiner.combine(outputIDInfo,predictions,weights); 070 } 071 072 @Override 073 public Optional<Excuse<T>> getExcuse(Example<T> example) { 074 Map<String, Map<String,Double>> map = new HashMap<>(); 075 Prediction<T> prediction = predict(example); 076 List<Excuse<T>> excuses = new ArrayList<>(); 077 078 for (int i = 0; i < models.size(); i++) { 079 Optional<Excuse<T>> excuse = models.get(i).getExcuse(example); 080 if (excuse.isPresent()) { 081 excuses.add(excuse.get()); 082 Map<String, List<Pair<String,Double>>> m = excuse.get().getScores(); 083 for (Entry<String, List<Pair<String,Double>>> e : m.entrySet()) { 084 Map<String, Double> innerMap = map.computeIfAbsent(e.getKey(), k -> new HashMap<>()); 085 for (Pair<String,Double> p : e.getValue()) { 086 innerMap.merge(p.getA(), p.getB() * weights[i], Double::sum); 087 } 088 } 089 } 090 } 091 092 if (map.isEmpty()) { 093 return Optional.empty(); 094 } else { 095 Map<String, List<Pair<String, Double>>> outputMap = new HashMap<>(); 096 for (Entry<String, Map<String, Double>> label : map.entrySet()) { 097 List<Pair<String, Double>> list = new ArrayList<>(); 098 099 for (Entry<String, Double> entry : label.getValue().entrySet()) { 100 list.add(new Pair<>(entry.getKey(), entry.getValue())); 101 } 102 103 list.sort((Pair<String, Double> o1, Pair<String, Double> o2) -> o2.getB().compareTo(o1.getB())); 104 outputMap.put(label.getKey(), list); 105 } 106 107 return Optional.of(new EnsembleExcuse<>(example, prediction, outputMap, excuses)); 108 } 109 } 110 111 @Override 112 protected EnsembleModel<T> copy(String name, EnsembleModelProvenance newProvenance, List<Model<T>> newModels) { 113 return new WeightedEnsembleModel<>(name,newProvenance,featureIDMap,outputIDInfo,newModels,combiner); 114 } 115}