001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.tree; 018 019import org.tribuo.Output; 020import org.tribuo.math.la.SparseVector; 021 022/** 023 * An immutable {@link Node} with a split and two child nodes. 024 */ 025public class SplitNode<T extends Output<T>> implements Node<T> { 026 private static final long serialVersionUID = 3L; 027 028 private final Node<T> greaterThan; 029 030 private final Node<T> lessThanOrEqual; 031 032 private final int splitFeature; 033 034 private final double splitValue; 035 036 private final double impurity; 037 038 /** 039 * Constructs a split node with the specified split value, feature id, impurity and child nodes. 040 * @param splitValue The feature value to split on. 041 * @param featureID The feature id number. 042 * @param impurity The impurity of this node at training time. 043 * @param greaterThan The node to take if the feature value is greater than the split value. 044 * @param lessThanOrEqual The node to take if the feature value is less than or equal to the split value. 045 */ 046 public SplitNode(double splitValue, int featureID, double impurity, Node<T> greaterThan, Node<T> lessThanOrEqual) { 047 this.splitValue = splitValue; 048 this.splitFeature = featureID; 049 this.impurity = impurity; 050 this.greaterThan = greaterThan; 051 this.lessThanOrEqual = lessThanOrEqual; 052 } 053 054 /** 055 * Return the appropriate child node. If the splitFeature is not present in 056 * the example it's value is treated as zero. 057 * 058 * @param e The example to inspect. 059 * @return The corresponding child node. 060 */ 061 @Override 062 public Node<T> getNextNode(SparseVector e) { 063 double feature = e.get(splitFeature); 064 if (feature > splitValue) { 065 return greaterThan; 066 } else { 067 return lessThanOrEqual; 068 } 069 } 070 071 @Override 072 public boolean isLeaf() { 073 return false; 074 } 075 076 @Override 077 public double getImpurity() { 078 return impurity; 079 } 080 081 @Override 082 public Node<T> copy() { 083 return new SplitNode<>(splitValue,splitFeature,impurity,greaterThan.copy(),lessThanOrEqual.copy()); 084 } 085 086 /** 087 * Gets the feature ID that this node uses for splitting. 088 * @return The feature ID. 089 */ 090 public int getFeatureID() { 091 return splitFeature; 092 } 093 094 /** 095 * The threshold value. 096 * @return The threshold value. 097 */ 098 public double splitValue() { 099 return splitValue; 100 } 101 102 /** 103 * The node used if the value is greater than the splitValue. 104 * @return The greater than node. 105 */ 106 public Node<T> getGreaterThan() { 107 return greaterThan; 108 } 109 110 /** 111 * The node used if the value is less than or equal to the splitValue. 112 * @return The less than or equal to node. 113 */ 114 public Node<T> getLessThanOrEqual() { 115 return lessThanOrEqual; 116 } 117 118 @Override 119 public String toString() { 120 return "SplitNode(feature="+splitFeature+",value="+splitValue+",impurity="+impurity+",\n\t\tleft="+lessThanOrEqual.toString()+",\n\t\tright="+greaterThan.toString()+")"; 121 } 122 123} 124