Class TreeModel<T extends Output<T>>

java.lang.Object
org.tribuo.Model<T>
org.tribuo.SparseModel<T>
org.tribuo.common.tree.TreeModel<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>, Serializable, ProtoSerializable<org.tribuo.protos.core.ModelProto>
Direct Known Subclasses:
IndependentRegressionTreeModel

public class TreeModel<T extends Output<T>> extends SparseModel<T>
A Model wrapped around a decision tree root Node.
See Also:
  • Field Details

    • CURRENT_VERSION

      public static final int CURRENT_VERSION
      Protobuf serialization version.
      See Also:
  • Constructor Details

    • TreeModel

      protected TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String,List<String>> activeFeatures)
      Constructs a trained decision tree model.

      Only used when the tree has multiple roots, should only be called from subclasses when *all* other methods are overridden.

      Parameters:
      name - The model name.
      description - The model provenance.
      featureIDMap - The feature id map.
      outputIDInfo - The output info.
      generatesProbabilities - Does this model emit probabilities.
      activeFeatures - The active feature set of the model.
  • Method Details

    • deserializeFromProto

      public static TreeModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • deserializeFromProtos

      protected static <U extends Output<U>> List<Node<U>> deserializeFromProtos(List<org.tribuo.common.tree.protos.TreeNodeProto> nodeProtos, Class<U> outputClass) throws com.google.protobuf.InvalidProtocolBufferException
      We will start off with a list of node builders that we will replace item-by-item with the nodes that they built. We will start with the leaf nodes and add split nodes as they become ready to build. In this way we will travel up the tree and only attempt to build split nodes when both of their child nodes are available. It may seem a bit tortured to do it this way, but this approach preserves the immutability of the built nodes (the split nodes in particular). The split node builder only knows the index of its children when it is deserialized but must be given the actual nodes before their build method can be called. Note that we only add split node builders to the queue once they can be built because both children have been created and provided to the builder.

      This approach should traverse the entire tree in the correct order but we check at the end of the method that everything looks good.

      Type Parameters:
      U - The output type of the nodes.
      Parameters:
      nodeProtos - The node protos to deserialize.
      outputClass - The output type.
      Returns:
      The nodes.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If an unexpected proto is found.
    • getDepth

      public int getDepth()
      Probes the tree to find the depth.
      Returns:
      The depth of the tree.
    • computeDepth

      protected static <T extends Output<T>> int computeDepth(int initialDepth, Node<T> root)
      Computes the depth of the tree.
      Type Parameters:
      T - The output type of the tree.
      Parameters:
      initialDepth - The current depth.
      root - The root to probe.
      Returns:
      The tree depth.
    • predict

      public Prediction<T> predict(Example<T> example)
      Description copied from class: Model
      Uses the model to predict the output for a single example.

      predict does not mutate the example.

      Throws IllegalArgumentException if the example has no features or no feature overlap with the model.

      Specified by:
      predict in class Model<T extends Output<T>>
      Parameters:
      example - the example to predict.
      Returns:
      the result of the prediction.
    • getTopFeatures

      public Map<String,List<com.oracle.labs.mlrg.olcut.util.Pair<String,Double>>> getTopFeatures(int n)
      Description copied from class: Model
      Gets the top n features associated with this model.

      If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.

      If the model cannot describe it's top features then it returns Collections.emptyMap().

      Specified by:
      getTopFeatures in class Model<T extends Output<T>>
      Parameters:
      n - the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score it's features.
      Returns:
      a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
    • getExcuse

      public Optional<Excuse<T>> getExcuse(Example<T> example)
      Description copied from class: Model
      Generates an excuse for an example.

      This attempts to explain a classification result. Generating an excuse may be quite an expensive operation.

      This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS.

      The optional is empty if the model does not provide excuses.

      Specified by:
      getExcuse in class Model<T extends Output<T>>
      Parameters:
      example - The input example.
      Returns:
      An optional excuse object. The optional is empty if this model does not provide excuses.
    • copy

      protected TreeModel<T> copy(String newName, ModelProvenance newProvenance)
      Description copied from class: Model
      Copies a model, replacing its provenance and name with the supplied values.

      Used to provide the provenance removal functionality.

      Specified by:
      copy in class Model<T extends Output<T>>
      Parameters:
      newName - The new name.
      newProvenance - The new provenance.
      Returns:
      A copy of the model.
    • getFeatures

      public Set<String> getFeatures()
      Returns the set of features which are split on in this tree.
      Returns:
      The feature names used by this tree.
    • countNodes

      public int countNodes(Node<T> root)
      Counts the number of nodes in the tree rooted at the supplied node, including that node.
      Parameters:
      root - The tree root.
      Returns:
      The number of nodes.
    • toString

      public String toString()
      Overrides:
      toString in class Model<T extends Output<T>>
    • getRoot

      public Node<T> getRoot()
      Returns the root node of this tree.
      Returns:
      The root node.
    • serialize

      public org.tribuo.protos.core.ModelProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<T extends Output<T>>
      Overrides:
      serialize in class Model<T extends Output<T>>
      Returns:
      The protobuf.
    • serializeToNodes

      protected List<org.tribuo.common.tree.protos.TreeNodeProto> serializeToNodes(Node<T> root)
      Serializes the supplied node tree into a list of protobufs.
      Parameters:
      root - The root of the tree to serialize.
      Returns:
      The protobuf list.