001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impl; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.common.tree.AbstractTrainingNode; 025import org.tribuo.common.tree.LeafNode; 026import org.tribuo.common.tree.Node; 027import org.tribuo.common.tree.SplitNode; 028import org.tribuo.common.tree.impl.IntArrayContainer; 029import org.tribuo.math.la.SparseVector; 030import org.tribuo.math.la.VectorTuple; 031import org.tribuo.regression.Regressor; 032import org.tribuo.regression.rtree.impurity.RegressorImpurity; 033import org.tribuo.regression.rtree.impurity.RegressorImpurity.ImpurityTuple; 034import org.tribuo.util.Util; 035 036import java.io.IOException; 037import java.io.NotSerializableException; 038import java.util.ArrayList; 039import java.util.Collections; 040import java.util.List; 041import java.util.logging.Logger; 042 043/** 044 * A decision tree node used at training time. 045 * Contains a list of the example indices currently found in this node, 046 * the current impurity and a bunch of other statistics. 047 */ 048public class JointRegressorTrainingNode extends AbstractTrainingNode<Regressor> { 049 private static final long serialVersionUID = 1L; 050 051 private static final Logger logger = Logger.getLogger(JointRegressorTrainingNode.class.getName()); 052 053 private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 054 private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 055 private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 056 private static final ThreadLocal<IntArrayContainer> mergeBufferFour = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 057 private static final ThreadLocal<IntArrayContainer> mergeBufferFive = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 058 059 private transient ArrayList<TreeFeature> data; 060 061 private final boolean normalize; 062 063 private final ImmutableOutputInfo<Regressor> labelIDMap; 064 065 private final ImmutableFeatureMap featureIDMap; 066 067 private final RegressorImpurity impurity; 068 069 private final int[] indices; 070 071 private final float[][] targets; 072 073 private final float[] weights; 074 075 /** 076 * Constructor which creates the inverted file. 077 * @param impurity The impurity function to use. 078 * @param examples The training data. 079 * @param normalize Normalizes the leaves so each leaf has a distribution which sums to 1.0. 080 */ 081 public JointRegressorTrainingNode(RegressorImpurity impurity, Dataset<Regressor> examples, boolean normalize) { 082 this(impurity, invertData(examples), examples.size(), examples.getFeatureIDMap(), examples.getOutputIDInfo(), normalize); 083 } 084 085 private JointRegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int numExamples, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo, boolean normalize) { 086 this(impurity,tuple.data,tuple.indices,tuple.targets,tuple.weights,numExamples,0,featureIDMap,outputInfo,normalize); 087 } 088 089 private JointRegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, boolean normalize) { 090 super(depth, numExamples); 091 this.data = data; 092 this.normalize = normalize; 093 this.featureIDMap = featureIDMap; 094 this.labelIDMap = labelIDMap; 095 this.impurity = impurity; 096 this.indices = indices; 097 this.targets = targets; 098 this.weights = weights; 099 } 100 101 @Override 102 public double getImpurity() { 103 double tmp = 0.0; 104 for (int i = 0; i < targets.length; i++) { 105 tmp += impurity.impurity(indices, targets[i], weights); 106 } 107 return tmp / targets.length; 108 } 109 110 /** 111 * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5). 112 * @param featureIDs Indices of the features available in this split. 113 * @return A possibly empty list of TrainingNodes. 114 */ 115 @Override 116 public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs) { 117 int bestID = -1; 118 double bestSplitValue = 0.0; 119 double weightSum = Util.sum(indices,indices.length,weights); 120 double bestScore = getImpurity(); 121 //logger.info("Cur node score = " + bestScore); 122 List<int[]> curIndices = new ArrayList<>(); 123 List<int[]> bestLeftIndices = new ArrayList<>(); 124 List<int[]> bestRightIndices = new ArrayList<>(); 125 for (int i = 0; i < featureIDs.length; i++) { 126 List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature(); 127 128 curIndices.clear(); 129 for (int j = 0; j < feature.size(); j++) { 130 InvertedFeature f = feature.get(j); 131 int[] curFeatureIndices = f.indices(); 132 curIndices.add(curFeatureIndices); 133 } 134 135 // searching for the intervals between features. 136 for (int j = 0; j < feature.size()-1; j++) { 137 List<int[]> curLeftIndices = curIndices.subList(0,j+1); 138 List<int[]> curRightIndices = curIndices.subList(j+1,feature.size()); 139 double lessThanScore = 0.0; 140 double greaterThanScore = 0.0; 141 for (int k = 0; k < targets.length; k++) { 142 ImpurityTuple left = impurity.impurityTuple(curLeftIndices,targets[k],weights); 143 lessThanScore += left.impurity * left.weight; 144 ImpurityTuple right = impurity.impurityTuple(curRightIndices,targets[k],weights); 145 greaterThanScore += right.impurity * right.weight; 146 } 147 double score = (lessThanScore + greaterThanScore) / (targets.length * weightSum); 148 if (score < bestScore) { 149 bestID = i; 150 bestScore = score; 151 bestSplitValue = (feature.get(j).value + feature.get(j + 1).value) / 2.0; 152 // Clear out the old best indices before storing the new ones. 153 bestLeftIndices.clear(); 154 bestLeftIndices.addAll(curLeftIndices); 155 bestRightIndices.clear(); 156 bestRightIndices.addAll(curRightIndices); 157 //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score); 158 //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size); 159 } 160 } 161 } 162 List<AbstractTrainingNode<Regressor>> output; 163 // If we found a split better than the current impurity. 164 if (bestID != -1) { 165 splitID = featureIDs[bestID]; 166 split = true; 167 splitValue = bestSplitValue; 168 IntArrayContainer firstBuffer = mergeBufferOne.get(); 169 firstBuffer.size = 0; 170 firstBuffer.grow(indices.length); 171 IntArrayContainer secondBuffer = mergeBufferTwo.get(); 172 secondBuffer.size = 0; 173 secondBuffer.grow(indices.length); 174 int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, firstBuffer, secondBuffer); 175 int[] rightIndices = IntArrayContainer.merge(bestRightIndices, firstBuffer, secondBuffer); 176 //logger.info("Splitting on feature " + bestID + " with value " + bestSplitValue + " at depth " + depth + ", " + numExamples + " examples in node."); 177 //logger.info("left indices length = " + leftIndices.length); 178 ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size()); 179 ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size()); 180 for (TreeFeature feature : data) { 181 Pair<TreeFeature,TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer); 182 lessThanData.add(split.getA()); 183 greaterThanData.add(split.getB()); 184 } 185 lessThanOrEqual = new JointRegressorTrainingNode(impurity, lessThanData, leftIndices, targets, weights, leftIndices.length, depth + 1, featureIDMap, labelIDMap, normalize); 186 greaterThan = new JointRegressorTrainingNode(impurity, greaterThanData, rightIndices, targets, weights, rightIndices.length, depth + 1, featureIDMap, labelIDMap, normalize); 187 output = new ArrayList<>(); 188 output.add(lessThanOrEqual); 189 output.add(greaterThan); 190 } else { 191 output = Collections.emptyList(); 192 } 193 data = null; 194 return output; 195 } 196 197 /** 198 * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node. 199 * @return A subtree using the SplitNode and LeafNode classes. 200 */ 201 @Override 202 public Node<Regressor> convertTree() { 203 if (split) { 204 // split node 205 Node<Regressor> newGreaterThan = greaterThan.convertTree(); 206 Node<Regressor> newLessThan = lessThanOrEqual.convertTree(); 207 return new SplitNode<>(splitValue,splitID,getImpurity(),newGreaterThan,newLessThan); 208 } else { 209 double weightSum = 0.0; 210 double[] mean = new double[targets.length]; 211 Regressor leafPred; 212 if (normalize) { 213 for (int i = 0; i < indices.length; i++) { 214 int idx = indices[i]; 215 float weight = weights[idx]; 216 weightSum += weight; 217 for (int j = 0; j < targets.length; j++) { 218 float value = targets[j][idx]; 219 220 double oldMean = mean[j]; 221 mean[j] += (weight / weightSum) * (value - oldMean); 222 } 223 } 224 String[] names = new String[targets.length]; 225 double sum = 0.0; 226 for (int i = 0; i < targets.length; i++) { 227 names[i] = labelIDMap.getOutput(i).getNames()[0]; 228 sum += mean[i]; 229 } 230 // Normalize all the outputs so they sum to 1.0. 231 for (int i = 0; i < targets.length; i++) { 232 mean[i] /= sum; 233 } 234 leafPred = new Regressor(names, mean); 235 } else { 236 double[] variance = new double[targets.length]; 237 for (int i = 0; i < indices.length; i++) { 238 int idx = indices[i]; 239 float weight = weights[idx]; 240 weightSum += weight; 241 for (int j = 0; j < targets.length; j++) { 242 float value = targets[j][idx]; 243 244 double oldMean = mean[j]; 245 mean[j] += (weight / weightSum) * (value - oldMean); 246 variance[j] += weight * (value - oldMean) * (value - mean[j]); 247 } 248 } 249 String[] names = new String[targets.length]; 250 for (int i = 0; i < targets.length; i++) { 251 names[i] = labelIDMap.getOutput(i).getNames()[0]; 252 variance[i] = indices.length > 1 ? variance[i] / (weightSum - 1) : 0; 253 } 254 leafPred = new Regressor(names, mean, variance); 255 } 256 return new LeafNode<>(getImpurity(),leafPred,Collections.emptyMap(),false); 257 } 258 } 259 260 /** 261 * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset 262 * so it's very expensive in terms of memory. 263 * @param examples An input dataset. 264 * @return A list of TreeFeatures which contain {@link InvertedFeature}s. 265 */ 266 private static InvertedData invertData(Dataset<Regressor> examples) { 267 ImmutableFeatureMap featureInfos = examples.getFeatureIDMap(); 268 ImmutableOutputInfo<Regressor> labelInfo = examples.getOutputIDInfo(); 269 int numLabels = labelInfo.size(); 270 int numFeatures = featureInfos.size(); 271 int[] indices = new int[examples.size()]; 272 float[][] targets = new float[labelInfo.size()][examples.size()]; 273 float[] weights = new float[examples.size()]; 274 275 logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs"); 276 ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size()); 277 278 for (int i = 0; i < featureInfos.size(); i++) { 279 data.add(new TreeFeature(i)); 280 } 281 282 for (int i = 0; i < examples.size(); i++) { 283 Example<Regressor> e = examples.getExample(i); 284 indices[i] = i; 285 weights[i] = e.getWeight(); 286 double[] newTargets = e.getOutput().getValues(); 287 for (int j = 0; j < targets.length; j++) { 288 targets[j][i] = (float) newTargets[j]; 289 } 290 SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false); 291 int lastID = 0; 292 for (VectorTuple f : vec) { 293 int curID = f.index; 294 for (int j = lastID; j < curID; j++) { 295 data.get(j).observeValue(0.0,i); 296 } 297 data.get(curID).observeValue(f.value,i); 298 // 299 // These two checks should never occur as SparseVector deals with 300 // collisions, and Dataset prevents repeated features. 301 // They are left in just to make sure. 302 if (lastID > curID) { 303 logger.severe("Example = " + e.toString()); 304 throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 305 } else if (lastID-1 == curID) { 306 logger.severe("Example = " + e.toString()); 307 throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 308 } 309 lastID = curID + 1; 310 } 311 for (int j = lastID; j < numFeatures; j++) { 312 data.get(j).observeValue(0.0,i); 313 } 314 if (i % 1000 == 0) { 315 logger.fine("Processed example " + i); 316 } 317 } 318 319 logger.fine("Sorting features"); 320 321 data.forEach(TreeFeature::sort); 322 323 logger.fine("Fixing InvertedFeature sizes"); 324 325 data.forEach(TreeFeature::fixSize); 326 327 logger.fine("Built initial List<TreeFeature>"); 328 329 return new InvertedData(data,indices,targets,weights); 330 } 331 332 private static class InvertedData { 333 final ArrayList<TreeFeature> data; 334 final int[] indices; 335 final float[][] targets; 336 final float[] weights; 337 338 InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) { 339 this.data = data; 340 this.indices = indices; 341 this.targets = targets; 342 this.weights = weights; 343 } 344 } 345 346 private void writeObject(java.io.ObjectOutputStream stream) 347 throws IOException { 348 throw new NotSerializableException("JointRegressorTrainingNode is a runtime class only, and should not be serialized."); 349 } 350}