public class SLMTrainer extends Object implements SparseTrainer<Regressor>, WeightedExamples
Modifier and Type | Field and Description |
---|---|
protected int |
maxNumFeatures |
protected boolean |
normalize |
protected int |
trainInvocationCounter |
DEFAULT_SEED
Modifier | Constructor and Description |
---|---|
protected |
SLMTrainer()
For OLCUT.
|
|
SLMTrainer(boolean normalize)
Constructs a trainer for a sparse linear model using sequential forward selection.
|
|
SLMTrainer(boolean normalize,
int maxNumFeatures)
Constructs a trainer for a sparse linear model using sequential forward selection.
|
Modifier and Type | Method and Description |
---|---|
int |
getInvocationCount()
The number of times this trainer instance has had it's train method invoked.
|
TrainerProvenance |
getProvenance() |
protected org.apache.commons.math3.linear.RealVector |
newWeights(org.tribuo.regression.slm.SLMTrainer.SLMState state) |
String |
toString() |
SparseLinearModel |
train(Dataset<Regressor> examples,
Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
Trains a sparse linear model.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
train
@Config(description="Maximum number of features to use.") protected int maxNumFeatures
@Config(description="Normalize the data first.") protected boolean normalize
protected int trainInvocationCounter
public SLMTrainer(boolean normalize, int maxNumFeatures)
normalize
- Normalizes the data first (i.e., removes the bias term).maxNumFeatures
- The maximum number of features to select. Supply -1 to select all features.public SLMTrainer(boolean normalize)
Selects all the features.
normalize
- Normalizes the data first (i.e., removes the bias term).protected SLMTrainer()
protected org.apache.commons.math3.linear.RealVector newWeights(org.tribuo.regression.slm.SLMTrainer.SLMState state)
public SparseLinearModel train(Dataset<Regressor> examples, Map<String,com.oracle.labs.mlrg.olcut.provenance.Provenance> runProvenance)
public int getInvocationCount()
Trainer
This is used to determine how many times the trainer's RNG has been accessed to ensure replicability in the random number stream.
getInvocationCount
in interface Trainer<Regressor>
public TrainerProvenance getProvenance()
getProvenance
in interface com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.