Class CARTJointRegressionTrainer
java.lang.Object
org.tribuo.common.tree.AbstractCARTTrainer<Regressor>
org.tribuo.regression.rtree.CARTJointRegressionTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,DecisionTreeTrainer<Regressor>,SparseTrainer<Regressor>,Trainer<Regressor>,WeightedExamples
-
Nested Class Summary
Nested classes/interfaces inherited from class org.tribuo.common.tree.AbstractCARTTrainer
AbstractCARTTrainer.AbstractCARTTrainerProvenance -
Field Summary
Fields inherited from class org.tribuo.common.tree.AbstractCARTTrainer
fractionFeaturesInSplit, maxDepth, MIN_EXAMPLES, minChildWeight, minImpurityDecrease, rng, seed, trainInvocationCounter, useRandomSplitPointsFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED -
Constructor Summary
ConstructorsConstructorDescriptionCreates a CART Trainer.CARTJointRegressionTrainer(int maxDepth) Creates a CART Trainer.CARTJointRegressionTrainer(int maxDepth, boolean normalize) Creates a CART Trainer.CARTJointRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, RegressorImpurity impurity, boolean normalize, long seed) Creates a CART Trainer.CARTJointRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, RegressorImpurity impurity, boolean normalize, long seed) Creates a CART Trainer. -
Method Summary
Modifier and TypeMethodDescriptionprotected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer) toString()Methods inherited from class org.tribuo.common.tree.AbstractCARTTrainer
getFractionFeaturesInSplit, getInvocationCount, getMinImpurityDecrease, getUseRandomSplitPoints, postConfig, train, train
-
Constructor Details
-
CARTJointRegressionTrainer
public CARTJointRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, RegressorImpurity impurity, boolean normalize, long seed) Creates a CART Trainer.- Parameters:
maxDepth- maxDepth The maximum depth of the tree.minChildWeight- minChildWeight The minimum node weight to consider it for a split.minImpurityDecrease- The minimum decrease in impurity necessary to split a node.fractionFeaturesInSplit- fractionFeaturesInSplit The fraction of features available in each split.useRandomSplitPoints- Whether to choose split points for features at random.impurity- impurity The impurity function to use to determine split quality.normalize- Normalize the leaves so each output sums to one.seed- The seed to use for the RNG.
-
CARTJointRegressionTrainer
public CARTJointRegressionTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, RegressorImpurity impurity, boolean normalize, long seed) Creates a CART Trainer.Computes the exact split point.
- Parameters:
maxDepth- maxDepth The maximum depth of the tree.minChildWeight- minChildWeight The minimum node weight to consider it for a split.minImpurityDecrease- The minimum decrease in impurity necessary to split a node.fractionFeaturesInSplit- fractionFeaturesInSplit The fraction of features available in each split.impurity- impurity The impurity function to use to determine split quality.normalize- Normalize the leaves so each output sums to one.seed- The seed to use for the RNG.
-
CARTJointRegressionTrainer
public CARTJointRegressionTrainer()Creates a CART Trainer.Sets the impurity to the
MeanSquaredError, computes an arbitrary depth tree with exact split points using all the features, and does not normalize the outputs. -
CARTJointRegressionTrainer
public CARTJointRegressionTrainer(int maxDepth) Creates a CART Trainer.Sets the impurity to the
MeanSquaredError, computes the exact split points using all the features, and does not normalize the outputs.- Parameters:
maxDepth- The maximum depth of the tree.
-
CARTJointRegressionTrainer
public CARTJointRegressionTrainer(int maxDepth, boolean normalize) Creates a CART Trainer. Sets the impurity to theMeanSquaredError.- Parameters:
maxDepth- The maximum depth of the tree.normalize- Normalises the leaves so each leaf has a distribution which sums to 1.0.
-
-
Method Details
-
mkTrainingNode
protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer) - Specified by:
mkTrainingNodein classAbstractCARTTrainer<Regressor>
-
toString
-
getProvenance
-