001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.tree; 018 019import org.tribuo.Example; 020import org.tribuo.Output; 021import org.tribuo.Prediction; 022import org.tribuo.math.la.SparseVector; 023 024import java.util.HashMap; 025import java.util.Map; 026 027/** 028 * An immutable leaf {@link Node} that can create a prediction. 029 */ 030public class LeafNode<T extends Output<T>> implements Node<T> { 031 private static final long serialVersionUID = 4L; 032 033 private final double impurity; 034 035 private final T output; 036 private final Map<String,T> scores; 037 private final boolean generatesProbabilities; 038 039 /** 040 * Constructs a leaf node. 041 * @param impurity The impurity value calculated at training time. 042 * @param output The output value from this node. 043 * @param scores The score map for the other outputs. 044 * @param generatesProbabilities If the scores are probabilities. 045 */ 046 public LeafNode(double impurity, T output, Map<String,T> scores, boolean generatesProbabilities) { 047 this.impurity = impurity; 048 this.output = output; 049 this.scores = scores; 050 this.generatesProbabilities = generatesProbabilities; 051 } 052 053 @Override 054 public Node<T> getNextNode(SparseVector e) { 055 return null; 056 } 057 058 @Override 059 public boolean isLeaf() { 060 return true; 061 } 062 063 @Override 064 public double getImpurity() { 065 return impurity; 066 } 067 068 @Override 069 public LeafNode<T> copy() { 070 return new LeafNode<>(impurity,output.copy(),new HashMap<>(scores),generatesProbabilities); 071 } 072 073 /** 074 * Gets the output in this node. 075 * @return The output. 076 */ 077 public T getOutput() { 078 return output; 079 } 080 081 /** 082 * Gets the distribution over scores in this node. 083 * @return The score distribution. 084 */ 085 public Map<String,T> getDistribution() { 086 return scores; 087 } 088 089 /** 090 * Constructs a new prediction object based on this node's scores. 091 * @param numUsed The number of features used. 092 * @param example The example to be scored. 093 * @return The prediction for the supplied example. 094 */ 095 public Prediction<T> getPrediction(int numUsed, Example<T> example) { 096 return new Prediction<>(output,scores,numUsed,example,generatesProbabilities); 097 } 098 099 @Override 100 public String toString() { 101 return "LeafNode(impurity="+impurity+",output="+output.toString()+",scores="+scores.toString()+",probability="+generatesProbabilities+")"; 102 } 103 104}