Class AbstractCARTTrainer<T extends Output<T>>
java.lang.Object
org.tribuo.common.tree.AbstractCARTTrainer<T>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,DecisionTreeTrainer<T>
,SparseTrainer<T>
,Trainer<T>
,WeightedExamples
- Direct Known Subclasses:
CARTClassificationTrainer
,CARTJointRegressionTrainer
,CARTRegressionTrainer
public abstract class AbstractCARTTrainer<T extends Output<T>>
extends Object
implements DecisionTreeTrainer<T>
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionprotected static class
Deprecated. -
Field Summary
FieldsModifier and TypeFieldDescriptionprotected float
Number of features to sample per split.protected int
Maximum tree depth.static final int
Default minimum weight of examples allowed in a leaf node.protected float
Minimum weight of examples allowed in a leaf.protected SplittableRandom
protected long
protected int
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED
-
Constructor Summary
ConstructorsModifierConstructorDescriptionprotected
AbstractCARTTrainer
(int maxDepth, float minChildWeight, float fractionFeaturesInSplit, long seed) After calls to this superconstructor subclasses must call postConfig(). -
Method Summary
Modifier and TypeMethodDescriptionfloat
Returns the feature subsampling rate.int
The number of times this trainer instance has had it's train method invoked.protected abstract AbstractTrainingNode
<T> mkTrainingNode
(Dataset<T> examples) void
Trains a sparse predictive model using the examples in the given data set.train
(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sparse predictive model using the examples in the given data set.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.provenance.Provenancable
getProvenance
-
Field Details
-
MIN_EXAMPLES
Default minimum weight of examples allowed in a leaf node.- See Also:
-
minChildWeight
Minimum weight of examples allowed in a leaf. -
maxDepth
Maximum tree depth. Integer.MAX_VALUE indicates the depth is unlimited. -
fractionFeaturesInSplit
@Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.") protected float fractionFeaturesInSplitNumber of features to sample per split. 1 indicates all features are considered. -
seed
-
rng
-
trainInvocationCounter
-
-
Constructor Details
-
AbstractCARTTrainer
protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float fractionFeaturesInSplit, long seed) After calls to this superconstructor subclasses must call postConfig().- Parameters:
maxDepth
- The maximum depth of the tree.minChildWeight
- The minimum child weight allowed.fractionFeaturesInSplit
- The fraction of features to consider at each split.seed
- The seed for the feature subsampling RNG.
-
-
Method Details
-
postConfig
- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
-
getInvocationCount
Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<T extends Output<T>>
- Returns:
- The number of train invocations.
-
getFractionFeaturesInSplit
Description copied from interface:DecisionTreeTrainer
Returns the feature subsampling rate.- Specified by:
getFractionFeaturesInSplit
in interfaceDecisionTreeTrainer<T extends Output<T>>
- Returns:
- The feature subsampling rate.
-
train
Description copied from interface:SparseTrainer
Trains a sparse predictive model using the examples in the given data set. -
train
public TreeModel<T> train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Description copied from interface:SparseTrainer
Trains a sparse predictive model using the examples in the given data set.- Specified by:
train
in interfaceSparseTrainer<T extends Output<T>>
- Specified by:
train
in interfaceTrainer<T extends Output<T>>
- Parameters:
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).- Returns:
- a predictive model that can be used to generate predictions for new examples.
-
mkTrainingNode
-