001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.tree;
018
019import org.tribuo.Output;
020import org.tribuo.math.la.SparseVector;
021
022import java.util.List;
023
024/**
025 * Base class for decision tree nodes used at training time.
026 */
027public abstract class AbstractTrainingNode<T extends Output<T>> implements Node<T> {
028
029    /**
030     * Default buffer size used in the split operation.
031     */
032    protected static final int DEFAULT_SIZE = 16;
033
034    protected final int depth;
035
036    protected final int numExamples;
037
038    protected boolean split;
039
040    protected int splitID;
041
042    protected double splitValue;
043    
044    protected AbstractTrainingNode<T> greaterThan;
045    
046    protected AbstractTrainingNode<T> lessThanOrEqual;
047
048    /**
049     * Builds an abstract training node.
050     * @param depth The depth of this node.
051     * @param numExamples The number of examples in this node.
052     */
053    protected AbstractTrainingNode(int depth, int numExamples) {
054        this.depth = depth;
055        this.numExamples = numExamples;
056    }
057
058    public abstract List<AbstractTrainingNode<T>> buildTree(int[] indices);
059
060    /**
061     * Converts a tree from a training representation to the final inference time representation.
062     * @return The converted subtree.
063     */
064    public abstract Node<T> convertTree();
065
066    /**
067     * The depth of this node in the tree.
068     * @return The depth.
069     */
070    public int getDepth() {
071        return depth;
072    }
073
074    @Override
075    public Node<T> getNextNode(SparseVector example) {
076        if (split) {
077            double feature = example.get(splitID);
078            if (feature > splitValue) {
079                return greaterThan;
080            } else {
081                return lessThanOrEqual;
082            }
083        } else {
084            return null;
085        }
086    }
087
088    /**
089     * The number of training examples in this node.
090     * @return The number of training examples in this node.
091     */
092    public int getNumExamples() {
093        return numExamples;
094    }
095
096    @Override
097    public boolean isLeaf() {
098        return !split;
099    }
100
101    @Override
102    public Node<T> copy() {
103        throw new UnsupportedOperationException("Copy is not supported on training nodes.");
104    }
105}