Class LinearSGDTrainer
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<MultiLabel, SGDVector, AbstractLinearSGDModel<MultiLabel>, LinearParameters>
org.tribuo.common.sgd.AbstractLinearSGDTrainer<MultiLabel, SGDVector>
org.tribuo.multilabel.sgd.linear.LinearSGDTrainer
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,Trainer<MultiLabel>,WeightedExamples
A trainer for a multi-label linear model which uses SGD.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
-
Field Summary
Fields inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
addBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffleFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED -
Constructor Summary
ConstructorsConstructorDescriptionLinearSGDTrainer(MultiLabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.LinearSGDTrainer(MultiLabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Constructs an SGD trainer for a linear model.LinearSGDTrainer(MultiLabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) Constructs an SGD trainer for a linear model. -
Method Summary
Modifier and TypeMethodDescriptionprotected LinearSGDModelcreateModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<MultiLabel> outputInfo, LinearParameters parameters) Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.protected StringReturns the class name of the model that's produced by this trainer.protected SGDObjective<SGDVector> Returns the objective used by this trainer.protected SparseVectorgetTarget(ImmutableOutputInfo<MultiLabel> outputInfo, MultiLabel output) Extracts the appropriate training time representation from the supplied output.toString()Methods inherited from class org.tribuo.common.sgd.AbstractLinearSGDTrainer
createParameters, getNameMethods inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
getInvocationCount, getProvenance, postConfig, setShuffle, shuffleInPlace, train, train
-
Constructor Details
-
LinearSGDTrainer
public LinearSGDTrainer(MultiLabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.- Parameters:
objective- The objective function to optimise.optimiser- The gradient optimiser to use.epochs- The number of epochs (complete passes through the training data).loggingInterval- Log the loss after this many iterations. If -1 don't log anything.minibatchSize- The size of any minibatches.seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
LinearSGDTrainer
public LinearSGDTrainer(MultiLabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Constructs an SGD trainer for a linear model.Sets the minibatch size to 1.
- Parameters:
objective- The objective function to optimise.optimiser- The gradient optimiser to use.epochs- The number of epochs (complete passes through the training data).loggingInterval- Log the loss after this many iterations. If -1 don't log anything.seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
LinearSGDTrainer
public LinearSGDTrainer(MultiLabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) Constructs an SGD trainer for a linear model.Sets the minibatch size to 1 and the logging interval to 1000.
- Parameters:
objective- The objective function to optimise.optimiser- The gradient optimiser to use.epochs- The number of epochs (complete passes through the training data).seed- A seed for the random number generator, used to shuffle the examples before each epoch.
-
-
Method Details
-
getTarget
Description copied from class:AbstractSGDTrainerExtracts the appropriate training time representation from the supplied output.- Specified by:
getTargetin classAbstractSGDTrainer<MultiLabel, SGDVector, AbstractLinearSGDModel<MultiLabel>, LinearParameters>- Parameters:
outputInfo- The output info to use.output- The output to extract.- Returns:
- The training time representation of the output.
-
getObjective
Description copied from class:AbstractSGDTrainerReturns the objective used by this trainer.- Specified by:
getObjectivein classAbstractSGDTrainer<MultiLabel, SGDVector, AbstractLinearSGDModel<MultiLabel>, LinearParameters>- Returns:
- The SGDObjective used by this trainer.
-
createModel
protected LinearSGDModel createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<MultiLabel> outputInfo, LinearParameters parameters) Description copied from class:AbstractSGDTrainerCreates the appropriate model subclass for this subclass of AbstractSGDTrainer.- Specified by:
createModelin classAbstractSGDTrainer<MultiLabel, SGDVector, AbstractLinearSGDModel<MultiLabel>, LinearParameters>- Parameters:
name- The model name.provenance- The model provenance.featureMap- The feature map.outputInfo- The output info.parameters- The model parameters.- Returns:
- A new instance of the appropriate subclass of
Model.
-
getModelClassName
Description copied from class:AbstractSGDTrainerReturns the class name of the model that's produced by this trainer.- Specified by:
getModelClassNamein classAbstractSGDTrainer<MultiLabel, SGDVector, AbstractLinearSGDModel<MultiLabel>, LinearParameters>- Returns:
- The model class name;
-
toString
-