001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.rtree.impl; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.common.tree.AbstractTrainingNode; 025import org.tribuo.common.tree.LeafNode; 026import org.tribuo.common.tree.Node; 027import org.tribuo.common.tree.SplitNode; 028import org.tribuo.common.tree.impl.IntArrayContainer; 029import org.tribuo.math.la.SparseVector; 030import org.tribuo.math.la.VectorTuple; 031import org.tribuo.regression.Regressor; 032import org.tribuo.regression.Regressor.DimensionTuple; 033import org.tribuo.regression.rtree.impurity.RegressorImpurity; 034import org.tribuo.regression.rtree.impurity.RegressorImpurity.ImpurityTuple; 035import org.tribuo.util.Util; 036 037import java.io.IOException; 038import java.io.NotSerializableException; 039import java.util.ArrayList; 040import java.util.Collections; 041import java.util.List; 042import java.util.logging.Logger; 043 044/** 045 * A decision tree node used at training time. 046 * Contains a list of the example indices currently found in this node, 047 * the current impurity and a bunch of other statistics. 048 */ 049public class RegressorTrainingNode extends AbstractTrainingNode<Regressor> { 050 private static final long serialVersionUID = 1L; 051 052 private static final Logger logger = Logger.getLogger(RegressorTrainingNode.class.getName()); 053 054 private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 055 private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 056 private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 057 private static final ThreadLocal<IntArrayContainer> mergeBufferFour = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 058 private static final ThreadLocal<IntArrayContainer> mergeBufferFive = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE)); 059 060 private transient ArrayList<TreeFeature> data; 061 062 private final ImmutableOutputInfo<Regressor> labelIDMap; 063 064 private final ImmutableFeatureMap featureIDMap; 065 066 private final RegressorImpurity impurity; 067 068 private final int[] indices; 069 070 private final float[] targets; 071 072 private final float[] weights; 073 074 private final String dimName; 075 076 public RegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int dimIndex, String dimName, int numExamples, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo) { 077 this(impurity,tuple.copyData(),tuple.indices,tuple.targets[dimIndex],tuple.weights,dimName,numExamples,0,featureIDMap,outputInfo); 078 } 079 080 private RegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, float[] targets, float[] weights, String dimName, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap) { 081 super(depth, numExamples); 082 this.data = data; 083 this.featureIDMap = featureIDMap; 084 this.labelIDMap = labelIDMap; 085 this.impurity = impurity; 086 this.indices = indices; 087 this.targets = targets; 088 this.weights = weights; 089 this.dimName = dimName; 090 } 091 092 @Override 093 public double getImpurity() { 094 return impurity.impurity(indices, targets, weights); 095 } 096 097 /** 098 * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5). 099 * @param featureIDs Indices of the features available in this split. 100 * @return A possibly empty list of TrainingNodes. 101 */ 102 @Override 103 public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs) { 104 int bestID = -1; 105 double bestSplitValue = 0.0; 106 double weightSum = Util.sum(indices,indices.length,weights); 107 double bestScore = getImpurity(); 108 //logger.info("Cur node score = " + bestScore); 109 List<int[]> curIndices = new ArrayList<>(); 110 List<int[]> bestLeftIndices = new ArrayList<>(); 111 List<int[]> bestRightIndices = new ArrayList<>(); 112 for (int i = 0; i < featureIDs.length; i++) { 113 List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature(); 114 115 curIndices.clear(); 116 for (int j = 0; j < feature.size(); j++) { 117 InvertedFeature f = feature.get(j); 118 int[] curFeatureIndices = f.indices(); 119 curIndices.add(curFeatureIndices); 120 } 121 122 // searching for the intervals between features. 123 for (int j = 0; j < feature.size()-1; j++) { 124 List<int[]> curLeftIndices = curIndices.subList(0,j+1); 125 List<int[]> curRightIndices = curIndices.subList(j+1,feature.size()); 126 ImpurityTuple lessThanScore = impurity.impurityTuple(curLeftIndices,targets,weights); 127 ImpurityTuple greaterThanScore = impurity.impurityTuple(curRightIndices,targets,weights); 128 double score = (lessThanScore.impurity*lessThanScore.weight + greaterThanScore.impurity*greaterThanScore.weight) / weightSum; 129 if (score < bestScore) { 130 bestID = i; 131 bestScore = score; 132 bestSplitValue = (feature.get(j).value + feature.get(j + 1).value) / 2.0; 133 // Clear out the old best indices before storing the new ones. 134 bestLeftIndices.clear(); 135 bestLeftIndices.addAll(curLeftIndices); 136 bestRightIndices.clear(); 137 bestRightIndices.addAll(curRightIndices); 138 //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score); 139 //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size); 140 } 141 } 142 } 143 List<AbstractTrainingNode<Regressor>> output; 144 // If we found a split better than the current impurity. 145 if (bestID != -1) { 146 splitID = featureIDs[bestID]; 147 split = true; 148 splitValue = bestSplitValue; 149 IntArrayContainer firstBuffer = mergeBufferOne.get(); 150 firstBuffer.size = 0; 151 firstBuffer.grow(indices.length); 152 IntArrayContainer secondBuffer = mergeBufferTwo.get(); 153 secondBuffer.size = 0; 154 secondBuffer.grow(indices.length); 155 int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, firstBuffer, secondBuffer); 156 int[] rightIndices = IntArrayContainer.merge(bestRightIndices, firstBuffer, secondBuffer); 157 //logger.info("Splitting on feature " + bestID + " with value " + bestSplitValue + " at depth " + depth + ", " + numExamples + " examples in node."); 158 //logger.info("left indices length = " + leftIndices.length); 159 ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size()); 160 ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size()); 161 for (TreeFeature feature : data) { 162 Pair<TreeFeature,TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer); 163 lessThanData.add(split.getA()); 164 greaterThanData.add(split.getB()); 165 } 166 lessThanOrEqual = new RegressorTrainingNode(impurity, lessThanData, leftIndices, targets, weights, dimName, leftIndices.length, depth + 1, featureIDMap, labelIDMap); 167 greaterThan = new RegressorTrainingNode(impurity, greaterThanData, rightIndices, targets, weights, dimName, rightIndices.length, depth + 1, featureIDMap, labelIDMap); 168 output = new ArrayList<>(); 169 output.add(lessThanOrEqual); 170 output.add(greaterThan); 171 } else { 172 output = Collections.emptyList(); 173 } 174 data = null; 175 return output; 176 } 177 178 /** 179 * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node. 180 * @return A subtree using the SplitNode and LeafNode classes. 181 */ 182 @Override 183 public Node<Regressor> convertTree() { 184 if (split) { 185 // split node 186 Node<Regressor> newGreaterThan = greaterThan.convertTree(); 187 Node<Regressor> newLessThan = lessThanOrEqual.convertTree(); 188 return new SplitNode<>(splitValue,splitID,getImpurity(),newGreaterThan,newLessThan); 189 } else { 190 double mean = 0.0; 191 double weightSum = 0.0; 192 double variance = 0.0; 193 for (int i = 0; i < indices.length; i++) { 194 int idx = indices[i]; 195 float value = targets[idx]; 196 float weight = weights[idx]; 197 198 weightSum += weight; 199 double oldMean = mean; 200 mean += (weight / weightSum) * (value - oldMean); 201 variance += weight * (value - oldMean) * (value - mean); 202 } 203 variance = indices.length > 1 ? variance / (weightSum-1) : 0; 204 DimensionTuple leafPred = new DimensionTuple(dimName,mean,variance); 205 return new LeafNode<>(getImpurity(),leafPred,Collections.emptyMap(),false); 206 } 207 } 208 209 /** 210 * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset 211 * so it's very expensive in terms of memory. 212 * @param examples An input dataset. 213 * @return A list of TreeFeatures which contain {@link InvertedFeature}s. 214 */ 215 public static InvertedData invertData(Dataset<Regressor> examples) { 216 ImmutableFeatureMap featureInfos = examples.getFeatureIDMap(); 217 ImmutableOutputInfo<Regressor> labelInfo = examples.getOutputIDInfo(); 218 int numLabels = labelInfo.size(); 219 int numFeatures = featureInfos.size(); 220 int[] indices = new int[examples.size()]; 221 float[][] targets = new float[numLabels][examples.size()]; 222 float[] weights = new float[examples.size()]; 223 224 logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs"); 225 ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size()); 226 227 for (int i = 0; i < featureInfos.size(); i++) { 228 data.add(new TreeFeature(i)); 229 } 230 231 for (int i = 0; i < examples.size(); i++) { 232 Example<Regressor> e = examples.getExample(i); 233 indices[i] = i; 234 weights[i] = e.getWeight(); 235 double[] output = e.getOutput().getValues(); 236 for (int j = 0; j < output.length; j++) { 237 targets[j][i] = (float) output[j]; 238 } 239 SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false); 240 int lastID = 0; 241 for (VectorTuple f : vec) { 242 int curID = f.index; 243 for (int j = lastID; j < curID; j++) { 244 data.get(j).observeValue(0.0,i); 245 } 246 data.get(curID).observeValue(f.value,i); 247 // 248 // These two checks should never occur as SparseVector deals with 249 // collisions, and Dataset prevents repeated features. 250 // They are left in just to make sure. 251 if (lastID > curID) { 252 logger.severe("Example = " + e.toString()); 253 throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 254 } else if (lastID-1 == curID) { 255 logger.severe("Example = " + e.toString()); 256 throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID); 257 } 258 lastID = curID + 1; 259 } 260 for (int j = lastID; j < numFeatures; j++) { 261 data.get(j).observeValue(0.0,i); 262 } 263 if (i % 1000 == 0) { 264 logger.fine("Processed example " + i); 265 } 266 } 267 268 logger.fine("Sorting features"); 269 270 data.forEach(TreeFeature::sort); 271 272 logger.fine("Fixing InvertedFeature sizes"); 273 274 data.forEach(TreeFeature::fixSize); 275 276 logger.fine("Built initial List<TreeFeature>"); 277 278 return new InvertedData(data,indices,targets,weights); 279 } 280 281 /** 282 * Tuple containing an inverted dataset (i.e., feature-wise not exmaple-wise). 283 */ 284 public static class InvertedData { 285 final ArrayList<TreeFeature> data; 286 final int[] indices; 287 final float[][] targets; 288 final float[] weights; 289 290 InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) { 291 this.data = data; 292 this.indices = indices; 293 this.targets = targets; 294 this.weights = weights; 295 } 296 297 ArrayList<TreeFeature> copyData() { 298 ArrayList<TreeFeature> newData = new ArrayList<>(); 299 300 for (TreeFeature f : data) { 301 newData.add(f.deepCopy()); 302 } 303 304 return newData; 305 } 306 } 307 308 private void writeObject(java.io.ObjectOutputStream stream) 309 throws IOException { 310 throw new NotSerializableException("RegressorTrainingNode is a runtime class only, and should not be serialized."); 311 } 312}