Package org.tribuo.common.sgd
Class AbstractFMTrainer<T extends Output<T>,U,V extends AbstractFMModel<T>>
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<T,U,V,FMParameters>
org.tribuo.common.sgd.AbstractFMTrainer<T,U,V>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable
,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>
,Trainer<T>
,WeightedExamples
- Direct Known Subclasses:
FMClassificationTrainer
,FMMultiLabelTrainer
,FMRegressionTrainer
public abstract class AbstractFMTrainer<T extends Output<T>,U,V extends AbstractFMModel<T>>
extends AbstractSGDTrainer<T,U,V,FMParameters>
A trainer for a quadratic factorization machine model which uses SGD.
It's an AbstractSGDTrainer
operating on FMParameters
.
See:
Rendle, S. Factorization machines. 2010 IEEE International Conference on Data Mining
-
Field Summary
Fields inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
addBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffle
Fields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT
-
Constructor Summary
ModifierConstructorDescriptionprotected
For olcut.protected
AbstractFMTrainer
(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine. -
Method Summary
Modifier and TypeMethodDescriptionprotected FMParameters
createParameters
(int numFeatures, int numOutputs, SplittableRandom localRNG) Constructs the trainable parameters object, in this case aFMParameters
containing a weight matrix for the feature weights and a series of weight matrices for the factorized feature representation.protected String
getName()
Returns the default model name.void
Used by the OLCUT configuration system, and should not be called by external code.Methods inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
createModel, getInvocationCount, getModelClassName, getObjective, getProvenance, getTarget, setInvocationCount, setShuffle, shuffleInPlace, train, train, train
-
Field Details
-
factorizedDimSize
@Config(mandatory=true, description="The size of the factorized feature representation.") protected int factorizedDimSize -
variance
@Config(mandatory=true, description="The variance of the initializer.") protected double variance
-
-
Constructor Details
-
AbstractFMTrainer
protected AbstractFMTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine.- Parameters:
optimiser
- The gradient optimiser to use.epochs
- The number of epochs (complete passes through the training data).loggingInterval
- Log the loss after this many iterations. If -1 don't log anything.minibatchSize
- The size of any minibatches.seed
- A seed for the random number generator, used to shuffle the examples before each epoch.factorizedDimSize
- Size of the factorized feature representation.variance
- The variance of the initializer.
-
AbstractFMTrainer
protected AbstractFMTrainer()For olcut.
-
-
Method Details
-
postConfig
public void postConfig()Description copied from class:AbstractSGDTrainer
Used by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfig
in interfacecom.oracle.labs.mlrg.olcut.config.Configurable
- Overrides:
postConfig
in classAbstractSGDTrainer<T extends Output<T>,
U, V extends AbstractFMModel<T>, FMParameters>
-
getName
Returns the default model name.- Specified by:
getName
in classAbstractSGDTrainer<T extends Output<T>,
U, V extends AbstractFMModel<T>, FMParameters> - Returns:
- The default model name.
-
createParameters
Constructs the trainable parameters object, in this case aFMParameters
containing a weight matrix for the feature weights and a series of weight matrices for the factorized feature representation.- Specified by:
createParameters
in classAbstractSGDTrainer<T extends Output<T>,
U, V extends AbstractFMModel<T>, FMParameters> - Parameters:
numFeatures
- The number of input features.numOutputs
- The number of output dimensions.localRNG
- The RNG to use for parameter initialisation.- Returns:
- The trainable parameters.
-