001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.tree;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Output;
026import org.tribuo.Prediction;
027import org.tribuo.SparseModel;
028import org.tribuo.math.la.SparseVector;
029import org.tribuo.provenance.ModelProvenance;
030
031import java.util.ArrayList;
032import java.util.Collections;
033import java.util.Comparator;
034import java.util.HashMap;
035import java.util.HashSet;
036import java.util.LinkedHashSet;
037import java.util.LinkedList;
038import java.util.List;
039import java.util.Map;
040import java.util.Optional;
041import java.util.PriorityQueue;
042import java.util.Queue;
043import java.util.Set;
044
045/**
046 * A {@link Model} wrapped around a decision tree root {@link Node}.
047 */
048public class TreeModel<T extends Output<T>> extends SparseModel<T> {
049    private static final long serialVersionUID = 3L;
050
051    private final Node<T> root;
052
053    /**
054     * Constructs a trained decision tree model.
055     * @param name The model name.
056     * @param description The model provenance.
057     * @param featureIDMap The feature id map.
058     * @param outputIDInfo The output info.
059     * @param generatesProbabilities Does this model emit probabilities.
060     * @param root The root node of the tree.
061     */
062    TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Node<T> root) {
063        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, gatherActiveFeatures(featureIDMap,root));
064        this.root = root;
065    }
066
067    /**
068     * Constructs a trained decision tree model.
069     * <p>
070     * Only used when the tree has multiple roots, should only be called from
071     * subclassed when *all* other methods are overridden.
072     * @param name The model name.
073     * @param description The model provenance.
074     * @param featureIDMap The feature id map.
075     * @param outputIDInfo The output info.
076     * @param generatesProbabilities Does this model emit probabilities.
077     * @param activeFeatures The active feature set of the model.
078     */
079    protected TreeModel(String name, ModelProvenance description,
080                        ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo,
081                        boolean generatesProbabilities, Map<String,List<String>> activeFeatures) {
082        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, activeFeatures);
083        this.root = null;
084    }
085
086    private static <T extends Output<T>> Map<String,List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Node<T> root) {
087        Set<String> activeFeatures = new LinkedHashSet<>();
088
089        Queue<Node<T>> nodeQueue = new LinkedList<>();
090
091        nodeQueue.offer(root);
092
093        while (!nodeQueue.isEmpty()) {
094            Node<T> node = nodeQueue.poll();
095            if ((node != null) && (!node.isLeaf())) {
096                SplitNode<T> splitNode = (SplitNode<T>) node;
097                String featureName = fMap.get(splitNode.getFeatureID()).getName();
098                activeFeatures.add(featureName);
099                nodeQueue.offer(splitNode.getGreaterThan());
100                nodeQueue.offer(splitNode.getLessThanOrEqual());
101            }
102        }
103        return Collections.singletonMap(Model.ALL_OUTPUTS,new ArrayList<>(activeFeatures));
104    }
105
106    @Override
107    public Prediction<T> predict(Example<T> example) {
108        //
109        // Ensures we handle collisions correctly
110        SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false);
111        if (vec.numActiveElements() == 0) {
112            throw new IllegalArgumentException("No features found in Example " + example.toString());
113        }
114        Node<T> oldNode = root;
115        Node<T> curNode = root;
116
117        while (curNode != null) {
118            oldNode = curNode;
119            curNode = oldNode.getNextNode(vec);
120        }
121
122        //
123        // oldNode must be a LeafNode.
124        return ((LeafNode<T>) oldNode).getPrediction(vec.numActiveElements(),example);
125    }
126
127    @Override
128    public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) {
129        int maxFeatures = n < 0 ? featureIDMap.size() : n;
130        Map<String,Integer> featureCounts = new HashMap<>();
131
132        Queue<Node<T>> nodeQueue = new LinkedList<>();
133
134        nodeQueue.offer(root);
135
136        while (!nodeQueue.isEmpty()) {
137            Node<T> node = nodeQueue.poll();
138            if ((node != null) && !node.isLeaf()) {
139                SplitNode<T> splitNode = (SplitNode<T>) node;
140                String featureName = featureIDMap.get(splitNode.getFeatureID()).getName();
141                featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1);
142                nodeQueue.offer(splitNode.getGreaterThan());
143                nodeQueue.offer(splitNode.getLessThanOrEqual());
144            }
145        }
146
147        Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
148        PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator);
149
150        for (Map.Entry<String, Integer> e : featureCounts.entrySet()) {
151            Pair<String,Double> cur = new Pair<>(e.getKey(), (double) e.getValue());
152            if (q.size() < maxFeatures) {
153                q.offer(cur);
154            } else if (comparator.compare(cur, q.peek()) > 0) {
155                q.poll();
156                q.offer(cur);
157            }
158        }
159        List<Pair<String,Double>> list = new ArrayList<>();
160        while (q.size() > 0) {
161            list.add(q.poll());
162        }
163        Collections.reverse(list);
164
165        Map<String,List<Pair<String,Double>>> map = new HashMap<>();
166        map.put(Model.ALL_OUTPUTS, list);
167
168        return map;
169    }
170
171    @Override
172    public Optional<Excuse<T>> getExcuse(Example<T> example) {
173        List<String> list = new ArrayList<>();
174        //
175        // Ensures we handle collisions correctly
176        SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false);
177        Node<T> oldNode = root;
178        Node<T> curNode = root;
179
180        while (curNode != null) {
181            oldNode = curNode;
182            if (oldNode instanceof SplitNode) {
183                SplitNode<T> node = (SplitNode<T>) curNode;
184                list.add(featureIDMap.get(node.getFeatureID()).getName());
185            }
186            curNode = oldNode.getNextNode(vec);
187        }
188
189        //
190        // oldNode must be a LeafNode.
191        Prediction<T> pred = ((LeafNode<T>) oldNode).getPrediction(vec.numActiveElements(),example);
192
193        List<Pair<String,Double>> pairs = new ArrayList<>();
194        int i = list.size() + 1;
195        for (String s : list) {
196            pairs.add(new Pair<>(s,i+0.0));
197            i--;
198        }
199
200        Map<String,List<Pair<String,Double>>> map = new HashMap<>();
201        map.put(Model.ALL_OUTPUTS,pairs);
202
203        return Optional.of(new Excuse<>(example,pred,map));
204    }
205
206    @Override
207    protected TreeModel<T> copy(String newName, ModelProvenance newProvenance) {
208        return new TreeModel<>(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,root.copy());
209    }
210
211    /**
212     * Returns the set of features which are split on in this tree.
213     * @return The feature names used by this tree.
214     */
215    public Set<String> getFeatures() {
216        Set<String> features = new HashSet<>();
217
218        Queue<Node<T>> nodeQueue = new LinkedList<>();
219
220        nodeQueue.offer(root);
221
222        while (!nodeQueue.isEmpty()) {
223            Node<T> node = nodeQueue.poll();
224            if ((node != null) && !node.isLeaf()) {
225                SplitNode<T> splitNode = (SplitNode<T>) node;
226                features.add(featureIDMap.get(splitNode.getFeatureID()).getName());
227                nodeQueue.offer(splitNode.getGreaterThan());
228                nodeQueue.offer(splitNode.getLessThanOrEqual());
229            }
230        }
231
232        return features;
233    }
234
235    @Override
236    public String toString() {
237        return "TreeModel(description="+provenance.toString()+",\n\t\ttree="+root.toString()+")";
238    }
239
240}