001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree.impl;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.common.tree.AbstractTrainingNode;
025import org.tribuo.common.tree.LeafNode;
026import org.tribuo.common.tree.Node;
027import org.tribuo.common.tree.SplitNode;
028import org.tribuo.common.tree.impl.IntArrayContainer;
029import org.tribuo.math.la.SparseVector;
030import org.tribuo.math.la.VectorTuple;
031import org.tribuo.regression.Regressor;
032import org.tribuo.regression.rtree.impurity.RegressorImpurity;
033import org.tribuo.regression.rtree.impurity.RegressorImpurity.ImpurityTuple;
034import org.tribuo.util.Util;
035
036import java.io.IOException;
037import java.io.NotSerializableException;
038import java.util.ArrayList;
039import java.util.Collections;
040import java.util.List;
041import java.util.logging.Logger;
042
043/**
044 * A decision tree node used at training time.
045 * Contains a list of the example indices currently found in this node,
046 * the current impurity and a bunch of other statistics.
047 */
048public class JointRegressorTrainingNode extends AbstractTrainingNode<Regressor> {
049    private static final long serialVersionUID = 1L;
050
051    private static final Logger logger = Logger.getLogger(JointRegressorTrainingNode.class.getName());
052
053    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
054    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
055    private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
056    private static final ThreadLocal<IntArrayContainer> mergeBufferFour = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
057    private static final ThreadLocal<IntArrayContainer> mergeBufferFive = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
058
059    private transient ArrayList<TreeFeature> data;
060
061    private final boolean normalize;
062
063    private final ImmutableOutputInfo<Regressor> labelIDMap;
064
065    private final ImmutableFeatureMap featureIDMap;
066
067    private final RegressorImpurity impurity;
068
069    private final int[] indices;
070
071    private final float[][] targets;
072
073    private final float[] weights;
074
075    /**
076     * Constructor which creates the inverted file.
077     * @param impurity The impurity function to use.
078     * @param examples The training data.
079     * @param normalize Normalizes the leaves so each leaf has a distribution which sums to 1.0.
080     */
081    public JointRegressorTrainingNode(RegressorImpurity impurity, Dataset<Regressor> examples, boolean normalize) {
082        this(impurity, invertData(examples), examples.size(), examples.getFeatureIDMap(), examples.getOutputIDInfo(), normalize);
083    }
084
085    private JointRegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int numExamples, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo, boolean normalize) {
086        this(impurity,tuple.data,tuple.indices,tuple.targets,tuple.weights,numExamples,0,featureIDMap,outputInfo,normalize);
087    }
088
089    private JointRegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, boolean normalize) {
090        super(depth, numExamples);
091        this.data = data;
092        this.normalize = normalize;
093        this.featureIDMap = featureIDMap;
094        this.labelIDMap = labelIDMap;
095        this.impurity = impurity;
096        this.indices = indices;
097        this.targets = targets;
098        this.weights = weights;
099    }
100
101    @Override
102    public double getImpurity() {
103        double tmp = 0.0;
104        for (int i = 0; i < targets.length; i++) {
105            tmp += impurity.impurity(indices, targets[i], weights);
106        }
107        return tmp / targets.length;
108    }
109
110    /**
111     * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5).
112     * @param featureIDs Indices of the features available in this split.
113     * @return A possibly empty list of TrainingNodes.
114     */
115    @Override
116    public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs) {
117        int bestID = -1;
118        double bestSplitValue = 0.0;
119        double weightSum = Util.sum(indices,indices.length,weights);
120        double bestScore = getImpurity();
121        //logger.info("Cur node score = " + bestScore);
122        List<int[]> curIndices = new ArrayList<>();
123        List<int[]> bestLeftIndices = new ArrayList<>();
124        List<int[]> bestRightIndices = new ArrayList<>();
125        for (int i = 0; i < featureIDs.length; i++) {
126            List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature();
127
128            curIndices.clear();
129            for (int j = 0; j < feature.size(); j++) {
130                InvertedFeature f = feature.get(j);
131                int[] curFeatureIndices = f.indices();
132                curIndices.add(curFeatureIndices);
133            }
134
135            // searching for the intervals between features.
136            for (int j = 0; j < feature.size()-1; j++) {
137                List<int[]> curLeftIndices = curIndices.subList(0,j+1);
138                List<int[]> curRightIndices = curIndices.subList(j+1,feature.size());
139                double lessThanScore = 0.0;
140                double greaterThanScore = 0.0;
141                for (int k = 0; k < targets.length; k++) {
142                    ImpurityTuple left = impurity.impurityTuple(curLeftIndices,targets[k],weights);
143                    lessThanScore += left.impurity * left.weight;
144                    ImpurityTuple right = impurity.impurityTuple(curRightIndices,targets[k],weights);
145                    greaterThanScore += right.impurity * right.weight;
146                }
147                double score = (lessThanScore + greaterThanScore) / (targets.length * weightSum);
148                if (score < bestScore) {
149                    bestID = i;
150                    bestScore = score;
151                    bestSplitValue = (feature.get(j).value + feature.get(j + 1).value) / 2.0;
152                    // Clear out the old best indices before storing the new ones.
153                    bestLeftIndices.clear();
154                    bestLeftIndices.addAll(curLeftIndices);
155                    bestRightIndices.clear();
156                    bestRightIndices.addAll(curRightIndices);
157                    //logger.info("id = " + featureIDs[i] + ", split = " + bestSplitValue + ", score = " + score);
158                    //logger.info("less score = " +lessThanScore+", less size = "+lessThanIndices.size+", greater score = " + greaterThanScore+", greater size = "+greaterThanIndices.size);
159                }
160            }
161        }
162        List<AbstractTrainingNode<Regressor>> output;
163        // If we found a split better than the current impurity.
164        if (bestID != -1) {
165            splitID = featureIDs[bestID];
166            split = true;
167            splitValue = bestSplitValue;
168            IntArrayContainer firstBuffer = mergeBufferOne.get();
169            firstBuffer.size = 0;
170            firstBuffer.grow(indices.length);
171            IntArrayContainer secondBuffer = mergeBufferTwo.get();
172            secondBuffer.size = 0;
173            secondBuffer.grow(indices.length);
174            int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, firstBuffer, secondBuffer);
175            int[] rightIndices = IntArrayContainer.merge(bestRightIndices, firstBuffer, secondBuffer);
176            //logger.info("Splitting on feature " + bestID + " with value " + bestSplitValue + " at depth " + depth + ", " + numExamples + " examples in node.");
177            //logger.info("left indices length = " + leftIndices.length);
178            ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size());
179            ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size());
180            for (TreeFeature feature : data) {
181                Pair<TreeFeature,TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer);
182                lessThanData.add(split.getA());
183                greaterThanData.add(split.getB());
184            }
185            lessThanOrEqual = new JointRegressorTrainingNode(impurity, lessThanData, leftIndices, targets, weights, leftIndices.length, depth + 1, featureIDMap, labelIDMap, normalize);
186            greaterThan = new JointRegressorTrainingNode(impurity, greaterThanData, rightIndices, targets, weights, rightIndices.length, depth + 1, featureIDMap, labelIDMap, normalize);
187            output = new ArrayList<>();
188            output.add(lessThanOrEqual);
189            output.add(greaterThan);
190        } else {
191            output = Collections.emptyList();
192        }
193        data = null;
194        return output;
195    }
196
197    /**
198     * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node.
199     * @return A subtree using the SplitNode and LeafNode classes.
200     */
201    @Override
202    public Node<Regressor> convertTree() {
203        if (split) {
204            // split node
205            Node<Regressor> newGreaterThan = greaterThan.convertTree();
206            Node<Regressor> newLessThan = lessThanOrEqual.convertTree();
207            return new SplitNode<>(splitValue,splitID,getImpurity(),newGreaterThan,newLessThan);
208        } else {
209            double weightSum = 0.0;
210            double[] mean = new double[targets.length];
211            Regressor leafPred;
212            if (normalize) {
213                for (int i = 0; i < indices.length; i++) {
214                    int idx = indices[i];
215                    float weight = weights[idx];
216                    weightSum += weight;
217                    for (int j = 0; j < targets.length; j++) {
218                        float value = targets[j][idx];
219
220                        double oldMean = mean[j];
221                        mean[j] += (weight / weightSum) * (value - oldMean);
222                    }
223                }
224                String[] names = new String[targets.length];
225                double sum = 0.0;
226                for (int i = 0; i < targets.length; i++) {
227                    names[i] = labelIDMap.getOutput(i).getNames()[0];
228                    sum += mean[i];
229                }
230                // Normalize all the outputs so they sum to 1.0.
231                for (int i = 0; i < targets.length; i++) {
232                    mean[i] /= sum;
233                }
234                leafPred = new Regressor(names, mean);
235            } else {
236                double[] variance = new double[targets.length];
237                for (int i = 0; i < indices.length; i++) {
238                    int idx = indices[i];
239                    float weight = weights[idx];
240                    weightSum += weight;
241                    for (int j = 0; j < targets.length; j++) {
242                        float value = targets[j][idx];
243
244                        double oldMean = mean[j];
245                        mean[j] += (weight / weightSum) * (value - oldMean);
246                        variance[j] += weight * (value - oldMean) * (value - mean[j]);
247                    }
248                }
249                String[] names = new String[targets.length];
250                for (int i = 0; i < targets.length; i++) {
251                    names[i] = labelIDMap.getOutput(i).getNames()[0];
252                    variance[i] = indices.length > 1 ? variance[i] / (weightSum - 1) : 0;
253                }
254                leafPred = new Regressor(names, mean, variance);
255            }
256            return new LeafNode<>(getImpurity(),leafPred,Collections.emptyMap(),false);
257        }
258    }
259
260    /**
261     * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset
262     * so it's very expensive in terms of memory.
263     * @param examples An input dataset.
264     * @return A list of TreeFeatures which contain {@link InvertedFeature}s.
265     */
266    private static InvertedData invertData(Dataset<Regressor> examples) {
267        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
268        ImmutableOutputInfo<Regressor> labelInfo = examples.getOutputIDInfo();
269        int numLabels = labelInfo.size();
270        int numFeatures = featureInfos.size();
271        int[] indices = new int[examples.size()];
272        float[][] targets = new float[labelInfo.size()][examples.size()];
273        float[] weights = new float[examples.size()];
274
275        logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs");
276        ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size());
277
278        for (int i = 0; i < featureInfos.size(); i++) {
279            data.add(new TreeFeature(i));
280        }
281
282        for (int i = 0; i < examples.size(); i++) {
283            Example<Regressor> e = examples.getExample(i);
284            indices[i] = i;
285            weights[i] = e.getWeight();
286            double[] newTargets = e.getOutput().getValues();
287            for (int j = 0; j < targets.length; j++) {
288                targets[j][i] = (float) newTargets[j];
289            }
290            SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false);
291            int lastID = 0;
292            for (VectorTuple f : vec) {
293                int curID = f.index;
294                for (int j = lastID; j < curID; j++) {
295                    data.get(j).observeValue(0.0,i);
296                }
297                data.get(curID).observeValue(f.value,i);
298                //
299                // These two checks should never occur as SparseVector deals with
300                // collisions, and Dataset prevents repeated features.
301                // They are left in just to make sure.
302                if (lastID > curID) {
303                    logger.severe("Example = " + e.toString());
304                    throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
305                } else if (lastID-1 == curID) {
306                    logger.severe("Example = " + e.toString());
307                    throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
308                }
309                lastID = curID + 1;
310            }
311            for (int j = lastID; j < numFeatures; j++) {
312                data.get(j).observeValue(0.0,i);
313            }
314            if (i % 1000 == 0) {
315                logger.fine("Processed example " + i);
316            }
317        }
318
319        logger.fine("Sorting features");
320
321        data.forEach(TreeFeature::sort);
322
323        logger.fine("Fixing InvertedFeature sizes");
324
325        data.forEach(TreeFeature::fixSize);
326
327        logger.fine("Built initial List<TreeFeature>");
328
329        return new InvertedData(data,indices,targets,weights);
330    }
331
332    private static class InvertedData {
333        final ArrayList<TreeFeature> data;
334        final int[] indices;
335        final float[][] targets;
336        final float[] weights;
337
338        InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) {
339            this.data = data;
340            this.indices = indices;
341            this.targets = targets;
342            this.weights = weights;
343        }
344    }
345
346    private void writeObject(java.io.ObjectOutputStream stream)
347            throws IOException {
348        throw new NotSerializableException("JointRegressorTrainingNode is a runtime class only, and should not be serialized.");
349    }
350}