Class AbstractCARTTrainer<T extends Output<T>>

java.lang.Object
org.tribuo.common.tree.AbstractCARTTrainer<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, DecisionTreeTrainer<T>, SparseTrainer<T>, Trainer<T>, WeightedExamples
Direct Known Subclasses:
CARTClassificationTrainer, CARTJointRegressionTrainer, CARTRegressionTrainer

public abstract class AbstractCARTTrainer<T extends Output<T>> extends Object implements DecisionTreeTrainer<T>
Base class for Trainer's that use an approximation of the CART algorithm to build a decision tree.

See:

 J. Friedman, T. Hastie, & R. Tibshirani.
 "The Elements of Statistical Learning"
 Springer 2001. PDF
 
  • Field Details

    • MIN_EXAMPLES

      public static final int MIN_EXAMPLES
      Default minimum weight of examples allowed in a leaf node.
      See Also:
    • minChildWeight

      @Config(description="The minimum weight allowed in a child node.") protected float minChildWeight
      Minimum weight of examples allowed in a leaf.
    • maxDepth

      @Config(description="The maximum depth of the tree.") protected int maxDepth
      Maximum tree depth. Integer.MAX_VALUE indicates the depth is unlimited.
    • fractionFeaturesInSplit

      @Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.") protected float fractionFeaturesInSplit
      Number of features to sample per split. 1 indicates all features are considered.
    • seed

      @Config(description="The RNG seed to use when sampling features in a split.") protected long seed
    • rng

      protected SplittableRandom rng
    • trainInvocationCounter

      protected int trainInvocationCounter
  • Constructor Details

    • AbstractCARTTrainer

      protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float fractionFeaturesInSplit, long seed)
      After calls to this superconstructor subclasses must call postConfig().
      Parameters:
      maxDepth - The maximum depth of the tree.
      minChildWeight - The minimum child weight allowed.
      fractionFeaturesInSplit - The fraction of features to consider at each split.
      seed - The seed for the feature subsampling RNG.
  • Method Details

    • postConfig

      public void postConfig()
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: Trainer
      The number of times this trainer instance has had it's train method invoked.

      This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.

      Specified by:
      getInvocationCount in interface Trainer<T extends Output<T>>
      Returns:
      The number of train invocations.
    • getFractionFeaturesInSplit

      Description copied from interface: DecisionTreeTrainer
      Returns the feature subsampling rate.
      Specified by:
      getFractionFeaturesInSplit in interface DecisionTreeTrainer<T extends Output<T>>
      Returns:
      The feature subsampling rate.
    • train

      public TreeModel<T> train(Dataset<T> examples)
      Description copied from interface: SparseTrainer
      Trains a sparse predictive model using the examples in the given data set.
      Specified by:
      train in interface SparseTrainer<T extends Output<T>>
      Specified by:
      train in interface Trainer<T extends Output<T>>
      Parameters:
      examples - The data set containing the examples.
      Returns:
      A sparse predictive model that can be used to generate predictions for new examples.
    • train

      public TreeModel<T> train(Dataset<T> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
      Description copied from interface: SparseTrainer
      Trains a sparse predictive model using the examples in the given data set.
      Specified by:
      train in interface SparseTrainer<T extends Output<T>>
      Specified by:
      train in interface Trainer<T extends Output<T>>
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • mkTrainingNode

      protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> examples)