001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.ensemble;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Output;
026import org.tribuo.provenance.EnsembleModelProvenance;
027import org.tribuo.provenance.ModelProvenance;
028
029import java.util.ArrayList;
030import java.util.Collections;
031import java.util.Comparator;
032import java.util.HashMap;
033import java.util.List;
034import java.util.Map;
035import java.util.Map.Entry;
036import java.util.Optional;
037import java.util.PriorityQueue;
038
039/**
040 * A model which contains a list of other {@link Model}s.
041 */
042public abstract class EnsembleModel<T extends Output<T>> extends Model<T> {
043    private static final long serialVersionUID = 1L;
044
045    protected final List<Model<T>> models;
046
047    public EnsembleModel(String name, EnsembleModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<Model<T>> newModels) {
048        super(name,description,featureIDMap,outputIDInfo,true);
049        models = Collections.unmodifiableList(newModels);
050    }
051
052    /**
053     * Returns an unmodifiable view on the ensemble members.
054     * @return The ensemble members.
055     */
056    public List<Model<T>> getModels() {
057        return models;
058    }
059
060    /**
061     * The number of ensemble members.
062     * @return The ensemble size.
063     */
064    public int getNumModels() {
065        return models.size();
066    }
067
068    @Override
069    public abstract Optional<Excuse<T>> getExcuse(Example<T> example);
070
071    @Override
072    public EnsembleModelProvenance getProvenance() {
073        return (EnsembleModelProvenance) provenance;
074    }
075
076    @Override
077    protected Model<T> copy(String name, ModelProvenance newProvenance) {
078        return copy(name,(EnsembleModelProvenance)newProvenance,new ArrayList<>(models));
079    }
080
081    protected abstract EnsembleModel<T> copy(String name, EnsembleModelProvenance newProvenance, List<Model<T>> newModels);
082
083    @Override
084    public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) {
085        Map<String, Map<String,Pair<String,Double>>> featureMap = new HashMap<>();
086        for (Model<T> model : models) {
087            Map<String, List<Pair<String,Double>>> scoredFeatures = model.getTopFeatures(n);
088            for (Entry<String,List<Pair<String,Double>>> e : scoredFeatures.entrySet()) {
089                Map<String, Pair<String, Double>> curSet = featureMap.computeIfAbsent(e.getKey(), k -> new HashMap<>());
090                for (Pair<String,Double> f : e.getValue()) {
091                    Pair<String,Double> tmp = new Pair<>(f.getA(),f.getB()/models.size());
092                    curSet.merge(tmp.getA(),tmp,(Pair<String,Double> p1, Pair<String,Double> p2) -> new Pair<>(p1.getA(),p1.getB()+p2.getB()) );
093                }
094            }
095        }
096
097        int maxFeatures = n < 0 ? featureIDMap.size() : n;
098
099        Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
100        Map<String, List<Pair<String,Double>>> map = new HashMap<>();
101        for (Entry<String, Map<String, Pair<String,Double>>> e : featureMap.entrySet()) {
102
103            PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator);
104            for (Pair<String,Double> cur : e.getValue().values()) {
105                if (q.size() < maxFeatures) {
106                    q.offer(cur);
107                } else if (comparator.compare(cur, q.peek()) > 0) {
108                    q.poll();
109                    q.offer(cur);
110                }
111            }
112            List<Pair<String,Double>> list = new ArrayList<>();
113            while (q.size() > 0) {
114                list.add(q.poll());
115            }
116            Collections.reverse(list);
117            map.put(e.getKey(), list);
118        }
119
120        return map;
121    }
122
123}