Class LinearSGDTrainer
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>
org.tribuo.common.sgd.AbstractLinearSGDTrainer<Label, Integer, LinearSGDModel>
org.tribuo.classification.sgd.linear.LinearSGDTrainer
- All Implemented Interfaces:
- com.oracle.labs.mlrg.olcut.config.Configurable,- com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,- Trainer<Label>,- WeightedExamples
- Direct Known Subclasses:
- LogisticRegressionTrainer
A trainer for a linear classifier using SGD.
 
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
- 
Field SummaryFields inherited from class org.tribuo.common.sgd.AbstractSGDTraineraddBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffleFields inherited from interface org.tribuo.TrainerDEFAULT_SEED, INCREMENT_INVOCATION_COUNT
- 
Constructor SummaryConstructorsConstructorDescriptionLinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Constructs an SGD trainer for a linear model.LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) Constructs an SGD trainer for a linear model.
- 
Method SummaryModifier and TypeMethodDescriptionprotected LinearSGDModelcreateModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, LinearParameters parameters) Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.protected StringReturns the class name of the model that's produced by this trainer.protected SGDObjective<Integer> Returns the objective used by this trainer.protected IntegergetTarget(ImmutableOutputInfo<Label> outputInfo, Label output) Extracts the appropriate training time representation from the supplied output.toString()Methods inherited from class org.tribuo.common.sgd.AbstractLinearSGDTrainercreateParameters, getNameMethods inherited from class org.tribuo.common.sgd.AbstractSGDTrainergetInvocationCount, getProvenance, postConfig, setInvocationCount, setShuffle, shuffleInPlace, train, train, train
- 
Constructor Details- 
LinearSGDTrainerpublic LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.- Parameters:
- objective- The objective function to optimise.
- optimiser- The gradient optimiser to use.
- epochs- The number of epochs (complete passes through the training data).
- loggingInterval- Log the loss after this many iterations. If -1 don't log anything.
- minibatchSize- The size of any minibatches.
- seed- A seed for the random number generator, used to shuffle the examples before each epoch.
 
- 
LinearSGDTrainerpublic LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) Constructs an SGD trainer for a linear model.Sets the minibatch size to 1. - Parameters:
- objective- The objective function to optimise.
- optimiser- The gradient optimiser to use.
- epochs- The number of epochs (complete passes through the training data).
- loggingInterval- Log the loss after this many iterations. If -1 don't log anything.
- seed- A seed for the random number generator, used to shuffle the examples before each epoch.
 
- 
LinearSGDTrainerpublic LinearSGDTrainer(LabelObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed) Constructs an SGD trainer for a linear model.Sets the minibatch size to 1 and the logging interval to 1000. - Parameters:
- objective- The objective function to optimise.
- optimiser- The gradient optimiser to use.
- epochs- The number of epochs (complete passes through the training data).
- seed- A seed for the random number generator, used to shuffle the examples before each epoch.
 
 
- 
- 
Method Details- 
getTargetDescription copied from class:AbstractSGDTrainerExtracts the appropriate training time representation from the supplied output.- Specified by:
- getTargetin class- AbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>
- Parameters:
- outputInfo- The output info to use.
- output- The output to extract.
- Returns:
- The training time representation of the output.
 
- 
getObjectiveDescription copied from class:AbstractSGDTrainerReturns the objective used by this trainer.- Specified by:
- getObjectivein class- AbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>
- Returns:
- The SGDObjective used by this trainer.
 
- 
createModelprotected LinearSGDModel createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Label> outputInfo, LinearParameters parameters) Description copied from class:AbstractSGDTrainerCreates the appropriate model subclass for this subclass of AbstractSGDTrainer.- Specified by:
- createModelin class- AbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>
- Parameters:
- name- The model name.
- provenance- The model provenance.
- featureMap- The feature map.
- outputInfo- The output info.
- parameters- The model parameters.
- Returns:
- A new instance of the appropriate subclass of Model.
 
- 
getModelClassNameDescription copied from class:AbstractSGDTrainerReturns the class name of the model that's produced by this trainer.- Specified by:
- getModelClassNamein class- AbstractSGDTrainer<Label, Integer, LinearSGDModel, LinearParameters>
- Returns:
- The model class name;
 
- 
toString
 
-