public class LinearSGDTrainer extends AbstractLinearSGDTrainer<Regressor,DenseVector>
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
addBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffle
DEFAULT_SEED
Constructor and Description |
---|
LinearSGDTrainer(RegressionObjective objective,
StochasticGradientOptimiser optimiser,
int epochs,
int loggingInterval,
int minibatchSize,
long seed)
Constructs an SGD trainer for a linear model.
|
LinearSGDTrainer(RegressionObjective objective,
StochasticGradientOptimiser optimiser,
int epochs,
int loggingInterval,
long seed)
Constructs an SGD trainer for a linear model.
|
LinearSGDTrainer(RegressionObjective objective,
StochasticGradientOptimiser optimiser,
int epochs,
long seed)
Constructs an SGD trainer for a linear model.
|
Modifier and Type | Method and Description |
---|---|
protected LinearSGDModel |
createModel(String name,
ModelProvenance provenance,
ImmutableFeatureMap featureMap,
ImmutableOutputInfo<Regressor> outputInfo,
LinearParameters parameters)
Creates the appropriate model subclass for this subclass of AbstractSGDTrainer.
|
protected String |
getModelClassName()
Returns the class name of the model that's produced by this trainer.
|
protected SGDObjective<DenseVector> |
getObjective()
Returns the objective used by this trainer.
|
protected DenseVector |
getTarget(ImmutableOutputInfo<Regressor> outputInfo,
Regressor output)
Extracts the appropriate training time representation from the supplied output.
|
String |
toString() |
createParameters, getName
getInvocationCount, getProvenance, postConfig, setShuffle, shuffleInPlace, train, train
public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed)
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of any minibatches.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed)
Sets the minibatch size to 1.
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.public LinearSGDTrainer(RegressionObjective objective, StochasticGradientOptimiser optimiser, int epochs, long seed)
Sets the minibatch size to 1 and the logging interval to 1000.
objective
- The objective function to optimise.optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).seed
- A seed for the random number generator, used to shuffle the examples before each epoch.protected DenseVector getTarget(ImmutableOutputInfo<Regressor> outputInfo, Regressor output)
AbstractSGDTrainer
getTarget
in class AbstractSGDTrainer<Regressor,DenseVector,AbstractLinearSGDModel<Regressor>,LinearParameters>
outputInfo
- The output info to use.output
- The output to extract.protected SGDObjective<DenseVector> getObjective()
AbstractSGDTrainer
getObjective
in class AbstractSGDTrainer<Regressor,DenseVector,AbstractLinearSGDModel<Regressor>,LinearParameters>
protected LinearSGDModel createModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap, ImmutableOutputInfo<Regressor> outputInfo, LinearParameters parameters)
AbstractSGDTrainer
createModel
in class AbstractSGDTrainer<Regressor,DenseVector,AbstractLinearSGDModel<Regressor>,LinearParameters>
name
- The model name.provenance
- The model provenance.featureMap
- The feature map.outputInfo
- The output info.parameters
- The model parameters.Model
.protected String getModelClassName()
AbstractSGDTrainer
getModelClassName
in class AbstractSGDTrainer<Regressor,DenseVector,AbstractLinearSGDModel<Regressor>,LinearParameters>
Copyright © 2015–2021 Oracle and/or its affiliates. All rights reserved.