001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.tree;
018
019import org.tribuo.Output;
020import org.tribuo.math.la.SparseVector;
021
022/**
023 * An immutable {@link Node} with a split and two child nodes.
024 */
025public class SplitNode<T extends Output<T>> implements Node<T> {
026    private static final long serialVersionUID = 3L;
027
028    private final Node<T> greaterThan;
029
030    private final Node<T> lessThanOrEqual;
031
032    private final int splitFeature;
033
034    private final double splitValue;
035
036    private final double impurity;
037
038    /**
039     * Constructs a split node with the specified split value, feature id, impurity and child nodes.
040     * @param splitValue The feature value to split on.
041     * @param featureID The feature id number.
042     * @param impurity The impurity of this node at training time.
043     * @param greaterThan The node to take if the feature value is greater than the split value.
044     * @param lessThanOrEqual The node to take if the feature value is less than or equal to the split value.
045     */
046    public SplitNode(double splitValue, int featureID, double impurity, Node<T> greaterThan, Node<T> lessThanOrEqual) {
047        this.splitValue = splitValue;
048        this.splitFeature = featureID;
049        this.impurity = impurity;
050        this.greaterThan = greaterThan;
051        this.lessThanOrEqual = lessThanOrEqual;
052    }
053
054    /**
055     * Return the appropriate child node. If the splitFeature is not present in
056     * the example it's value is treated as zero.
057     *
058     * @param e The example to inspect.
059     * @return The corresponding child node.
060     */
061    @Override
062    public Node<T> getNextNode(SparseVector e) {
063        double feature = e.get(splitFeature);
064        if (feature > splitValue) {
065            return greaterThan;
066        } else {
067            return lessThanOrEqual;
068        }
069    }
070
071    @Override
072    public boolean isLeaf() { 
073        return false;
074    }
075
076    @Override
077    public double getImpurity() {
078        return impurity;
079    }
080
081    @Override
082    public Node<T> copy() {
083        return new SplitNode<>(splitValue,splitFeature,impurity,greaterThan.copy(),lessThanOrEqual.copy());
084    }
085
086    /**
087     * Gets the feature ID that this node uses for splitting.
088     * @return The feature ID.
089     */
090    public int getFeatureID() {
091        return splitFeature;
092    }
093
094    /**
095     * The threshold value.
096     * @return The threshold value.
097     */
098    public double splitValue() {
099        return splitValue;
100    }
101
102    /**
103     * The node used if the value is greater than the splitValue.
104     * @return The greater than node.
105     */
106    public Node<T> getGreaterThan() {
107        return greaterThan;
108    }
109
110    /**
111     * The node used if the value is less than or equal to the splitValue.
112     * @return The less than or equal to node.
113     */
114    public Node<T> getLessThanOrEqual() {
115        return lessThanOrEqual;
116    }
117
118    @Override
119    public String toString() {
120        return "SplitNode(feature="+splitFeature+",value="+splitValue+",impurity="+impurity+",\n\t\tleft="+lessThanOrEqual.toString()+",\n\t\tright="+greaterThan.toString()+")";
121    }
122    
123}
124