Package org.tribuo.regression.slm
Class SLMTrainer
java.lang.Object
org.tribuo.regression.slm.SLMTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,SparseTrainer<Regressor>
,Trainer<Regressor>
,WeightedExamples
- Direct Known Subclasses:
LARSLassoTrainer
,LARSTrainer
A trainer for a sparse linear regression model.
Uses sequential forward selection to construct the model. Optionally can
normalize the data first. Each output dimension is trained independently
with no shared regularization.
-
Field Summary
Modifier and TypeFieldDescriptionprotected int
The maximum number of features to select.protected boolean
Should the data be centred first? In most cases this should be true.protected int
The number of timestrain(org.tribuo.Dataset<org.tribuo.regression.Regressor>, java.util.Map<java.lang.String, com.oracle.labs.mlrg.olcut.provenance.Provenance>)
has been called on this object.Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ModifierConstructorDescriptionprotected
For OLCUT.SLMTrainer
(boolean normalize) Constructs a trainer for a sparse linear model using sequential forward selection.SLMTrainer
(boolean normalize, int maxNumFeatures) Constructs a trainer for a sparse linear model using sequential forward selection. -
Method Summary
Modifier and TypeMethodDescriptionint
The number of times this trainer instance has had it's train method invoked.protected DenseVector
newWeights
(org.tribuo.regression.slm.SLMTrainer.SLMState state) Computes the new feature weights.void
setInvocationCount
(int invocationCount) Set the internal state of the trainer to the provided number of invocations of the train method.toString()
train
(Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sparse linear model.train
(Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Trains a sparse linear model.Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface com.oracle.labs.mlrg.olcut.config.Configurable
postConfig
Methods inherited from interface org.tribuo.SparseTrainer
train
-
Field Details
-
maxNumFeatures
@Config(description="Maximum number of features to use.") protected int maxNumFeaturesThe maximum number of features to select. -
normalize
@Config(description="Normalize the data first.") protected boolean normalizeShould the data be centred first? In most cases this should be true. -
trainInvocationCounter
protected int trainInvocationCounterThe number of timestrain(org.tribuo.Dataset<org.tribuo.regression.Regressor>, java.util.Map<java.lang.String, com.oracle.labs.mlrg.olcut.provenance.Provenance>)
has been called on this object.
-
-
Constructor Details
-
SLMTrainer
public SLMTrainer(boolean normalize, int maxNumFeatures) Constructs a trainer for a sparse linear model using sequential forward selection.- Parameters:
normalize
- Normalizes the data first (i.e., removes the bias term).maxNumFeatures
- The maximum number of features to select. Supply -1 to select all features.
-
SLMTrainer
public SLMTrainer(boolean normalize) Constructs a trainer for a sparse linear model using sequential forward selection.Selects all the features.
- Parameters:
normalize
- Normalizes the data first (i.e., removes the bias term).
-
SLMTrainer
protected SLMTrainer()For OLCUT.
-
-
Method Details
-
newWeights
Computes the new feature weights.In this version it returns the ordinary least squares solution for the current state.
- Parameters:
state
- The SLM state to operate on.- Returns:
- The new feature weights.
-
train
public SparseLinearModel train(Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance) Trains a sparse linear model. -
train
public SparseLinearModel train(Dataset<Regressor> examples, Map<String, com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance, int invocationCount) Trains a sparse linear model.- Specified by:
train
in interfaceSparseTrainer<Regressor>
- Specified by:
train
in interfaceTrainer<Regressor>
- Parameters:
examples
- The data set containing the examples.invocationCount
- The state of the RNG the trainer should be set to before trainingrunProvenance
- Training run specific provenance (e.g., fold number).- Returns:
- A trained sparse linear model.
-
getInvocationCount
public int getInvocationCount()Description copied from interface:Trainer
The number of times this trainer instance has had it's train method invoked.This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
- Specified by:
getInvocationCount
in interfaceTrainer<Regressor>
- Returns:
- The number of train invocations.
-
setInvocationCount
public void setInvocationCount(int invocationCount) Description copied from interface:Trainer
Set the internal state of the trainer to the provided number of invocations of the train method.This is used when reproducing a Tribuo-trained model by setting the state of the RNG to what it was at when Tribuo trained the original model by simulating invocations of the train method. This method should ALWAYS be overridden, and the default method is purely for compatibility.
In a future major release this default implementation will be removed.
- Specified by:
setInvocationCount
in interfaceTrainer<Regressor>
- Parameters:
invocationCount
- the number of invocations of the train method to simulate
-
getProvenance
- Specified by:
getProvenance
in interfacecom.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
-
toString
-