001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.dtree.impl; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.classification.Label; 025import org.tribuo.classification.dtree.impurity.LabelImpurity; 026import org.tribuo.common.tree.AbstractTrainingNode; 027import org.tribuo.common.tree.LeafNode; 028import org.tribuo.common.tree.Node; 029import org.tribuo.common.tree.SplitNode; 030import org.tribuo.common.tree.impl.IntArrayContainer; 031import org.tribuo.math.la.SparseVector; 032import org.tribuo.math.la.VectorTuple; 033import org.tribuo.util.Util; 034 035import java.io.IOException; 036import java.io.NotSerializableException; 037import java.util.ArrayList; 038import java.util.Arrays; 039import java.util.Collections; 040import java.util.LinkedHashMap; 041import java.util.List; 042import java.util.Map; 043import java.util.logging.Logger; 044 045/** 046 * A decision tree node used at training time. 047 * Contains a list of the example indices currently found in this node, 048 * the current impurity and a bunch of other statistics. 049 */ 050public class ClassifierTrainingNode extends AbstractTrainingNode<Label> { 051 private static final long serialVersionUID = 1L; 052 053 private static final Logger logger = Logger.getLogger(ClassifierTrainingNode.class.getName()); 054 055 private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 056 private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 057 private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 058 059 private transient ArrayList<TreeFeature> data; 060 061 private final ImmutableOutputInfo<Label> labelIDMap; 062 063 private final ImmutableFeatureMap featureIDMap; 064 065 private final LabelImpurity impurity; 066 067 private final float[] labelCounts; 068 069 /** 070 * Constructor which creates the inverted file. 071 * @param impurity The impurity function to use. 072 * @param examples The training data. 073 */ 074 public ClassifierTrainingNode(LabelImpurity impurity, Dataset<Label> examples) { 075 this(impurity,invertData(examples), examples.size(), 0, examples.getFeatureIDMap(), examples.getOutputIDInfo()); 076 } 077 078 private ClassifierTrainingNode(LabelImpurity impurity, ArrayList<TreeFeature> data, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) { 079 super(depth, numExamples); 080 this.data = data; 081 this.featureIDMap = featureIDMap; 082 this.labelIDMap = labelIDMap; 083 this.impurity = impurity; 084 this.labelCounts = data.get(0).getLabelCounts(); 085 } 086 087 /** 088 * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5). 089 * @param featureIDs Indices of the features available in this split. 090 * @return A possibly empty list of TrainingNodes. 091 */ 092 @Override 093 public List<AbstractTrainingNode<Label>> buildTree(int[] featureIDs) { 094 int bestID = -1; 095 double bestSplitValue = 0.0; 096 double bestScore = impurity.impurity(labelCounts); 097 float[] lessThanCounts = new float[labelCounts.length]; 098 float[] greaterThanCounts = new float[labelCounts.length]; 099 double countsSum = Util.sum(labelCounts); 100 for (int i = 0; i < featureIDs.length; i++) { 101 List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature(); 102 Arrays.fill(lessThanCounts,0.0f); 103 System.arraycopy(labelCounts, 0, greaterThanCounts, 0, labelCounts.length); 104 // searching for the intervals between features. 105 for (int j = 0; j < feature.size()-1; j++) { 106 InvertedFeature f = feature.get(j); 107 float[] featureCounts = f.getLabelCounts(); 108 Util.inPlaceAdd(lessThanCounts,featureCounts); 109 Util.inPlaceSubtract(greaterThanCounts,featureCounts); 110 double lessThanScore = impurity.impurityWeighted(lessThanCounts); 111 double greaterThanScore = impurity.impurityWeighted(greaterThanCounts); 112 if ((lessThanScore > 1e-10) && (greaterThanScore > 1e-10)) { 113 double score = (lessThanScore + greaterThanScore) / countsSum; 114 if (score < bestScore) { 115 bestID = i; 116 bestScore = score; 117 bestSplitValue = (f.value + feature.get(j + 1).value) / 2.0; 118 } 119 } 120 } 121 } 122 List<AbstractTrainingNode<Label>> output; 123 // If we found a split better than the current impurity. 124 if (bestID != -1) { 125 splitID = featureIDs[bestID]; 126 split = true; 127 splitValue = bestSplitValue; 128 IntArrayContainer lessThanIndices = mergeBufferOne.get(); 129 lessThanIndices.size = 0; 130 IntArrayContainer buffer = mergeBufferTwo.get(); 131 buffer.size = 0; 132 for (InvertedFeature f : data.get(splitID)) { 133 if (f.value < splitValue) { 134 int[] indices = f.indices(); 135 IntArrayContainer.merge(lessThanIndices,indices,buffer); 136 // Swap the buffers 137 IntArrayContainer tmp = lessThanIndices; 138 lessThanIndices = buffer; 139 buffer = tmp; 140 } else { 141 break; 142 } 143 } 144 //logger.info("Splitting on feature " + maxID + " with value " + maxSplitValue + " at depth " + depth + ", " + numExamples + " examples in node."); 145 //logger.info("left indices length = " + lessThanIndices.size); 146 IntArrayContainer secondBuffer = mergeBufferThree.get(); 147 secondBuffer.grow(lessThanIndices.size); 148 ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size()); 149 ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size()); 150 for (TreeFeature feature : data) { 151 Pair<TreeFeature,TreeFeature> split = feature.split(lessThanIndices,buffer,secondBuffer); 152 lessThanData.add(split.getA()); 153 greaterThanData.add(split.getB()); 154 } 155 156 lessThanOrEqual = new ClassifierTrainingNode(impurity, lessThanData, lessThanIndices.size, depth + 1, featureIDMap, labelIDMap); 157 greaterThan = new ClassifierTrainingNode(impurity, greaterThanData, numExamples - lessThanIndices.size, depth + 1, featureIDMap, labelIDMap); 158 output = new ArrayList<>(); 159 output.add(lessThanOrEqual); 160 output.add(greaterThan); 161 } else { 162 output = Collections.emptyList(); 163 } 164 data = null; 165 return output; 166 } 167 168 /** 169 * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node. 170 * @return A subtree using the SplitNode and LeafNode classes. 171 */ 172 @Override 173 public Node<Label> convertTree() { 174 if (split) { 175 // split node 176 Node<Label> newGreaterThan = greaterThan.convertTree(); 177 Node<Label> newLessThan = lessThanOrEqual.convertTree(); 178 return new SplitNode<>(splitValue,splitID,impurity.impurity(labelCounts),newGreaterThan,newLessThan); 179 } else { 180 // leaf node 181 double[] normedCounts = Util.normalizeToDistribution(labelCounts); 182 double maxScore = Double.NEGATIVE_INFINITY; 183 Label maxLabel = null; 184 Map<String,Label> counts = new LinkedHashMap<>(); 185 for (int i = 0; i < labelCounts.length; i++) { 186 String name = labelIDMap.getOutput(i).getLabel(); 187 Label label = new Label(name,normedCounts[i]); 188 counts.put(name, label); 189 if (label.getScore() > maxScore) { 190 maxScore = label.getScore(); 191 maxLabel = label; 192 } 193 } 194 return new LeafNode<>(impurity.impurity(labelCounts),maxLabel,counts,true); 195 } 196 } 197 198 @Override 199 public double getImpurity() { 200 return impurity.impurity(labelCounts); 201 } 202 203 /** 204 * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset 205 * so it's very expensive in terms of memory. 206 * @param examples An input dataset. 207 * @return A list of TreeFeatures which contain {@link InvertedFeature}s. 208 */ 209 private static ArrayList<TreeFeature> invertData(Dataset<Label> examples) { 210 ImmutableFeatureMap featureInfos = examples.getFeatureIDMap(); 211 ImmutableOutputInfo<Label> labelInfo = examples.getOutputIDInfo(); 212 int numLabels = labelInfo.size(); 213 int numFeatures = featureInfos.size(); 214 int numExamples = examples.size(); 215 216 int[] labels = new int[numExamples]; 217 float[] weights = new float[numExamples]; 218 219 int k = 0; 220 for (Example<Label> e : examples) { 221 weights[k] = e.getWeight(); 222 labels[k] = labelInfo.getID(e.getOutput()); 223 k++; 224 } 225 226 logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " classes"); 227 ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size()); 228 229 for (int i = 0; i < featureInfos.size(); i++) { 230 data.add(new TreeFeature(i,numLabels,labels,weights)); 231 } 232 233 for (int i = 0; i < examples.size(); i++) { 234 Example<Label> e = examples.getExample(i); 235 SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false); 236 int lastID = 0; 237 for (VectorTuple f : vec) { 238 int curID = f.index; 239 for (int j = lastID; j < curID; j++) { 240 data.get(j).observeValue(0.0,i); 241 } 242 data.get(curID).observeValue(f.value,i); 243 // 244 // These two checks should never occur as SparseVector deals with collisions, and Dataset prevents 245 // repeated features. 246 // They are left in just to make sure. 247 if (lastID > curID) { 248 logger.severe("Example = " + e.toString()); 249 throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 250 } else if (lastID-1 == curID) { 251 logger.severe("Example = " + e.toString()); 252 throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 253 } 254 lastID = curID + 1; 255 } 256 for (int j = lastID; j < numFeatures; j++) { 257 data.get(j).observeValue(0.0,i); 258 } 259 if (i % 1000 == 0) { 260 logger.fine("Processed example " + i); 261 } 262 } 263 264 logger.fine("Sorting features"); 265 266 data.forEach(TreeFeature::sort); 267 268 /* 269 for (TreeFeature f : data) { 270 logger.info(f.toString()); 271 } 272 */ 273 274 logger.fine("Fixing InvertedFeature sizes"); 275 276 data.forEach(TreeFeature::fixSize); 277 278 logger.fine("Built initial List<TreeFeature>"); 279 280 return data; 281 } 282 283 private void writeObject(java.io.ObjectOutputStream stream) 284 throws IOException { 285 throw new NotSerializableException("ClassifierTrainingNode is a runtime class only, and should not be serialized."); 286 } 287}