Class LinearSGDTrainer
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>
org.tribuo.common.sgd.AbstractLinearSGDTrainer<Label, Integer, LinearSGDModel>
org.tribuo.classification.sgd.linear.LinearSGDTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,Trainer<Label>,WeightedExamples
- Direct Known Subclasses:
LogisticRegressionTrainer
A trainer for a linear classifier using SGD.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
-
Field Summary
Fields inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
addBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffleFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT -
Constructor Summary
ConstructorsConstructorDescriptionLinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Constructs an SGD trainer for a linear model.LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) Constructs an SGD trainer for a linear model. -
Method Summary
Modifier and TypeMethodDescriptionprotected LinearSGDModelcreateModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, LinearParameters parameters) Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.protected StringReturns the class name of the model that's produced by this trainer.protected SGDObjective<Integer> Returns the objective used by this trainer.protected IntegergetTarget(ImmutableOutputInfo<Label> outputInfo, Label output) Extracts the appropriate training time representation from the supplied output.toString()Methods inherited from class org.tribuo.common.sgd.AbstractLinearSGDTrainer
createParameters, getNameMethods inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
getInvocationCount, getProvenance, postConfig, setInvocationCount, setShuffle, shuffleInPlace, train, train, train
-
Constructor Details
-
LinearSGDTrainer
public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.- Parameters:
objective- The objective function to optimise.optimiser- The gradient optimiser to use.epochs- The number of epochs (complete passes through the training data).loggingInterval- Log the loss after this many iterations. If -1 don't log anything.minibatchSize- The size of any minibatches.seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
LinearSGDTrainer
public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Constructs an SGD trainer for a linear model.Sets the minibatch size to 1.
- Parameters:
objective- The objective function to optimise.optimiser- The gradient optimiser to use.epochs- The number of epochs (complete passes through the training data).loggingInterval- Log the loss after this many iterations. If -1 don't log anything.seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
LinearSGDTrainer
public LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) Constructs an SGD trainer for a linear model.Sets the minibatch size to 1 and the logging interval to 1000.
- Parameters:
objective- The objective function to optimise.optimiser- The gradient optimiser to use.epochs- The number of epochs (complete passes through the training data).seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
-
Method Details
-
getTarget
Description copied from class:AbstractSGDTrainerExtracts the appropriate training time representation from the supplied output.- Specified by:
getTargetin classAbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>- Parameters:
outputInfo- The output info to use.output- The output to extract.- Returns:
- The training time representation of the output.
-
getObjective
Description copied from class:AbstractSGDTrainerReturns the objective used by this trainer.- Specified by:
getObjectivein classAbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>- Returns:
- The SGDObjective used by this trainer.
-
createModel
protected LinearSGDModel createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, LinearParameters parameters) Description copied from class:AbstractSGDTrainerCreates the appropriate model subclass for this subclass of AbstractSGDTrainer.- Specified by:
createModelin classAbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>- Parameters:
name- The model name.provenance- The model provenance.featureMap- The feature map.outputInfo- The output info.parameters- The model parameters.- Returns:
- A new instance of the appropriate subclass of
Model.
-
getModelClassName
Description copied from class:AbstractSGDTrainerReturns the class name of the model that's produced by this trainer.- Specified by:
getModelClassNamein classAbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>- Returns:
- The model class name;
-
toString
-