public class CARTJointRegressionTrainer extends AbstractCARTTrainer<Regressor>
AbstractCARTTrainer.AbstractCARTTrainerProvenance
fractionFeaturesInSplit, maxDepth, MIN_EXAMPLES, minChildWeight, rng, seed, trainInvocationCounter
DEFAULT_SEED
Constructor and Description |
---|
CARTJointRegressionTrainer()
Creates a CART Trainer.
|
CARTJointRegressionTrainer(int maxDepth)
Creates a CART Trainer.
|
CARTJointRegressionTrainer(int maxDepth,
boolean normalize)
Creates a CART Trainer.
|
CARTJointRegressionTrainer(int maxDepth,
float minChildWeight,
float fractionFeaturesInSplit,
RegressorImpurity impurity,
boolean normalize,
long seed)
Creates a CART Trainer.
|
Modifier and Type | Method and Description |
---|---|
TrainerProvenance |
getProvenance() |
protected AbstractTrainingNode<Regressor> |
mkTrainingNode(Dataset<Regressor> examples) |
String |
toString() |
getFractionFeaturesInSplit, getInvocationCount, postConfig, train, train
public CARTJointRegressionTrainer(int maxDepth, float minChildWeight, float fractionFeaturesInSplit, RegressorImpurity impurity, boolean normalize, long seed)
maxDepth
- maxDepth The maximum depth of the tree.minChildWeight
- minChildWeight The minimum node weight to consider it for a split.fractionFeaturesInSplit
- fractionFeaturesInSplit The fraction of features available in each split.impurity
- impurity The impurity function to use to determine split quality.normalize
- Normalize the leaves so each output sums to one.seed
- The seed to use for the RNG.public CARTJointRegressionTrainer()
MeanSquaredError
and does not normalize the outputs.public CARTJointRegressionTrainer(int maxDepth)
MeanSquaredError
and does not normalize the outputs.maxDepth
- The maximum depth of the tree.public CARTJointRegressionTrainer(int maxDepth, boolean normalize)
MeanSquaredError
.maxDepth
- The maximum depth of the tree.normalize
- Normalises the leaves so each leaf has a distribution which sums to 1.0.protected AbstractTrainingNode<Regressor> mkTrainingNode(Dataset<Regressor> examples)
mkTrainingNode
in class AbstractCARTTrainer<Regressor>
public TrainerProvenance getProvenance()
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.