public abstract class AbstractCARTTrainer<T extends Output<T>> extends Object implements DecisionTreeTrainer<T>
Modifier and Type | Class and Description |
---|---|
protected static class |
AbstractCARTTrainer.AbstractCARTTrainerProvenance
Deprecated.
|
Modifier and Type | Field and Description |
---|---|
protected float |
fractionFeaturesInSplit
Number of features to sample per split.
|
protected int |
maxDepth
Maximum tree depth.
|
static int |
MIN_EXAMPLES
Default minimum weight of examples allowed in a leaf node.
|
protected float |
minChildWeight
Minimum weight of examples allowed in a leaf.
|
protected float |
minImpurityDecrease
Minimum impurity decrease.
|
protected SplittableRandom |
rng |
protected long |
seed |
protected int |
trainInvocationCounter |
protected boolean |
useRandomSplitPoints
Whether to choose split points for features at random.
|
DEFAULT_SEED
Modifier | Constructor and Description |
---|---|
protected |
AbstractCARTTrainer(int maxDepth,
float minChildWeight,
float minImpurityDecrease,
float fractionFeaturesInSplit,
boolean useRandomSplitPoints,
long seed)
After calls to this superconstructor subclasses must call postConfig().
|
Modifier and Type | Method and Description |
---|---|
float |
getFractionFeaturesInSplit()
Returns the feature subsampling rate.
|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
float |
getMinImpurityDecrease()
Returns the minimum decrease in impurity necessary to split a node.
|
boolean |
getUseRandomSplitPoints()
Returns whether to choose split points for features at random.
|
protected abstract AbstractTrainingNode<T> |
mkTrainingNode(Dataset<T> examples,
AbstractTrainingNode.LeafDeterminer leafDeterminer) |
void |
postConfig()
Used by the OLCUT configuration system, and should not be called by external code.
|
TreeModel<T> |
train(Dataset<T> examples)
Trains a sparse predictive model using the examples in the given data set.
|
TreeModel<T> |
train(Dataset<T> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a sparse predictive model using the examples in the given data set.
|
public static final int MIN_EXAMPLES
@Config(description="The minimum weight allowed in a child node.") protected float minChildWeight
@Config(description="The maximum depth of the tree.") protected int maxDepth
@Config(description="The decrease in impurity needed in order to split the node.") protected float minImpurityDecrease
@Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.") protected float fractionFeaturesInSplit
@Config(description="Whether to choose split points for features at random.") protected boolean useRandomSplitPoints
@Config(description="The RNG seed to use when sampling features in a split.") protected long seed
protected SplittableRandom rng
protected int trainInvocationCounter
protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, long seed)
maxDepth
- The maximum depth of the tree.minChildWeight
- The minimum child weight allowed.minImpurityDecrease
- The minimum decrease in impurity necessary to split a node.fractionFeaturesInSplit
- The fraction of features to consider at each split.useRandomSplitPoints
- Whether to choose split points for features at random.seed
- The seed for the feature subsampling RNG.public void postConfig()
postConfig
in interface com.oracle.labs.mlrg.olcut.config.Configurable
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<T extends Output<T>>
public float getFractionFeaturesInSplit()
DecisionTreeTrainer
getFractionFeaturesInSplit
in interface DecisionTreeTrainer<T extends Output<T>>
public boolean getUseRandomSplitPoints()
DecisionTreeTrainer
getUseRandomSplitPoints
in interface DecisionTreeTrainer<T extends Output<T>>
public float getMinImpurityDecrease()
DecisionTreeTrainer
getMinImpurityDecrease
in interface DecisionTreeTrainer<T extends Output<T>>
public TreeModel<T> train(Dataset<T> examples)
SparseTrainer
public TreeModel<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
SparseTrainer
train
in interface SparseTrainer<T extends Output<T>>
train
in interface Trainer<T extends Output<T>>
examples
- the data set containing the examples.runProvenance
- Training run specific provenance (e.g., fold number).protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer)
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.