001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.tree; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Output; 026import org.tribuo.Prediction; 027import org.tribuo.SparseModel; 028import org.tribuo.math.la.SparseVector; 029import org.tribuo.provenance.ModelProvenance; 030 031import java.util.ArrayList; 032import java.util.Collections; 033import java.util.Comparator; 034import java.util.HashMap; 035import java.util.HashSet; 036import java.util.LinkedHashSet; 037import java.util.LinkedList; 038import java.util.List; 039import java.util.Map; 040import java.util.Optional; 041import java.util.PriorityQueue; 042import java.util.Queue; 043import java.util.Set; 044 045/** 046 * A {@link Model} wrapped around a decision tree root {@link Node}. 047 */ 048public class TreeModel<T extends Output<T>> extends SparseModel<T> { 049 private static final long serialVersionUID = 3L; 050 051 private final Node<T> root; 052 053 /** 054 * Constructs a trained decision tree model. 055 * @param name The model name. 056 * @param description The model provenance. 057 * @param featureIDMap The feature id map. 058 * @param outputIDInfo The output info. 059 * @param generatesProbabilities Does this model emit probabilities. 060 * @param root The root node of the tree. 061 */ 062 TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Node<T> root) { 063 super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, gatherActiveFeatures(featureIDMap,root)); 064 this.root = root; 065 } 066 067 /** 068 * Constructs a trained decision tree model. 069 * <p> 070 * Only used when the tree has multiple roots, should only be called from 071 * subclassed when *all* other methods are overridden. 072 * @param name The model name. 073 * @param description The model provenance. 074 * @param featureIDMap The feature id map. 075 * @param outputIDInfo The output info. 076 * @param generatesProbabilities Does this model emit probabilities. 077 * @param activeFeatures The active feature set of the model. 078 */ 079 protected TreeModel(String name, ModelProvenance description, 080 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, 081 boolean generatesProbabilities, Map<String,List<String>> activeFeatures) { 082 super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, activeFeatures); 083 this.root = null; 084 } 085 086 private static <T extends Output<T>> Map<String,List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Node<T> root) { 087 Set<String> activeFeatures = new LinkedHashSet<>(); 088 089 Queue<Node<T>> nodeQueue = new LinkedList<>(); 090 091 nodeQueue.offer(root); 092 093 while (!nodeQueue.isEmpty()) { 094 Node<T> node = nodeQueue.poll(); 095 if ((node != null) && (!node.isLeaf())) { 096 SplitNode<T> splitNode = (SplitNode<T>) node; 097 String featureName = fMap.get(splitNode.getFeatureID()).getName(); 098 activeFeatures.add(featureName); 099 nodeQueue.offer(splitNode.getGreaterThan()); 100 nodeQueue.offer(splitNode.getLessThanOrEqual()); 101 } 102 } 103 return Collections.singletonMap(Model.ALL_OUTPUTS,new ArrayList<>(activeFeatures)); 104 } 105 106 @Override 107 public Prediction<T> predict(Example<T> example) { 108 // 109 // Ensures we handle collisions correctly 110 SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false); 111 if (vec.numActiveElements() == 0) { 112 throw new IllegalArgumentException("No features found in Example " + example.toString()); 113 } 114 Node<T> oldNode = root; 115 Node<T> curNode = root; 116 117 while (curNode != null) { 118 oldNode = curNode; 119 curNode = oldNode.getNextNode(vec); 120 } 121 122 // 123 // oldNode must be a LeafNode. 124 return ((LeafNode<T>) oldNode).getPrediction(vec.numActiveElements(),example); 125 } 126 127 @Override 128 public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) { 129 int maxFeatures = n < 0 ? featureIDMap.size() : n; 130 Map<String,Integer> featureCounts = new HashMap<>(); 131 132 Queue<Node<T>> nodeQueue = new LinkedList<>(); 133 134 nodeQueue.offer(root); 135 136 while (!nodeQueue.isEmpty()) { 137 Node<T> node = nodeQueue.poll(); 138 if ((node != null) && !node.isLeaf()) { 139 SplitNode<T> splitNode = (SplitNode<T>) node; 140 String featureName = featureIDMap.get(splitNode.getFeatureID()).getName(); 141 featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1); 142 nodeQueue.offer(splitNode.getGreaterThan()); 143 nodeQueue.offer(splitNode.getLessThanOrEqual()); 144 } 145 } 146 147 Comparator<Pair<String,Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 148 PriorityQueue<Pair<String,Double>> q = new PriorityQueue<>(maxFeatures, comparator); 149 150 for (Map.Entry<String, Integer> e : featureCounts.entrySet()) { 151 Pair<String,Double> cur = new Pair<>(e.getKey(), (double) e.getValue()); 152 if (q.size() < maxFeatures) { 153 q.offer(cur); 154 } else if (comparator.compare(cur, q.peek()) > 0) { 155 q.poll(); 156 q.offer(cur); 157 } 158 } 159 List<Pair<String,Double>> list = new ArrayList<>(); 160 while (q.size() > 0) { 161 list.add(q.poll()); 162 } 163 Collections.reverse(list); 164 165 Map<String,List<Pair<String,Double>>> map = new HashMap<>(); 166 map.put(Model.ALL_OUTPUTS, list); 167 168 return map; 169 } 170 171 @Override 172 public Optional<Excuse<T>> getExcuse(Example<T> example) { 173 List<String> list = new ArrayList<>(); 174 // 175 // Ensures we handle collisions correctly 176 SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false); 177 Node<T> oldNode = root; 178 Node<T> curNode = root; 179 180 while (curNode != null) { 181 oldNode = curNode; 182 if (oldNode instanceof SplitNode) { 183 SplitNode<T> node = (SplitNode<T>) curNode; 184 list.add(featureIDMap.get(node.getFeatureID()).getName()); 185 } 186 curNode = oldNode.getNextNode(vec); 187 } 188 189 // 190 // oldNode must be a LeafNode. 191 Prediction<T> pred = ((LeafNode<T>) oldNode).getPrediction(vec.numActiveElements(),example); 192 193 List<Pair<String,Double>> pairs = new ArrayList<>(); 194 int i = list.size() + 1; 195 for (String s : list) { 196 pairs.add(new Pair<>(s,i+0.0)); 197 i--; 198 } 199 200 Map<String,List<Pair<String,Double>>> map = new HashMap<>(); 201 map.put(Model.ALL_OUTPUTS,pairs); 202 203 return Optional.of(new Excuse<>(example,pred,map)); 204 } 205 206 @Override 207 protected TreeModel<T> copy(String newName, ModelProvenance newProvenance) { 208 return new TreeModel<>(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,root.copy()); 209 } 210 211 /** 212 * Returns the set of features which are split on in this tree. 213 * @return The feature names used by this tree. 214 */ 215 public Set<String> getFeatures() { 216 Set<String> features = new HashSet<>(); 217 218 Queue<Node<T>> nodeQueue = new LinkedList<>(); 219 220 nodeQueue.offer(root); 221 222 while (!nodeQueue.isEmpty()) { 223 Node<T> node = nodeQueue.poll(); 224 if ((node != null) && !node.isLeaf()) { 225 SplitNode<T> splitNode = (SplitNode<T>) node; 226 features.add(featureIDMap.get(splitNode.getFeatureID()).getName()); 227 nodeQueue.offer(splitNode.getGreaterThan()); 228 nodeQueue.offer(splitNode.getLessThanOrEqual()); 229 } 230 } 231 232 return features; 233 } 234 235 @Override 236 public String toString() { 237 return "TreeModel(description="+provenance.toString()+",\n\t\ttree="+root.toString()+")"; 238 } 239 240}