001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Prediction; 026import org.tribuo.common.tree.LeafNode; 027import org.tribuo.common.tree.Node; 028import org.tribuo.common.tree.SplitNode; 029import org.tribuo.common.tree.TreeModel; 030import org.tribuo.math.la.SparseVector; 031import org.tribuo.provenance.ModelProvenance; 032import org.tribuo.regression.Regressor; 033import org.tribuo.regression.Regressor.DimensionTuple; 034 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.Comparator; 038import java.util.HashMap; 039import java.util.HashSet; 040import java.util.LinkedHashSet; 041import java.util.LinkedList; 042import java.util.List; 043import java.util.Map; 044import java.util.Optional; 045import java.util.PriorityQueue; 046import java.util.Queue; 047import java.util.Set; 048 049/** 050 * A {@link Model} wrapped around a list of decision tree root {@link Node}s used 051 * to generate independent predictions for each dimension in a regression. 052 */ 053public final class IndependentRegressionTreeModel extends TreeModel<Regressor> { 054 private static final long serialVersionUID = 1L; 055 056 private final Map<String,Node<Regressor>> roots; 057 058 IndependentRegressionTreeModel(String name, ModelProvenance description, 059 ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, boolean generatesProbabilities, 060 Map<String,Node<Regressor>> roots) { 061 super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, gatherActiveFeatures(featureIDMap,roots)); 062 this.roots = roots; 063 } 064 065 private static Map<String,List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Map<String,Node<Regressor>> roots) { 066 HashMap<String,List<String>> outputMap = new HashMap<>(); 067 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 068 Set<String> activeFeatures = new LinkedHashSet<>(); 069 070 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 071 072 nodeQueue.offer(e.getValue()); 073 074 while (!nodeQueue.isEmpty()) { 075 Node<Regressor> node = nodeQueue.poll(); 076 if ((node != null) && (!node.isLeaf())) { 077 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 078 String featureName = fMap.get(splitNode.getFeatureID()).getName(); 079 activeFeatures.add(featureName); 080 nodeQueue.offer(splitNode.getGreaterThan()); 081 nodeQueue.offer(splitNode.getLessThanOrEqual()); 082 } 083 } 084 outputMap.put(e.getKey(), new ArrayList<>(activeFeatures)); 085 } 086 return outputMap; 087 } 088 089 @Override 090 public Prediction<Regressor> predict(Example<Regressor> example) { 091 // 092 // Ensures we handle collisions correctly 093 SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false); 094 if (vec.numActiveElements() == 0) { 095 throw new IllegalArgumentException("No features found in Example " + example.toString()); 096 } 097 098 List<Prediction<Regressor>> predictionList = new ArrayList<>(); 099 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 100 Node<Regressor> oldNode = e.getValue(); 101 Node<Regressor> curNode = e.getValue(); 102 103 while (curNode != null) { 104 oldNode = curNode; 105 curNode = oldNode.getNextNode(vec); 106 } 107 108 // 109 // oldNode must be a LeafNode. 110 predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example)); 111 } 112 return combine(predictionList); 113 } 114 115 @Override 116 public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) { 117 int maxFeatures = n < 0 ? featureIDMap.size() : n; 118 119 Map<String, List<Pair<String, Double>>> map = new HashMap<>(); 120 Map<String, Integer> featureCounts = new HashMap<>(); 121 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 122 123 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 124 featureCounts.clear(); 125 nodeQueue.clear(); 126 127 nodeQueue.offer(e.getValue()); 128 129 while (!nodeQueue.isEmpty()) { 130 Node<Regressor> node = nodeQueue.poll(); 131 if ((node != null) && !node.isLeaf()) { 132 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 133 String featureName = featureIDMap.get(splitNode.getFeatureID()).getName(); 134 featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1); 135 nodeQueue.offer(splitNode.getGreaterThan()); 136 nodeQueue.offer(splitNode.getLessThanOrEqual()); 137 } 138 } 139 140 Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB())); 141 PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator); 142 143 for (Map.Entry<String, Integer> featureCount : featureCounts.entrySet()) { 144 Pair<String, Double> cur = new Pair<>(featureCount.getKey(), (double) featureCount.getValue()); 145 if (q.size() < maxFeatures) { 146 q.offer(cur); 147 } else if (comparator.compare(cur, q.peek()) > 0) { 148 q.poll(); 149 q.offer(cur); 150 } 151 } 152 List<Pair<String, Double>> list = new ArrayList<>(); 153 while (q.size() > 0) { 154 list.add(q.poll()); 155 } 156 Collections.reverse(list); 157 158 map.put(e.getKey(), list); 159 } 160 161 return map; 162 } 163 164 @Override 165 public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) { 166 SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false); 167 if (vec.numActiveElements() == 0) { 168 return Optional.empty(); 169 } 170 171 List<String> list = new ArrayList<>(); 172 List<Prediction<Regressor>> predList = new ArrayList<>(); 173 Map<String, List<Pair<String, Double>>> map = new HashMap<>(); 174 175 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 176 list.clear(); 177 178 // 179 // Ensures we handle collisions correctly 180 Node<Regressor> oldNode = e.getValue(); 181 Node<Regressor> curNode = e.getValue(); 182 183 while (curNode != null) { 184 oldNode = curNode; 185 if (oldNode instanceof SplitNode) { 186 SplitNode<?> node = (SplitNode<?>) curNode; 187 list.add(featureIDMap.get(node.getFeatureID()).getName()); 188 } 189 curNode = oldNode.getNextNode(vec); 190 } 191 192 // 193 // oldNode must be a LeafNode. 194 predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example)); 195 196 List<Pair<String, Double>> pairs = new ArrayList<>(); 197 int i = list.size() + 1; 198 for (String s : list) { 199 pairs.add(new Pair<>(s, i + 0.0)); 200 i--; 201 } 202 203 map.put(e.getKey(), pairs); 204 } 205 Prediction<Regressor> combinedPrediction = combine(predList); 206 207 return Optional.of(new Excuse<>(example,combinedPrediction,map)); 208 } 209 210 @Override 211 protected IndependentRegressionTreeModel copy(String newName, ModelProvenance newProvenance) { 212 Map<String,Node<Regressor>> newRoots = new HashMap<>(); 213 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 214 newRoots.put(e.getKey(),e.getValue().copy()); 215 } 216 return new IndependentRegressionTreeModel(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,newRoots); 217 } 218 219 private Prediction<Regressor> combine(List<Prediction<Regressor>> predictions) { 220 DimensionTuple[] tuples = new DimensionTuple[predictions.size()]; 221 int numUsed = 0; 222 int i = 0; 223 for (Prediction<Regressor> p : predictions) { 224 if (numUsed < p.getNumActiveFeatures()) { 225 numUsed = p.getNumActiveFeatures(); 226 } 227 Regressor output = p.getOutput(); 228 if (output instanceof DimensionTuple) { 229 tuples[i] = (DimensionTuple)output; 230 } else { 231 throw new IllegalStateException("All the leaves should contain DimensionTuple not Regressor"); 232 } 233 i++; 234 } 235 236 Example<Regressor> example = predictions.get(0).getExample(); 237 return new Prediction<>(new Regressor(tuples),numUsed,example); 238 } 239 240 @Override 241 public Set<String> getFeatures() { 242 Set<String> features = new HashSet<>(); 243 244 Queue<Node<Regressor>> nodeQueue = new LinkedList<>(); 245 246 for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) { 247 nodeQueue.offer(e.getValue()); 248 249 while (!nodeQueue.isEmpty()) { 250 Node<Regressor> node = nodeQueue.poll(); 251 if ((node != null) && !node.isLeaf()) { 252 SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node; 253 features.add(featureIDMap.get(splitNode.getFeatureID()).getName()); 254 nodeQueue.offer(splitNode.getGreaterThan()); 255 nodeQueue.offer(splitNode.getLessThanOrEqual()); 256 } 257 } 258 } 259 260 return features; 261 } 262 263 @Override 264 public String toString() { 265 StringBuilder sb = new StringBuilder(); 266 for (Map.Entry<String,Node<Regressor>> curRoot : roots.entrySet()) { 267 sb.append("Output '"); 268 sb.append(curRoot.getKey()); 269 sb.append("' - tree = "); 270 sb.append(curRoot.getValue().toString()); 271 sb.append('\n'); 272 } 273 return "IndependentTreeModel(description="+provenance.toString()+",\n"+sb.toString()+")"; 274 } 275 276}