Package org.tribuo.common.tree
Class AbstractCARTTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.common.tree.AbstractCARTTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,DecisionTreeTrainer<T>
,SparseTrainer<T>
,Trainer<T>
,WeightedExamples
- Direct Known Subclasses:
CARTClassificationTrainer
,CARTJointRegressionTrainer
,CARTRegressionTrainer
public abstract class AbstractCARTTrainer<T extends Output<T>>
extends Object
implements DecisionTreeTrainer<T>
-
Nested Class Summary
Modifier and TypeClassDescriptionprotected static class
Deprecated. -
Field Summary
Modifier and TypeFieldDescriptionprotected float
Number of features to sample per split.protected int
Maximum tree depth.static final int
Default minimum weight of examples allowed in a leaf node.protected float
Minimum weight of examples allowed in a leaf.protected float
Minimum impurity decrease.protected SplittableRandom
protected long
protected int
protected boolean
Whether to choose split points for features at random.Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ModifierConstructorDescriptionprotected
AbstractCARTTrainer
(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, long seed) After calls to this superconstructor subclasses must call postConfig(). -
Method Summary
Modifier and TypeMethodDescriptionfloat
Returns the feature subsampling rate.int
The number of times this trainer instance has had it's train method invoked.float
Returns the minimum decrease in impurity necessary to split a node.boolean
Returns whether to choose split points for features at random.protected abstract AbstractTrainingNode<T>
mkTrainingNode
(Dataset<T> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer) Makes the initial training node.void
Used by the OLCUT configuration system, and should not be called by external code.void
setInvocationCount
(int invocationCount) Set the internal state of the trainer to the provided number of invocations of the train method.Trains a sparse predictive model using the examples in the given data set.train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sparse predictive model using the examples in the given data set.train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Trains a predictive model using the examples in the given data set.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.provenance.Provenancable
getProvenance
-
Field Details
-
MIN_EXAMPLES
public static final int MIN_EXAMPLESDefault minimum weight of examples allowed in a leaf node.- See Also:
-
minChildWeight
@Config(description="The minimum weight allowed in a child node.") protected float minChildWeightMinimum weight of examples allowed in a leaf. -
maxDepth
@Config(description="The maximum depth of the tree.") protected int maxDepthMaximum tree depth. Integer.MAX_VALUE indicates the depth is unlimited. -
minImpurityDecrease
@Config(description="The decrease in impurity needed in order to split the node.") protected float minImpurityDecreaseMinimum impurity decrease. The decrease in impurity needed in order to split the node. -
fractionFeaturesInSplit
@Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.") protected float fractionFeaturesInSplitNumber of features to sample per split. 1 indicates all features are considered. -
useRandomSplitPoints
@Config(description="Whether to choose split points for features at random.") protected boolean useRandomSplitPointsWhether to choose split points for features at random. -
seed
@Config(description="The RNG seed to use when sampling features in a split.") protected long seed -
rng
-
trainInvocationCounter
protected int trainInvocationCounter
-
-
Constructor Details
-
AbstractCARTTrainer
protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, long seed) After calls to this superconstructor subclasses must call postConfig().- Parameters:
maxDepth
- The maximum depth of the tree.minChildWeight
- The minimum child weight allowed.minImpurityDecrease
- The minimum decrease in impurity necessary to split a node.fractionFeaturesInSplit
- The fraction of features to consider at each split.useRandomSplitPoints
- Whether to choose split points for features at random.seed
- The seed for the feature subsampling RNG.
-
-
Method Details
-
postConfig
public void postConfig()Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
getInvocationCount
public int getInvocationCount()Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
setInvocationCount
public void setInvocationCount(int invocationCount) Description copied from interface:Trainer
Set the internal state of the trainer to the provided number of invocations of the train method.This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.
In a future major release this default implementation will be removed.
- Specified by:
setInvocationCount
in interfaceTrainer<T extends Output<T>>
- Parameters:
invocationCount
- the number of invocations of the train method to simulate
-
getFractionFeaturesInSplit
public float getFractionFeaturesInSplit()Description copied from interface:DecisionTreeTrainer
Returns the feature subsampling rate.- Specified by:
getFractionFeaturesInSplit
in interfaceDecisionTreeTrainer<T extends Output<T>>
- Returns:
- The feature subsampling rate.
-
getUseRandomSplitPoints
public boolean getUseRandomSplitPoints()Description copied from interface:DecisionTreeTrainer
Returns whether to choose split points for features at random.- Specified by:
getUseRandomSplitPoints
in interfaceDecisionTreeTrainer<T extends Output<T>>
- Returns:
- Whether to choose split points for features at random.
-
getMinImpurityDecrease
public float getMinImpurityDecrease()Description copied from interface:DecisionTreeTrainer
Returns the minimum decrease in impurity necessary to split a node.- Specified by:
getMinImpurityDecrease
in interfaceDecisionTreeTrainer<T extends Output<T>>
- Returns:
- The minimum decrease in impurity necessary to split a node.
-
train
Description copied from interface:SparseTrainer
Trains a sparse predictive model using the examples in the given data set. -
train
public TreeModel<T> train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:SparseTrainer
Trains a sparse predictive model using the examples in the given data set.- Specified by:
train
in interfaceSparseTrainer<T extends Output<T>>
- Specified by:
train
in interfaceTrainer<T extends Output<T>>
- Parameters:
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
train
public TreeModel<T> train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Description copied from interface:SparseTrainer
Trains a predictive model using the examples in the given data set.- Specified by:
train
in interfaceSparseTrainer<T extends Output<T>>
- Specified by:
train
in interfaceTrainer<T extends Output<T>>
- Parameters:
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).invocationCount
- The state of the RNG the trainer should be set to before training- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
mkTrainingNode
protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer) Makes the initial training node.- Parameters:
examples
- The dataset to use.leafDeterminer
- The leaf determination function.- Returns:
- The initial training node.
-