001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.dtree.impl;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.classification.Label;
025import org.tribuo.classification.dtree.impurity.LabelImpurity;
026import org.tribuo.common.tree.AbstractTrainingNode;
027import org.tribuo.common.tree.LeafNode;
028import org.tribuo.common.tree.Node;
029import org.tribuo.common.tree.SplitNode;
030import org.tribuo.common.tree.impl.IntArrayContainer;
031import org.tribuo.math.la.SparseVector;
032import org.tribuo.math.la.VectorTuple;
033import org.tribuo.util.Util;
034
035import java.io.IOException;
036import java.io.NotSerializableException;
037import java.util.ArrayList;
038import java.util.Arrays;
039import java.util.Collections;
040import java.util.LinkedHashMap;
041import java.util.List;
042import java.util.Map;
043import java.util.logging.Logger;
044
045/**
046 * A decision tree node used at training time.
047 * Contains a list of the example indices currently found in this node,
048 * the current impurity and a bunch of other statistics.
049 */
050public class ClassifierTrainingNode extends AbstractTrainingNode<Label> {
051    private static final long serialVersionUID = 1L;
052
053    private static final Logger logger = Logger.getLogger(ClassifierTrainingNode.class.getName());
054
055    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
056    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
057    private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> new IntArrayContainer(DEFAULT_SIZE));
058
059    private transient ArrayList<TreeFeature> data;
060
061    private final ImmutableOutputInfo<Label> labelIDMap;
062
063    private final ImmutableFeatureMap featureIDMap;
064
065    private final LabelImpurity impurity;
066
067    private final float[] labelCounts;
068
069    /**
070     * Constructor which creates the inverted file.
071     * @param impurity The impurity function to use.
072     * @param examples The training data.
073     */
074    public ClassifierTrainingNode(LabelImpurity impurity, Dataset<Label> examples) {
075        this(impurity,invertData(examples), examples.size(), 0, examples.getFeatureIDMap(), examples.getOutputIDInfo());
076    }
077
078    private ClassifierTrainingNode(LabelImpurity impurity, ArrayList<TreeFeature> data, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) {
079        super(depth, numExamples);
080        this.data = data;
081        this.featureIDMap = featureIDMap;
082        this.labelIDMap = labelIDMap;
083        this.impurity = impurity;
084        this.labelCounts = data.get(0).getLabelCounts();
085    }
086
087    /**
088     * Builds a tree according to CART (as it does not do multi-way splits on categorical values like C4.5).
089     * @param featureIDs Indices of the features available in this split.
090     * @return A possibly empty list of TrainingNodes.
091     */
092    @Override
093    public List<AbstractTrainingNode<Label>> buildTree(int[] featureIDs) {
094        int bestID = -1;
095        double bestSplitValue = 0.0;
096        double bestScore = impurity.impurity(labelCounts);
097        float[] lessThanCounts = new float[labelCounts.length];
098        float[] greaterThanCounts = new float[labelCounts.length];
099        double countsSum = Util.sum(labelCounts);
100        for (int i = 0; i < featureIDs.length; i++) {
101            List<InvertedFeature> feature = data.get(featureIDs[i]).getFeature();
102            Arrays.fill(lessThanCounts,0.0f);
103            System.arraycopy(labelCounts, 0, greaterThanCounts, 0, labelCounts.length);
104            // searching for the intervals between features.
105            for (int j = 0; j < feature.size()-1; j++) {
106                InvertedFeature f = feature.get(j);
107                float[] featureCounts = f.getLabelCounts();
108                Util.inPlaceAdd(lessThanCounts,featureCounts);
109                Util.inPlaceSubtract(greaterThanCounts,featureCounts);
110                double lessThanScore = impurity.impurityWeighted(lessThanCounts);
111                double greaterThanScore = impurity.impurityWeighted(greaterThanCounts);
112                if ((lessThanScore > 1e-10) && (greaterThanScore > 1e-10)) {
113                    double score = (lessThanScore + greaterThanScore) / countsSum;
114                    if (score < bestScore) {
115                        bestID = i;
116                        bestScore = score;
117                        bestSplitValue = (f.value + feature.get(j + 1).value) / 2.0;
118                    }
119                }
120            }
121        }
122        List<AbstractTrainingNode<Label>> output;
123        // If we found a split better than the current impurity.
124        if (bestID != -1) {
125            splitID = featureIDs[bestID];
126            split = true;
127            splitValue = bestSplitValue;
128            IntArrayContainer lessThanIndices = mergeBufferOne.get();
129            lessThanIndices.size = 0;
130            IntArrayContainer buffer = mergeBufferTwo.get();
131            buffer.size = 0;
132            for (InvertedFeature f : data.get(splitID)) {
133                if (f.value < splitValue) {
134                    int[] indices = f.indices();
135                    IntArrayContainer.merge(lessThanIndices,indices,buffer);
136                    // Swap the buffers
137                    IntArrayContainer tmp = lessThanIndices;
138                    lessThanIndices = buffer;
139                    buffer = tmp;
140                } else {
141                    break;
142                }
143            }
144            //logger.info("Splitting on feature " + maxID + " with value " + maxSplitValue + " at depth " + depth + ", " + numExamples + " examples in node.");
145            //logger.info("left indices length = " + lessThanIndices.size);
146            IntArrayContainer secondBuffer = mergeBufferThree.get();
147            secondBuffer.grow(lessThanIndices.size);
148            ArrayList<TreeFeature> lessThanData = new ArrayList<>(data.size());
149            ArrayList<TreeFeature> greaterThanData = new ArrayList<>(data.size());
150            for (TreeFeature feature : data) {
151                Pair<TreeFeature,TreeFeature> split = feature.split(lessThanIndices,buffer,secondBuffer);
152                lessThanData.add(split.getA());
153                greaterThanData.add(split.getB());
154            }
155
156            lessThanOrEqual = new ClassifierTrainingNode(impurity, lessThanData, lessThanIndices.size, depth + 1, featureIDMap, labelIDMap);
157            greaterThan = new ClassifierTrainingNode(impurity, greaterThanData, numExamples - lessThanIndices.size, depth + 1, featureIDMap, labelIDMap);
158            output = new ArrayList<>();
159            output.add(lessThanOrEqual);
160            output.add(greaterThan);
161        } else {
162            output = Collections.emptyList();
163        }
164        data = null;
165        return output;
166    }
167
168    /**
169     * Generates a test time tree (made of {@link SplitNode} and {@link LeafNode}) from the tree rooted at this node.
170     * @return A subtree using the SplitNode and LeafNode classes.
171     */
172    @Override
173    public Node<Label> convertTree() {
174        if (split) {
175            // split node
176            Node<Label> newGreaterThan = greaterThan.convertTree();
177            Node<Label> newLessThan = lessThanOrEqual.convertTree();
178            return new SplitNode<>(splitValue,splitID,impurity.impurity(labelCounts),newGreaterThan,newLessThan);
179        } else {
180            // leaf node
181            double[] normedCounts = Util.normalizeToDistribution(labelCounts);
182            double maxScore = Double.NEGATIVE_INFINITY;
183            Label maxLabel = null;
184            Map<String,Label> counts = new LinkedHashMap<>();
185            for (int i = 0; i < labelCounts.length; i++) {
186                String name = labelIDMap.getOutput(i).getLabel();
187                Label label = new Label(name,normedCounts[i]);
188                counts.put(name, label);
189                if (label.getScore() > maxScore) {
190                    maxScore = label.getScore();
191                    maxLabel = label;
192                }
193            }
194            return new LeafNode<>(impurity.impurity(labelCounts),maxLabel,counts,true);
195        }
196    }
197
198    @Override
199    public double getImpurity() {
200        return impurity.impurity(labelCounts);
201    }
202
203    /**
204     * Inverts a training dataset from row major to column major. This partially de-sparsifies the dataset
205     * so it's very expensive in terms of memory.
206     * @param examples An input dataset.
207     * @return A list of TreeFeatures which contain {@link InvertedFeature}s.
208     */
209    private static ArrayList<TreeFeature> invertData(Dataset<Label> examples) {
210        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
211        ImmutableOutputInfo<Label> labelInfo = examples.getOutputIDInfo();
212        int numLabels = labelInfo.size();
213        int numFeatures = featureInfos.size();
214        int numExamples = examples.size();
215
216        int[] labels = new int[numExamples];
217        float[] weights = new float[numExamples];
218
219        int k = 0;
220        for (Example<Label> e : examples) {
221            weights[k] = e.getWeight();
222            labels[k] = labelInfo.getID(e.getOutput());
223            k++;
224        }
225
226        logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " classes");
227        ArrayList<TreeFeature> data = new ArrayList<>(featureInfos.size());
228
229        for (int i = 0; i < featureInfos.size(); i++) {
230            data.add(new TreeFeature(i,numLabels,labels,weights));
231        }
232
233        for (int i = 0; i < examples.size(); i++) {
234            Example<Label> e = examples.getExample(i);
235            SparseVector vec = SparseVector.createSparseVector(e,featureInfos,false);
236            int lastID = 0;
237            for (VectorTuple f : vec) {
238                int curID = f.index;
239                for (int j = lastID; j < curID; j++) {
240                    data.get(j).observeValue(0.0,i);
241                }
242                data.get(curID).observeValue(f.value,i);
243                //
244                // These two checks should never occur as SparseVector deals with collisions, and Dataset prevents
245                // repeated features.
246                // They are left in just to make sure.
247                if (lastID > curID) {
248                    logger.severe("Example = " + e.toString());
249                    throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
250                } else if (lastID-1 == curID) {
251                    logger.severe("Example = " + e.toString());
252                    throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
253                }
254                lastID = curID + 1;
255            }
256            for (int j = lastID; j < numFeatures; j++) {
257                data.get(j).observeValue(0.0,i);
258            }
259            if (i % 1000 == 0) {
260                logger.fine("Processed example " + i);
261            }
262        }
263
264        logger.fine("Sorting features");
265
266        data.forEach(TreeFeature::sort);
267
268        /*
269        for (TreeFeature f : data) {
270            logger.info(f.toString());
271        }
272        */
273
274        logger.fine("Fixing InvertedFeature sizes");
275
276        data.forEach(TreeFeature::fixSize);
277
278        logger.fine("Built initial List<TreeFeature>");
279
280        return data;
281    }
282
283    private void writeObject(java.io.ObjectOutputStream stream)
284            throws IOException {
285        throw new NotSerializableException("ClassifierTrainingNode is a runtime class only, and should not be serialized.");
286    }
287}