001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.tree;
018
019import org.tribuo.Example;
020import org.tribuo.Output;
021import org.tribuo.Prediction;
022import org.tribuo.math.la.SparseVector;
023
024import java.util.HashMap;
025import java.util.Map;
026
027/**
028 * An immutable leaf {@link Node} that can create a prediction.
029 */
030public class LeafNode<T extends Output<T>> implements Node<T> {
031    private static final long serialVersionUID = 4L;
032
033    private final double impurity;
034
035    private final T output;
036    private final Map<String,T> scores;
037    private final boolean generatesProbabilities;
038
039    /**
040     * Constructs a leaf node.
041     * @param impurity The impurity value calculated at training time.
042     * @param output The output value from this node.
043     * @param scores The score map for the other outputs.
044     * @param generatesProbabilities If the scores are probabilities.
045     */
046    public LeafNode(double impurity, T output, Map<String,T> scores, boolean generatesProbabilities) {
047        this.impurity = impurity;
048        this.output = output;
049        this.scores = scores;
050        this.generatesProbabilities = generatesProbabilities;
051    }
052
053    @Override
054    public Node<T> getNextNode(SparseVector e) {
055        return null;
056    }
057    
058    @Override
059    public boolean isLeaf() {
060        return true;
061    }
062
063    @Override
064    public double getImpurity() {
065        return impurity;
066    }
067
068    @Override
069    public LeafNode<T> copy() {
070        return new LeafNode<>(impurity,output.copy(),new HashMap<>(scores),generatesProbabilities);
071    }
072
073    /**
074     * Gets the output in this node.
075     * @return The output.
076     */
077    public T getOutput() {
078        return output;
079    }
080
081    /**
082     * Gets the distribution over scores in this node.
083     * @return The score distribution.
084     */
085    public Map<String,T> getDistribution() {
086        return scores;
087    }
088
089    /**
090     * Constructs a new prediction object based on this node's scores.
091     * @param numUsed The number of features used.
092     * @param example The example to be scored.
093     * @return The prediction for the supplied example.
094     */
095    public Prediction<T> getPrediction(int numUsed, Example<T> example) {
096        return new Prediction<>(output,scores,numUsed,example,generatesProbabilities);
097    }
098
099    @Override
100    public String toString() {
101        return "LeafNode(impurity="+impurity+",output="+output.toString()+",scores="+scores.toString()+",probability="+generatesProbabilities+")";
102    }
103
104}