Class AbstractCARTTrainer<T extends Output<T>>

java.lang.Object
org.tribuo.common.tree.AbstractCARTTrainer<T>
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, DecisionTreeTrainer<T>, SparseTrainer<T>, Trainer<T>, WeightedExamples
Direct Known Subclasses:
CARTClassificationTrainer, CARTJointRegressionTrainer, CARTRegressionTrainer

public abstract class AbstractCARTTrainer<T extends Output<T>> extends Object implements DecisionTreeTrainer<T>
Base class for Trainer's that use an approximation of the CART algorithm to build a decision tree.

See:

 J. Friedman, T. Hastie, & R. Tibshirani.
 "The Elements of Statistical Learning"
 Springer 2001. PDF
 
  • Field Details

    • MIN_EXAMPLES

      public static final int MIN_EXAMPLES
      Default minimum weight of examples allowed in a leaf node.
      See Also:
    • minChildWeight

      @Config(description="The minimum weight allowed in a child node.") protected float minChildWeight
      Minimum weight of examples allowed in a leaf.
    • maxDepth

      @Config(description="The maximum depth of the tree.") protected int maxDepth
      Maximum tree depth. Integer.MAX_VALUE indicates the depth is unlimited.
    • minImpurityDecrease

      @Config(description="The decrease in impurity needed in order to split the node.") protected float minImpurityDecrease
      Minimum impurity decrease. The decrease in impurity needed in order to split the node.
    • fractionFeaturesInSplit

      @Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.") protected float fractionFeaturesInSplit
      Number of features to sample per split. 1 indicates all features are considered.
    • useRandomSplitPoints

      @Config(description="Whether to choose split points for features at random.") protected boolean useRandomSplitPoints
      Whether to choose split points for features at random.
    • seed

      @Config(description="The RNG seed to use when sampling features in a split.") protected long seed
    • rng

      protected SplittableRandom rng
    • trainInvocationCounter

      protected int trainInvocationCounter
  • Constructor Details

    • AbstractCARTTrainer

      protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, long seed)
      After calls to this superconstructor subclasses must call postConfig().
      Parameters:
      maxDepth - The maximum depth of the tree.
      minChildWeight - The minimum child weight allowed.
      minImpurityDecrease - The minimum decrease in impurity necessary to split a node.
      fractionFeaturesInSplit - The fraction of features to consider at each split.
      useRandomSplitPoints - Whether to choose split points for features at random.
      seed - The seed for the feature subsampling RNG.
  • Method Details

    • postConfig

      public void postConfig()
      Used by the OLCUT configuration system, and should not be called by external code.
      Specified by:
      postConfig in interface com.oracle.labs.mlrg.olcut.config.Configurable
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: Trainer
      The number of times this trainer instance has had it's train method invoked.

      This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.

      Specified by:
      getInvocationCount in interface Trainer<T extends Output<T>>
      Returns:
      The number of train invocations.
    • setInvocationCount

      public void setInvocationCount(int invocationCount)
      Description copied from interface: Trainer
      Set the internal state of the trainer to the provided number of invocations of the train method.

      This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.

      In a future major release this default implementation will be removed.

      Specified by:
      setInvocationCount in interface Trainer<T extends Output<T>>
      Parameters:
      invocationCount - the number of invocations of the train method to simulate
    • getFractionFeaturesInSplit

      public float getFractionFeaturesInSplit()
      Description copied from interface: DecisionTreeTrainer
      Returns the feature subsampling rate.
      Specified by:
      getFractionFeaturesInSplit in interface DecisionTreeTrainer<T extends Output<T>>
      Returns:
      The feature subsampling rate.
    • getUseRandomSplitPoints

      public boolean getUseRandomSplitPoints()
      Description copied from interface: DecisionTreeTrainer
      Returns whether to choose split points for features at random.
      Specified by:
      getUseRandomSplitPoints in interface DecisionTreeTrainer<T extends Output<T>>
      Returns:
      Whether to choose split points for features at random.
    • getMinImpurityDecrease

      public float getMinImpurityDecrease()
      Description copied from interface: DecisionTreeTrainer
      Returns the minimum decrease in impurity necessary to split a node.
      Specified by:
      getMinImpurityDecrease in interface DecisionTreeTrainer<T extends Output<T>>
      Returns:
      The minimum decrease in impurity necessary to split a node.
    • train

      public TreeModel<T> train(Dataset<T> examples)
      Description copied from interface: SparseTrainer
      Trains a sparse predictive model using the examples in the given data set.
      Specified by:
      train in interface SparseTrainer<T extends Output<T>>
      Specified by:
      train in interface Trainer<T extends Output<T>>
      Parameters:
      examples - The data set containing the examples.
      Returns:
      A sparse predictive model that can be used to generate predictions for new examples.
    • train

      public TreeModel<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
      Description copied from interface: SparseTrainer
      Trains a sparse predictive model using the examples in the given data set.
      Specified by:
      train in interface SparseTrainer<T extends Output<T>>
      Specified by:
      train in interface Trainer<T extends Output<T>>
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • train

      public TreeModel<T> train(Dataset<T> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount)
      Description copied from interface: SparseTrainer
      Trains a predictive model using the examples in the given data set.
      Specified by:
      train in interface SparseTrainer<T extends Output<T>>
      Specified by:
      train in interface Trainer<T extends Output<T>>
      Parameters:
      examples - the data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      invocationCount - The state of the RNG the trainer should be set to before training
      Returns:
      a predictive model that can be used to generate predictions for new examples.
    • mkTrainingNode

      protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> examples, AbstractTrainingNode.LeafDeterminer leafDeterminer)
      Makes the initial training node.
      Parameters:
      examples - The dataset to use.
      leafDeterminer - The leaf determination function.
      Returns:
      The initial training node.