public abstract class AbstractCARTTrainer<T extends Output<T>> extends Object implements DecisionTreeTrainer<T>
Modifier and Type | Class and Description |
---|---|
protected static class |
AbstractCARTTrainer.AbstractCARTTrainerProvenance
Deprecated.
|
Modifier and Type | Field and Description |
---|---|
protected float |
fractionFeaturesInSplit
Number of features to sample per split.
|
protected int |
maxDepth
Maximum tree depth.
|
static int |
MIN_EXAMPLES
Default minimum weight of examples allowed in a leaf node.
|
protected float |
minChildWeight
Minimum weight of examples allowed in a leaf.
|
protected SplittableRandom |
rng |
protected long |
seed |
protected int |
trainInvocationCounter |
DEFAULT_SEED
Modifier | Constructor and Description |
---|---|
protected |
AbstractCARTTrainer(int maxDepth,
float minChildWeight,
float fractionFeaturesInSplit,
long seed)
After calls to this superconstructor subclasses must call postConfig().
|
Modifier and Type | Method and Description |
---|---|
float |
getFractionFeaturesInSplit()
Returns the feature subsampling rate.
|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
protected abstract AbstractTrainingNode<T> |
mkTrainingNode(Dataset<T> examples) |
void |
postConfig() |
TreeModel<T> |
train(Dataset<T> examples)
Trains a sparse predictive model using the examples in the given data set.
|
TreeModel<T> |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a sparse predictive model using the examples in the given data set.
|
public static final int MIN_EXAMPLES
@Config(description="The minimum weight allowed in a child node.") protected float minChildWeight
@Config(description="The maximum depth of the tree.") protected int maxDepth
@Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.") protected float fractionFeaturesInSplit
@Config(description="The RNG seed to use when sampling features in a split.") protected long seed
protected SplittableRandom rng
protected int trainInvocationCounter
protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float fractionFeaturesInSplit, long seed)
maxDepth
- The maximum depth of the tree.minChildWeight
- The minimum child weight allowed.fractionFeaturesInSplit
- The fraction of features to consider at each split.seed
- The seed for the feature subsampling RNG.public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
public float getFractionFeaturesInSplit()
DecisionTreeTrainer
getFractionFeaturesInSplit
in interface DecisionTreeTrainer<T extends Output<T>>
public TreeModel<T> train(Dataset<T> examples)
SparseTrainer
public TreeModel<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
SparseTrainer
train
in interface SparseTrainer<T extends Output<T>>
train
in interface Trainer<T extends Output<T>>
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> examples)
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.