001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.common.tree.LeafNode;
027import org.tribuo.common.tree.Node;
028import org.tribuo.common.tree.SplitNode;
029import org.tribuo.common.tree.TreeModel;
030import org.tribuo.math.la.SparseVector;
031import org.tribuo.provenance.ModelProvenance;
032import org.tribuo.regression.Regressor;
033import org.tribuo.regression.Regressor.DimensionTuple;
034
035import java.util.ArrayList;
036import java.util.Collections;
037import java.util.Comparator;
038import java.util.HashMap;
039import java.util.HashSet;
040import java.util.LinkedHashSet;
041import java.util.LinkedList;
042import java.util.List;
043import java.util.Map;
044import java.util.Optional;
045import java.util.PriorityQueue;
046import java.util.Queue;
047import java.util.Set;
048
049/**
050 * A {@link Model} wrapped around a list of decision tree root {@link Node}s used
051 * to generate independent predictions for each dimension in a regression.
052 */
053public final class IndependentRegressionTreeModel extends TreeModel<Regressor> {
054    private static final long serialVersionUID = 1L;
055
056    private final Map<String,Node<Regressor>> roots;
057
058    IndependentRegressionTreeModel(String name, ModelProvenance description,
059                                          ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, boolean generatesProbabilities,
060                                          Map<String,Node<Regressor>> roots) {
061        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities, gatherActiveFeatures(featureIDMap,roots));
062        this.roots = roots;
063    }
064
065    private static Map<String,List<String>> gatherActiveFeatures(ImmutableFeatureMap fMap, Map<String,Node<Regressor>> roots) {
066        HashMap<String,List<String>> outputMap = new HashMap<>();
067        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
068            Set<String> activeFeatures = new LinkedHashSet<>();
069
070            Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
071
072            nodeQueue.offer(e.getValue());
073
074            while (!nodeQueue.isEmpty()) {
075                Node<Regressor> node = nodeQueue.poll();
076                if ((node != null) && (!node.isLeaf())) {
077                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
078                    String featureName = fMap.get(splitNode.getFeatureID()).getName();
079                    activeFeatures.add(featureName);
080                    nodeQueue.offer(splitNode.getGreaterThan());
081                    nodeQueue.offer(splitNode.getLessThanOrEqual());
082                }
083            }
084            outputMap.put(e.getKey(), new ArrayList<>(activeFeatures));
085        }
086        return outputMap;
087    }
088
089    @Override
090    public Prediction<Regressor> predict(Example<Regressor> example) {
091        //
092        // Ensures we handle collisions correctly
093        SparseVector vec = SparseVector.createSparseVector(example,featureIDMap,false);
094        if (vec.numActiveElements() == 0) {
095            throw new IllegalArgumentException("No features found in Example " + example.toString());
096        }
097
098        List<Prediction<Regressor>> predictionList = new ArrayList<>();
099        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
100            Node<Regressor> oldNode = e.getValue();
101            Node<Regressor> curNode = e.getValue();
102
103            while (curNode != null) {
104                oldNode = curNode;
105                curNode = oldNode.getNextNode(vec);
106            }
107
108            //
109            // oldNode must be a LeafNode.
110            predictionList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
111        }
112        return combine(predictionList);
113    }
114
115    @Override
116    public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) {
117        int maxFeatures = n < 0 ? featureIDMap.size() : n;
118
119        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
120        Map<String, Integer> featureCounts = new HashMap<>();
121        Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
122
123        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
124            featureCounts.clear();
125            nodeQueue.clear();
126
127            nodeQueue.offer(e.getValue());
128
129            while (!nodeQueue.isEmpty()) {
130                Node<Regressor> node = nodeQueue.poll();
131                if ((node != null) && !node.isLeaf()) {
132                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
133                    String featureName = featureIDMap.get(splitNode.getFeatureID()).getName();
134                    featureCounts.put(featureName, featureCounts.getOrDefault(featureName, 0) + 1);
135                    nodeQueue.offer(splitNode.getGreaterThan());
136                    nodeQueue.offer(splitNode.getLessThanOrEqual());
137                }
138            }
139
140            Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
141            PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator);
142
143            for (Map.Entry<String, Integer> featureCount : featureCounts.entrySet()) {
144                Pair<String, Double> cur = new Pair<>(featureCount.getKey(), (double) featureCount.getValue());
145                if (q.size() < maxFeatures) {
146                    q.offer(cur);
147                } else if (comparator.compare(cur, q.peek()) > 0) {
148                    q.poll();
149                    q.offer(cur);
150                }
151            }
152            List<Pair<String, Double>> list = new ArrayList<>();
153            while (q.size() > 0) {
154                list.add(q.poll());
155            }
156            Collections.reverse(list);
157
158            map.put(e.getKey(), list);
159        }
160
161        return map;
162    }
163
164    @Override
165    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
166        SparseVector vec = SparseVector.createSparseVector(example, featureIDMap, false);
167        if (vec.numActiveElements() == 0) {
168            return Optional.empty();
169        }
170
171        List<String> list = new ArrayList<>();
172        List<Prediction<Regressor>> predList = new ArrayList<>();
173        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
174
175        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
176            list.clear();
177
178            //
179            // Ensures we handle collisions correctly
180            Node<Regressor> oldNode = e.getValue();
181            Node<Regressor> curNode = e.getValue();
182
183            while (curNode != null) {
184                oldNode = curNode;
185                if (oldNode instanceof SplitNode) {
186                    SplitNode<?> node = (SplitNode<?>) curNode;
187                    list.add(featureIDMap.get(node.getFeatureID()).getName());
188                }
189                curNode = oldNode.getNextNode(vec);
190            }
191
192            //
193            // oldNode must be a LeafNode.
194            predList.add(((LeafNode<Regressor>) oldNode).getPrediction(vec.numActiveElements(), example));
195
196            List<Pair<String, Double>> pairs = new ArrayList<>();
197            int i = list.size() + 1;
198            for (String s : list) {
199                pairs.add(new Pair<>(s, i + 0.0));
200                i--;
201            }
202
203            map.put(e.getKey(), pairs);
204        }
205        Prediction<Regressor> combinedPrediction = combine(predList);
206
207        return Optional.of(new Excuse<>(example,combinedPrediction,map));
208    }
209
210    @Override
211    protected IndependentRegressionTreeModel copy(String newName, ModelProvenance newProvenance) {
212        Map<String,Node<Regressor>> newRoots = new HashMap<>();
213        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
214            newRoots.put(e.getKey(),e.getValue().copy());
215        }
216        return new IndependentRegressionTreeModel(newName,newProvenance,featureIDMap,outputIDInfo,generatesProbabilities,newRoots);
217    }
218
219    private Prediction<Regressor> combine(List<Prediction<Regressor>> predictions) {
220        DimensionTuple[] tuples = new DimensionTuple[predictions.size()];
221        int numUsed = 0;
222        int i = 0;
223        for (Prediction<Regressor> p : predictions) {
224            if (numUsed < p.getNumActiveFeatures()) {
225                numUsed = p.getNumActiveFeatures();
226            }
227            Regressor output = p.getOutput();
228            if (output instanceof DimensionTuple) {
229                tuples[i] = (DimensionTuple)output;
230            } else {
231                throw new IllegalStateException("All the leaves should contain DimensionTuple not Regressor");
232            }
233            i++;
234        }
235
236        Example<Regressor> example = predictions.get(0).getExample();
237        return new Prediction<>(new Regressor(tuples),numUsed,example);
238    }
239
240    @Override
241    public Set<String> getFeatures() {
242        Set<String> features = new HashSet<>();
243
244        Queue<Node<Regressor>> nodeQueue = new LinkedList<>();
245
246        for (Map.Entry<String,Node<Regressor>> e : roots.entrySet()) {
247            nodeQueue.offer(e.getValue());
248
249            while (!nodeQueue.isEmpty()) {
250                Node<Regressor> node = nodeQueue.poll();
251                if ((node != null) && !node.isLeaf()) {
252                    SplitNode<Regressor> splitNode = (SplitNode<Regressor>) node;
253                    features.add(featureIDMap.get(splitNode.getFeatureID()).getName());
254                    nodeQueue.offer(splitNode.getGreaterThan());
255                    nodeQueue.offer(splitNode.getLessThanOrEqual());
256                }
257            }
258        }
259
260        return features;
261    }
262
263    @Override
264    public String toString() {
265        StringBuilder sb = new StringBuilder();
266        for (Map.Entry<String,Node<Regressor>> curRoot : roots.entrySet()) {
267            sb.append("Output '");
268            sb.append(curRoot.getKey());
269            sb.append("' - tree = ");
270            sb.append(curRoot.getValue().toString());
271            sb.append('\n');
272        }
273        return "IndependentTreeModel(description="+provenance.toString()+",\n"+sb.toString()+")";
274    }
275
276}