001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree.impl;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.common.tree.AbstractTrainingNode;
025import org.tribuo.common.tree.LeafNode;
026import org.tribuo.common.tree.Node;
027import org.tribuo.common.tree.SplitNode;
028import org.tribuo.common.tree.impl.IntArrayContainer;
029import org.tribuo.math.la.SparseVector;
030import org.tribuo.math.la.VectorTuple;
031import org.tribuo.regression.Regressor;
032import org.tribuo.regression.Regressor.DimensionTuple;
033import org.tribuo.regression.rtree.impurity.RegressorImpurity;
034import org.tribuo.regression.rtree.impurity.RegressorImpurity.ImpurityTuple;
035import org.tribuo.util.Util;
036
037import java.io.IOException;
038import java.io.NotSerializableException;
039import java.util.ArrayList;
040import java.util.Collections;
041import java.util.List;
042import java.util.logging.Logger;
043
044/**
045 * A decision tree node used at training time.
046 * Contains a list of the example indices currently found in this node,
047 * the current impurity and a bunch of other statistics.
048 */
049public class RegressorTrainingNode extends AbstractTrainingNode<Regressor> {
050    private static final long serialVersionUID = 1L;
051
052    private static final Logger logger = Logger.getLogger(RegressorTrainingNode.class.getName());
053
054    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
055    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
056    private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
057    private static final ThreadLocal<IntArrayContainer> mergeBufferFour = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
058    private static final ThreadLocal<IntArrayContainer> mergeBufferFive = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
059
060    private transient ArrayList<TreeFeature> data;
061
062    private final ImmutableOutputInfo<Regressor> labelIDMap;
063
064    private final ImmutableFeatureMap featureIDMap;
065
066    private final RegressorImpurity impurity;
067
068    private final int[] indices;
069
070    private final float[] targets;
071
072    private final float[] weights;
073
074    private final String dimName;
075
076    public RegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int dimIndex, String dimName, int numExamples, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo) {
077        this(impurity,tuple.copyData(),tuple.indices,tuple.targets[dimIndex],tuple.weights,dimName,numExamples,0,featureIDMap,outputInfo);
078    }
079
080    private RegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, float[] targets, float[] weights, String dimName, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap) {
081        super(depth, numExamples);
082        this.data = data;
083        this.featureIDMap = featureIDMap;
084        this.labelIDMap = labelIDMap;
085        this.impurity = impurity;
086        this.indices = indices;
087        this.targets = targets;
088        this.weights = weights;
089        this.dimName = dimName;
090    }
091
092    @Override
093    public double getImpurity() {
094        return impurity.impurity(indices, targets, weights);
095    }
096
097    /**
098     * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5).
099     * @param featureIDs Indices of the features available in this split.
100     * @return A possibly empty list of TrainingNodes.
101     */
102    @Override
103    public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs) {
104        int bestID = -1;
105        double bestSplitValue = 0.0;
106        double weightSum = Util.sum(indices,indices.length,weights);
107        double bestScore = getImpurity();
108        //logger.info("Cur node score = " + bestScore);
109        List<int[]> curIndices = new ArrayList<>();
110        List<int[]> bestLeftIndices = new ArrayList<>();
111        List<int[]> bestRightIndices = new ArrayList<>();
112        for (int i = 0; i < featureIDs.length; i++) {
113            List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature();
114
115            curIndices.clear();
116            for (int j = 0; j < feature.size(); j++) {
117                InvertedFeature f = feature.get(j);
118                int[] curFeatureIndices = f.indices();
119                curIndices.add(curFeatureIndices);
120            }
121
122            // searching for the intervals between features.
123            for (int j = 0; j < feature.size()-1; j++) {
124                List<int[]> curLeftIndices = curIndices.subList(0,j+1);
125                List<int[]> curRightIndices = curIndices.subList(j+1,feature.size());
126                ImpurityTuple lessThanScore = impurity.impurityTuple(curLeftIndices,targets,weights);
127                ImpurityTuple greaterThanScore = impurity.impurityTuple(curRightIndices,targets,weights);
128                double score = (lessThanScore.impurity*lessThanScore.weight + greaterThanScore.impurity*greaterThanScore.weight) / weightSum;
129                if (score < bestScore) {
130                    bestID = i;
131                    bestScore = score;
132                    bestSplitValue = (feature.get(j).value + feature.get(j + 1).value) / 2.0;
133                    // Clear out the old best indices before storing the new ones.
134                    bestLeftIndices.clear();
135                    bestLeftIndices.addAll(curLeftIndices);
136                    bestRightIndices.clear();
137                    bestRightIndices.addAll(curRightIndices);
138                    //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score);
139                    //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size);
140                }
141            }
142        }
143        List<AbstractTrainingNode<Regressor>> output;
144        // If we found a split better than the current impurity.
145        if (bestID != -1) {
146            splitID = featureIDs[bestID];
147            split = true;
148            splitValue = bestSplitValue;
149            IntArrayContainer firstBuffer = mergeBufferOne.get();
150            firstBuffer.size = 0;
151            firstBuffer.grow(indices.length);
152            IntArrayContainer secondBuffer = mergeBufferTwo.get();
153            secondBuffer.size = 0;
154            secondBuffer.grow(indices.length);
155            int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, firstBuffer, secondBuffer);
156            int[] rightIndices = IntArrayContainer.merge(bestRightIndices, firstBuffer, secondBuffer);
157            //logger.info("Splitting on feature " + bestID + " with value " + bestSplitValue + " at depth " + depth + ", " + numExamples + " examples in node.");
158            //logger.info("left indices length = " + leftIndices.length);
159            ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size());
160            ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size());
161            for (TreeFeature feature : data) {
162                Pair<TreeFeature,TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer);
163                lessThanData.add(split.getA());
164                greaterThanData.add(split.getB());
165            }
166            lessThanOrEqual = new RegressorTrainingNode(impurity, lessThanData, leftIndices, targets, weights, dimName, leftIndices.length, depth + 1, featureIDMap, labelIDMap);
167            greaterThan = new RegressorTrainingNode(impurity, greaterThanData, rightIndices, targets, weights, dimName, rightIndices.length, depth + 1, featureIDMap, labelIDMap);
168            output = new ArrayList<>();
169            output.add(lessThanOrEqual);
170            output.add(greaterThan);
171        } else {
172            output = Collections.emptyList();
173        }
174        data = null;
175        return output;
176    }
177
178    /**
179     * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node.
180     * @return A subtree using the SplitNode and LeafNode classes.
181     */
182    @Override
183    public Node<Regressor> convertTree() {
184        if (split) {
185            // split node
186            Node<Regressor> newGreaterThan = greaterThan.convertTree();
187            Node<Regressor> newLessThan = lessThanOrEqual.convertTree();
188            return new SplitNode<>(splitValue,splitID,getImpurity(),newGreaterThan,newLessThan);
189        } else {
190            double mean = 0.0;
191            double weightSum = 0.0;
192            double variance = 0.0;
193            for (int i = 0; i < indices.length; i++) {
194                int idx = indices[i];
195                float value = targets[idx];
196                float weight = weights[idx];
197
198                weightSum += weight;
199                double oldMean = mean;
200                mean += (weight / weightSum) * (value - oldMean);
201                variance += weight * (value - oldMean) * (value - mean);
202            }
203            variance = indices.length > 1 ? variance / (weightSum-1) : 0;
204            DimensionTuple leafPred = new DimensionTuple(dimName,mean,variance);
205            return new LeafNode<>(getImpurity(),leafPred,Collections.emptyMap(),false);
206        }
207    }
208
209    /**
210     * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset
211     * so it's very expensive in terms of memory.
212     * @param examples An input dataset.
213     * @return A list of TreeFeatures which contain {@link InvertedFeature}s.
214     */
215    public static InvertedData invertData(Dataset<Regressor> examples) {
216        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
217        ImmutableOutputInfo<Regressor> labelInfo = examples.getOutputIDInfo();
218        int numLabels = labelInfo.size();
219        int numFeatures = featureInfos.size();
220        int[] indices = new int[examples.size()];
221        float[][] targets = new float[numLabels][examples.size()];
222        float[] weights = new float[examples.size()];
223
224        logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs");
225        ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size());
226
227        for (int i = 0; i < featureInfos.size(); i++) {
228            data.add(new TreeFeature(i));
229        }
230
231        for (int i = 0; i < examples.size(); i++) {
232            Example<Regressor> e = examples.getExample(i);
233            indices[i] = i;
234            weights[i] = e.getWeight();
235            double[] output = e.getOutput().getValues();
236            for (int j = 0; j < output.length; j++) {
237                targets[j][i] = (float) output[j];
238            }
239            SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false);
240            int lastID = 0;
241            for (VectorTuple f : vec) {
242                int curID = f.index;
243                for (int j = lastID; j < curID; j++) {
244                    data.get(j).observeValue(0.0,i);
245                }
246                data.get(curID).observeValue(f.value,i);
247                //
248                // These two checks should never occur as SparseVector deals with
249                // collisions, and Dataset prevents repeated features.
250                // They are left in just to make sure.
251                if (lastID > curID) {
252                    logger.severe("Example = " + e.toString());
253                    throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
254                } else if (lastID-1 == curID) {
255                    logger.severe("Example = " + e.toString());
256                    throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
257                }
258                lastID = curID + 1;
259            }
260            for (int j = lastID; j < numFeatures; j++) {
261                data.get(j).observeValue(0.0,i);
262            }
263            if (i % 1000 == 0) {
264                logger.fine("Processed example " + i);
265            }
266        }
267
268        logger.fine("Sorting features");
269
270        data.forEach(TreeFeature::sort);
271
272        logger.fine("Fixing InvertedFeature sizes");
273
274        data.forEach(TreeFeature::fixSize);
275
276        logger.fine("Built initial List<TreeFeature>");
277
278        return new InvertedData(data,indices,targets,weights);
279    }
280
281    /**
282     * Tuple containing an inverted dataset (i.e., feature-wise not exmaple-wise).
283     */
284    public static class InvertedData {
285        final ArrayList<TreeFeature> data;
286        final int[] indices;
287        final float[][] targets;
288        final float[] weights;
289
290        InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) {
291            this.data = data;
292            this.indices = indices;
293            this.targets = targets;
294            this.weights = weights;
295        }
296
297        ArrayList<TreeFeature> copyData() {
298            ArrayList<TreeFeature> newData = new ArrayList<>();
299
300            for (TreeFeature f : data) {
301                newData.add(f.deepCopy());
302            }
303
304            return newData;
305        }
306    }
307
308    private void writeObject(java.io.ObjectOutputStream stream)
309            throws IOException {
310        throw new NotSerializableException("RegressorTrainingNode is a runtime class only, and should not be serialized.");
311    }
312}