Class LinearSGDTrainer
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<T,U,V,LinearParameters>
org.tribuo.common.sgd.AbstractLinearSGDTrainer<Label,Integer,LinearSGDModel>
org.tribuo.classification.sgd.linear.LinearSGDTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<Label>
,WeightedExamples
- Direct Known Subclasses:
LogisticRegressionTrainer
A trainer for a linear classifier using SGD.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
-
Field Summary
Fields inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
addBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffle
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ConstructorDescriptionLinearSGDTrainer
(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.LinearSGDTrainer
(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Constructs an SGD trainer for a linear model.LinearSGDTrainer
(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) Constructs an SGD trainer for a linear model. -
Method Summary
Modifier and TypeMethodDescriptionprotected LinearSGDModel
createModel
(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, LinearParameters parameters) Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.protected String
Returns the class name of the model that's produced by this trainer.protected SGDObjective<Integer>
Returns the objective used by this trainer.protected Integer
getTarget
(ImmutableOutputInfo<Label> outputInfo, Label output) Extracts the appropriate training time representation from the supplied output.toString()
Methods inherited from class org.tribuo.common.sgd.AbstractLinearSGDTrainer
createParameters, getName
Methods inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
getInvocationCount, getProvenance, postConfig, setInvocationCount, setShuffle, shuffleInPlace, train, train, train
-
Constructor Details
-
LinearSGDTrainer
public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.- Parameters:
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of any minibatches.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.
-
LinearSGDTrainer
public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Constructs an SGD trainer for a linear model.Sets the minibatch size to 1.
- Parameters:
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.
-
LinearSGDTrainer
public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) Constructs an SGD trainer for a linear model.Sets the minibatch size to 1 and the logging interval to 1000.
- Parameters:
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).seed
- A seed for the random number generator, used to shuffle the examples before each epoch.
-
-
Method Details
-
getTarget
Description copied from class:AbstractSGDTrainer
Extracts the appropriate training time representation from the supplied output.- Specified by:
getTarget
in classAbstractSGDTrainer<Label,
Integer, LinearSGDModel, LinearParameters> - Parameters:
outputInfo
- The output info to use.output
- The output to extract.- Returns:
- The training time representation of the output.
-
getObjective
Description copied from class:AbstractSGDTrainer
Returns the objective used by this trainer.- Specified by:
getObjective
in classAbstractSGDTrainer<Label,
Integer, LinearSGDModel, LinearParameters> - Returns:
- The SGDObjective used by this trainer.
-
createModel
protected LinearSGDModel createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, LinearParameters parameters) Description copied from class:AbstractSGDTrainer
Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.- Specified by:
createModel
in classAbstractSGDTrainer<Label,
Integer, LinearSGDModel, LinearParameters> - Parameters:
name
- The model name.provenance
- The model provenance.featureMap
- The feature map.outputInfo
- The output info.parameters
- The model parameters.- Returns:
- A new instance of the appropriate subclass of
Model
.
-
getModelClassName
Description copied from class:AbstractSGDTrainer
Returns the class name of the model that's produced by this trainer.- Specified by:
getModelClassName
in classAbstractSGDTrainer<Label,
Integer, LinearSGDModel, LinearParameters> - Returns:
- The model class name;
-
toString
-