Package org.tribuo.common.sgd
Class AbstractLinearSGDTrainer<T extends Output<T>,U,V extends AbstractLinearSGDModel<T>>
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<T,U,V,LinearParameters>
org.tribuo.common.sgd.AbstractLinearSGDTrainer<T,U,V>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
,WeightedExamples
- Direct Known Subclasses:
LinearSGDTrainer
,LinearSGDTrainer
,LinearSGDTrainer
public abstract class AbstractLinearSGDTrainer<T extends Output<T>,U,V extends AbstractLinearSGDModel<T>>
extends AbstractSGDTrainer<T,U,V,LinearParameters>
A trainer for a linear model which uses SGD.
It's an AbstractSGDTrainer
operating on LinearParameters
, with
the bias folded into the features.
See:
Bottou L. "Large-Scale Machine Learning with Stochastic Gradient Descent" Proceedings of COMPSTAT, 2010.
-
Field Summary
Fields inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
addBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffle
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ModifierConstructorDescriptionprotected
For olcut.protected
AbstractLinearSGDTrainer
(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model. -
Method Summary
Modifier and TypeMethodDescriptionprotected LinearParameters
createParameters
(int numFeatures, int numOutputs, SplittableRandom localRNG) Constructs the trainable parameters object, in this case aLinearParameters
containing a single weight matrix.protected String
getName()
Returns the default model name.Methods inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
createModel, getInvocationCount, getModelClassName, getObjective, getProvenance, getTarget, postConfig, setInvocationCount, setShuffle, shuffleInPlace, train, train, train
-
Constructor Details
-
AbstractLinearSGDTrainer
protected AbstractLinearSGDTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) Constructs an SGD trainer for a linear model.- Parameters:
optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of any minibatches.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.
-
AbstractLinearSGDTrainer
protected AbstractLinearSGDTrainer()For olcut.
-
-
Method Details
-
getName
Returns the default model name.- Specified by:
getName
in classAbstractSGDTrainer<T extends Output<T>,
U, V extends AbstractLinearSGDModel<T>, LinearParameters> - Returns:
- The default model name.
-
createParameters
protected LinearParameters createParameters(int numFeatures, int numOutputs, SplittableRandom localRNG) Constructs the trainable parameters object, in this case aLinearParameters
containing a single weight matrix.- Specified by:
createParameters
in classAbstractSGDTrainer<T extends Output<T>,
U, V extends AbstractLinearSGDModel<T>, LinearParameters> - Parameters:
numFeatures
- The number of input features.numOutputs
- The number of output dimensions.localRNG
- The RNG to use for parameter initialisation.- Returns:
- The trainable parameters.
-