Package org.tribuo.regression.slm
Class SLMTrainer
java.lang.Object
org.tribuo.regression.slm.SLMTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,SparseTrainer<Regressor>
,Trainer<Regressor>
,WeightedExamples
- Direct Known Subclasses:
LARSLassoTrainer
,LARSTrainer
A trainer for a sparse linear regression model.
Uses sequential forward selection to construct the model. Optionally can
normalize the data first. Each output dimension is trained independently
with no shared regularization.
-
Field Summary
Modifier and TypeFieldDescriptionprotected int
protected boolean
protected int
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ModifierConstructorDescriptionprotected
For OLCUT.SLMTrainer
(boolean normalize) Constructs a trainer for a sparse linear model using sequential forward selection.SLMTrainer
(boolean normalize, int maxNumFeatures) Constructs a trainer for a sparse linear model using sequential forward selection. -
Method Summary
Modifier and TypeMethodDescriptionint
The number of times this trainer instance has had it's train method invoked.protected org.apache.commons.math3.linear.RealVector
newWeights
(org.tribuo.regression.slm.SLMTrainer.SLMState state) void
setInvocationCount
(int invocationCount) Set the internal state of the trainer to the provided number of invocations of the train method.toString()
train
(Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sparse linear model.train
(Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Trains a sparse linear model.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
Methods inherited from interface org.tribuo.SparseTrainer
train
-
Field Details
-
maxNumFeatures
@Config(description="Maximum number of features to use.") protected int maxNumFeatures -
normalize
@Config(description="Normalize the data first.") protected boolean normalize -
trainInvocationCounter
protected int trainInvocationCounter
-
-
Constructor Details
-
SLMTrainer
public SLMTrainer(boolean normalize, int maxNumFeatures) Constructs a trainer for a sparse linear model using sequential forward selection.- Parameters:
normalize
- Normalizes the data first (i.e., removes the bias term).maxNumFeatures
- The maximum number of features to select. Supply -1 to select all features.
-
SLMTrainer
public SLMTrainer(boolean normalize) Constructs a trainer for a sparse linear model using sequential forward selection.Selects all the features.
- Parameters:
normalize
- Normalizes the data first (i.e., removes the bias term).
-
SLMTrainer
protected SLMTrainer()For OLCUT.
-
-
Method Details
-
newWeights
protected org.apache.commons.math3.linear.RealVector newWeights(org.tribuo.regression.slm.SLMTrainer.SLMState state) -
train
public SparseLinearModel train(Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sparse linear model. -
train
public SparseLinearModel train(Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Trains a sparse linear model.- Specified by:
train
in interfaceSparseTrainer<Regressor>
- Specified by:
train
in interfaceTrainer<Regressor>
- Parameters:
examples
- The data set containing the examples.invocationCount
- The state of the RNG the trainer should be set to before trainingrunProvenance
- Training run specific provenance (e.g., fold number).- Returns:
- A trained sparse linear model.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<Regressor>
- Returns:
- The number of train invocations.
-
setInvocationCount
public void setInvocationCount(int invocationCount) Description copied from interface:Trainer
Set the internal state of the trainer to the provided number of invocations of the train method.This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.
In a future major release this default implementation will be removed.
- Specified by:
setInvocationCount
in interfaceTrainer<Regressor>
- Parameters:
invocationCount
- the number of invocations of the train method to simulate
-
getProvenance
- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
-
toString
-