Class SLMTrainer

java.lang.Object
org.tribuo.regression.slm.SLMTrainer
All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable, com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>, SparseTrainer<Regressor>, Trainer<Regressor>, WeightedExamples
Direct Known Subclasses:
LARSLassoTrainer, LARSTrainer

public class SLMTrainer extends Object implements SparseTrainer<Regressor>, WeightedExamples
A trainer for a sparse linear regression model. Uses sequential forward selection to construct the model. Optionally can normalize the data first. Each output dimension is trained independently with no shared regularization.
  • Field Details

  • Constructor Details

    • SLMTrainer

      public SLMTrainer(boolean normalize, int maxNumFeatures)
      Constructs a trainer for a sparse linear model using sequential forward selection.
      Parameters:
      normalize - Normalizes the data first (i.e., removes the bias term).
      maxNumFeatures - The maximum number of features to select. Supply -1 to select all features.
    • SLMTrainer

      public SLMTrainer(boolean normalize)
      Constructs a trainer for a sparse linear model using sequential forward selection.

      Selects all the features.

      Parameters:
      normalize - Normalizes the data first (i.e., removes the bias term).
    • SLMTrainer

      protected SLMTrainer()
      For OLCUT.
  • Method Details

    • newWeights

      protected DenseVector newWeights(org.tribuo.regression.slm.SLMTrainer.SLMState state)
      Computes the new feature weights.

      In this version it returns the ordinary least squares solution for the current state.

      Parameters:
      state - The SLM state to operate on.
      Returns:
      The new feature weights.
    • train

      public SparseLinearModel train(Dataset<Regressor> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
      Trains a sparse linear model.
      Specified by:
      train in interface SparseTrainer<Regressor>
      Specified by:
      train in interface Trainer<Regressor>
      Parameters:
      examples - The data set containing the examples.
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      A trained sparse linear model.
    • train

      public SparseLinearModel train(Dataset<Regressor> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount)
      Trains a sparse linear model.
      Specified by:
      train in interface SparseTrainer<Regressor>
      Specified by:
      train in interface Trainer<Regressor>
      Parameters:
      examples - The data set containing the examples.
      invocationCount - The state of the RNG the trainer should be set to before training
      runProvenance - Training run specific provenance (e.g., fold number).
      Returns:
      A trained sparse linear model.
    • getInvocationCount

      public int getInvocationCount()
      Description copied from interface: Trainer
      The number of times this trainer instance has had it's train method invoked.

      This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.

      Specified by:
      getInvocationCount in interface Trainer<Regressor>
      Returns:
      The number of train invocations.
    • setInvocationCount

      public void setInvocationCount(int invocationCount)
      Description copied from interface: Trainer
      Set the internal state of the trainer to the provided number of invocations of the train method.

      This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.

      In a future major release this default implementation will be removed.

      Specified by:
      setInvocationCount in interface Trainer<Regressor>
      Parameters:
      invocationCount - the number of invocations of the train method to simulate
    • getProvenance

      public TrainerProvenance getProvenance()
      Specified by:
      getProvenance in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
    • toString

      public String toString()
      Overrides:
      toString in class Object