Class TreeModel<T extends Output<T>>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.provenance.Provenancable<ModelProvenance>
,Serializable
,ProtoSerializable<org.tribuo.protos.core.ModelProto>
- Direct Known Subclasses:
IndependentRegressionTreeModel
- See Also:
-
Field Summary
Modifier and TypeFieldDescriptionstatic final int
Protobuf serialization version.Fields inherited from class org.tribuo.Model
ALL_OUTPUTS, BIAS_FEATURE, featureIDMap, generatesProbabilities, name, outputIDInfo, provenance, provenanceOutput
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
-
Constructor Summary
ModifierConstructorDescriptionprotected
TreeModel
(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String, List<String>> activeFeatures) Constructs a trained decision tree model. -
Method Summary
Modifier and TypeMethodDescriptionprotected static <T extends Output<T>>
intcomputeDepth
(int initialDepth, Node<T> root) Computes the depth of the tree.copy
(String newName, ModelProvenance newProvenance) Copies a model, replacing its provenance and name with the supplied values.int
countNodes
(Node<T> root) Counts the number of nodes in the tree rooted at the supplied node, including that node.static TreeModel<?>
deserializeFromProto
(int version, String className, com.google.protobuf.Any message) Deserialization factory.deserializeFromProtos
(List<org.tribuo.common.tree.protos.TreeNodeProto> nodeProtos, Class<U> outputClass) We will start off with a list of node builders that we will replace item-by-item with the nodes that they built.int
getDepth()
Probes the tree to find the depth.Generates an excuse for an example.Returns the set of features which are split on in this tree.getRoot()
Returns the root node of this tree.getTopFeatures
(int n) Gets the topn
features associated with this model.Uses the model to predict the output for a single example.org.tribuo.protos.core.ModelProto
Serializes this object to a protobuf.protected List<org.tribuo.common.tree.protos.TreeNodeProto>
serializeToNodes
(Node<T> root) Serializes the supplied node tree into a list of protobufs.toString()
Methods inherited from class org.tribuo.SparseModel
copy, getActiveFeatures
Methods inherited from class org.tribuo.Model
castModel, createDataCarrier, deserialize, deserializeFromFile, deserializeFromStream, generatesProbabilities, getExcuses, getFeatureIDMap, getName, getOutputIDInfo, getProvenance, innerPredict, predict, predict, serializeToFile, serializeToStream, setName, validate
-
Field Details
-
CURRENT_VERSION
public static final int CURRENT_VERSIONProtobuf serialization version.- See Also:
-
-
Constructor Details
-
TreeModel
protected TreeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, Map<String, List<String>> activeFeatures) Constructs a trained decision tree model.Only used when the tree has multiple roots, should only be called from subclasses when *all* other methods are overridden.
- Parameters:
name
- The model name.description
- The model provenance.featureIDMap
- The feature id map.outputIDInfo
- The output info.generatesProbabilities
- Does this model emit probabilities.activeFeatures
- The active feature set of the model.
-
-
Method Details
-
deserializeFromProto
public static TreeModel<?> deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException Deserialization factory.- Parameters:
version
- The serialized object version.className
- The class name.message
- The serialized data.- Returns:
- The deserialized object.
- Throws:
com.google.protobuf.InvalidProtocolBufferException
- If the protobuf could not be parsed from themessage
.
-
deserializeFromProtos
protected static <U extends Output<U>> List<Node<U>> deserializeFromProtos(List<org.tribuo.common.tree.protos.TreeNodeProto> nodeProtos, Class<U> outputClass) throws com.google.protobuf.InvalidProtocolBufferException We will start off with a list of node builders that we will replace item-by-item with the nodes that they built. We will start with the leaf nodes and add split nodes as they become ready to build. In this way we will travel up the tree and only attempt to build split nodes when both of their child nodes are available. It may seem a bit tortured to do it this way, but this approach preserves the immutability of the built nodes (the split nodes in particular). The split node builder only knows the index of its children when it is deserialized but must be given the actual nodes before their build method can be called. Note that we only add split node builders to the queue once they can be built because both children have been created and provided to the builder.This approach should traverse the entire tree in the correct order but we check at the end of the method that everything looks good.
- Type Parameters:
U
- The output type of the nodes.- Parameters:
nodeProtos
- The node protos to deserialize.outputClass
- The output type.- Returns:
- The nodes.
- Throws:
com.google.protobuf.InvalidProtocolBufferException
- If an unexpected proto is found.
-
getDepth
public int getDepth()Probes the tree to find the depth.- Returns:
- The depth of the tree.
-
computeDepth
Computes the depth of the tree.- Type Parameters:
T
- The output type of the tree.- Parameters:
initialDepth
- The current depth.root
- The root to probe.- Returns:
- The tree depth.
-
predict
Description copied from class:Model
Uses the model to predict the output for a single example.predict does not mutate the example.
Throws
IllegalArgumentException
if the example has no features or no feature overlap with the model. -
getTopFeatures
Description copied from class:Model
Gets the topn
features associated with this model.If the model does not produce per output feature lists, it returns a map with a single element with key Model.ALL_OUTPUTS.
If the model cannot describe it's top features then it returns
Collections.emptyMap()
.- Specified by:
getTopFeatures
in classModel<T extends Output<T>>
- Parameters:
n
- the number of features to return. If this value is less than 0, all features should be returned for each class, unless the model cannot score it's features.- Returns:
- a map from string outputs to an ordered list of pairs of feature names and weights associated with that feature in the model
-
getExcuse
Description copied from class:Model
Generates an excuse for an example.This attempts to explain a classification result. Generating an excuse may be quite an expensive operation.
This excuse either contains per class information or an entry with key Model.ALL_OUTPUTS.
The optional is empty if the model does not provide excuses.
-
copy
Description copied from class:Model
Copies a model, replacing its provenance and name with the supplied values.Used to provide the provenance removal functionality.
-
getFeatures
Returns the set of features which are split on in this tree.- Returns:
- The feature names used by this tree.
-
countNodes
Counts the number of nodes in the tree rooted at the supplied node, including that node.- Parameters:
root
- The tree root.- Returns:
- The number of nodes.
-
toString
-
getRoot
Returns the root node of this tree.- Returns:
- The root node.
-
serialize
public org.tribuo.protos.core.ModelProto serialize()Description copied from interface:ProtoSerializable
Serializes this object to a protobuf. -
serializeToNodes
Serializes the supplied node tree into a list of protobufs.- Parameters:
root
- The root of the tree to serialize.- Returns:
- The protobuf list.
-