001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.ensemble; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Output; 026import org.tribuo.provenance.EnsembleModelProvenance; 027import org.tribuo.provenance.ModelProvenance; 028 029import java.util.ArrayList; 030import java.util.Collections; 031import java.util.Comparator; 032import java.util.HashMap; 033import java.util.List; 034import java.util.Map; 035import java.util.Map.Entry; 036import java.util.Optional; 037import java.util.PriorityQueue; 038 039/** 040 * A model which contains a list of other {@link Model}s. 041 */ 042public abstract class EnsembleModel<T extends Output<T>> extends Model<T> { 043 private static final long serialVersionUID = 1L; 044 045 protected final List<Model<T>> models; 046 047 public EnsembleModel(String name, EnsembleModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<Model<T>> newModels) { 048 super(name,description,featureIDMap,outputIDInfo,true); 049 models = Collections.unmodifiableList(newModels); 050 } 051 052 /** 053 * Returns an unmodifiable view on the ensemble members. 054 * @return The ensemble members. 055 */ 056 public List<Model<T>> getModels() { 057 return models; 058 } 059 060 /** 061 * The number of ensemble members. 062 * @return The ensemble size. 063 */ 064 public int getNumModels() { 065 return models.size(); 066 } 067 068 @Override 069 public abstract Optional<Excuse<T>> getExcuse(Example<T> example); 070 071 @Override 072 public EnsembleModelProvenance getProvenance() { 073 return (EnsembleModelProvenance) provenance; 074 } 075 076 @Override 077 protected Model<T> copy(String name, ModelProvenance newProvenance) { 078 return copy(name,(EnsembleModelProvenance)newProvenance,new ArrayList<>(models)); 079 } 080 081 protected abstract EnsembleModel<T> copy(String name, EnsembleModelProvenance newProvenance, List<Model<T>> newModels); 082 083 @Override 084 public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) { 085 Map<String, Map<String,Pair<String,Double>>> featureMap = new HashMap<>(); 086 for (Model<T> model : models) { 087 Map<String, List<Pair<String,Double>>> scoredFeatures = model.getTopFeatures(n); 088 for (Entry<String,List<Pair<String,Double>>> e : scoredFeatures.entrySet()) { 089 Map<String, Pair<String, Double>> curSet = featureMap.computeIfAbsent(e.getKey(), k -> new HashMap<>()); 090 for (Pair<String,Double> f : e.getValue()) { 091 Pair<String,Double> tmp = new Pair<>(f.getA(),f.getB()/models.size()); 092 curSet.merge(tmp.getA(),tmp,(Pair<String,Double> p1, Pair<String,Double> p2) -> new Pair<>(p1.getA(),p1.getB()+p2.getB()) ); 093 } 094 } 095 } 096 097 int maxFeatures = n < 0 ? featureIDMap.size() : n; 098 099 Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 100 Map<String, List<Pair<String,Double>>> map = new HashMap<>(); 101 for (Entry<String, Map<String, Pair<String,Double>>> e : featureMap.entrySet()) { 102 103 PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator); 104 for (Pair<String,Double> cur : e.getValue().values()) { 105 if (q.size() < maxFeatures) { 106 q.offer(cur); 107 } else if (comparator.compare(cur, q.peek()) > 0) { 108 q.poll(); 109 q.offer(cur); 110 } 111 } 112 List<Pair<String,Double>> list = new ArrayList<>(); 113 while (q.size() > 0) { 114 list.add(q.poll()); 115 } 116 Collections.reverse(list); 117 map.put(e.getKey(), list); 118 } 119 120 return map; 121 } 122 123}