Class TreeModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.SparseModel<T>
org.tribuo.common.tree.TreeModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable
Direct Known Subclasses:
IndependentRegressionTreeModel

public class TreeModel<T extends Output<T>> extends SparseModel<T>
A Model wrapped around a decision tree root Node.
See Also:
  • Constructor Details

    • TreeModel

      protected TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String,List<String>> activeFeatures)
      Constructs a trained decision tree model.

      Only used when the tree has multiple roots, should only be called from subclasses when *all* other methods are overridden.

      Parameters:
      name - The model name.
      description - The model provenance.
      featureIDMap - The feature id map.
      outputIDInfo - The output info.
      generatesProbabilities - Does this model emit probabilities.
      activeFeatures - The active feature set of the model.
  • Method Details

    • getDepth

      public int getDepth()
      Probes the tree to find the depth.
      Returns:
      The depth of the tree.
    • computeDepth

      protected static <T extends Output<T>> int computeDepth(int initialDepth, Node<T> root)
    • predict

      public Prediction<T> predict(Example<T> example)
      Description copied from class: Model
      Uses the model to predict the output for a single example.

      predict does not mutate the example.

      Throws IllegalArgumentException if the example has no features or no feature overlap with the model.

      Specified by:
      predict in class Model<T extends Output<T>>
      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • getTopFeatures

      public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Description copied from class: Model
      Gets the top n features associated with this model.

      If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.

      If the model cannot describe it's top features then it returns Collections.emptyMap().

      Specified by:
      getTopFeatures in class Model<T extends Output<T>>
      Parameters:
      n - the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score it's features.
      Returns:
      a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
    • getExcuse

      public Optional<Excuse<T>> getExcuse(Example<T> example)
      Description copied from class: Model
      Generates an excuse for an example.

      This attempts to explain a classification result. Generating an excuse may be quite an expensive operation.

      This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS.

      The optional is empty if the model does not provide excuses.

      Specified by:
      getExcuse in class Model<T extends Output<T>>
      Parameters:
      example - The input example.
      Returns:
      An optional excuse object. The optional is empty if this model does not provide excuses.
    • copy

      protected TreeModel<T> copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing its provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<T extends Output<T>>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • getFeatures

      public Set<String> getFeatures()
      Returns the set of features which are split on in this tree.
      Returns:
      The feature names used by this tree.
    • toString

      public String toString()
      Overrides:
      toString in class Model<T extends Output<T>>
    • getRoot

      public Node<T> getRoot()
      Returns the root node of this tree.
      Returns:
      The root node.