Class AbstractFMTrainer<T extends Output<T>, U, V extends AbstractFMModel<T>>
java.lang.Object
org.tribuo.common.sgd.AbstractSGDTrainer<T, U, V, FMParameters>
org.tribuo.common.sgd.AbstractFMTrainer<T,U,V>
- All Implemented Interfaces:
com.oracle.labs.mlrg.olcut.config.Configurable,com.oracle.labs.mlrg.olcut.provenance.Provenancable<TrainerProvenance>,Trainer<T>,WeightedExamples
- Direct Known Subclasses:
FMClassificationTrainer,FMMultiLabelTrainer,FMRegressionTrainer
public abstract class AbstractFMTrainer<T extends Output<T>, U, V extends AbstractFMModel<T>>
extends AbstractSGDTrainer<T, U, V, FMParameters>
A trainer for a quadratic factorization machine model which uses SGD.
It's an AbstractSGDTrainer operating on FMParameters.
See:
Rendle, S. Factorization machines. 2010 IEEE International Conference on Data Mining
-
Field Summary
FieldsFields inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
addBias, epochs, loggingInterval, minibatchSize, optimiser, rng, seed, shuffleFields inherited from interface org.tribuo.Trainer
DEFAULT_SEED, INCREMENT_INVOCATION_COUNT -
Constructor Summary
ConstructorsModifierConstructorDescriptionprotectedFor olcut.protectedAbstractFMTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine. -
Method Summary
Modifier and TypeMethodDescriptionprotected FMParameterscreateParameters(int numFeatures, int numOutputs, SplittableRandom localRNG) Constructs the trainable parameters object, in this case aFMParameterscontaining a weight matrix for the feature weights and a series of weight matrices for the factorized feature representation.protected StringgetName()Returns the default model name.voidUsed by the OLCUT configuration system, and should not be called by external code.Methods inherited from class org.tribuo.common.sgd.AbstractSGDTrainer
createModel, getInvocationCount, getModelClassName, getObjective, getProvenance, getTarget, setInvocationCount, setShuffle, shuffleInPlace, train, train, train
-
Field Details
-
factorizedDimSize
@Config(mandatory=true, description="The size of the factorized feature representation.") protected int factorizedDimSize -
variance
@Config(mandatory=true, description="The variance of the initializer.") protected double variance
-
-
Constructor Details
-
AbstractFMTrainer
protected AbstractFMTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed, int factorizedDimSize, double variance) Constructs an SGD trainer for a factorization machine.- Parameters:
optimiser- The gradient optimiser to use.epochs- The number of epochs (complete passes through the training data).loggingInterval- Log the loss after this many iterations. If -1 don't log anything.minibatchSize- The size of any minibatches.seed- A seed for the random number generator, used to shuffle the examples before each epoch.factorizedDimSize- Size of the factorized feature representation.variance- The variance of the initializer.
-
AbstractFMTrainer
protected AbstractFMTrainer()For olcut.
-
-
Method Details
-
postConfig
public void postConfig()Description copied from class:AbstractSGDTrainerUsed by the OLCUT configuration system, and should not be called by external code.- Specified by:
postConfigin interfacecom.oracle.labs.mlrg.olcut.config.Configurable- Overrides:
postConfigin classAbstractSGDTrainer<T extends Output<T>, U, V extends AbstractFMModel<T>, FMParameters>
-
getName
Returns the default model name.- Specified by:
getNamein classAbstractSGDTrainer<T extends Output<T>, U, V extends AbstractFMModel<T>, FMParameters>- Returns:
- The default model name.
-
createParameters
Constructs the trainable parameters object, in this case aFMParameterscontaining a weight matrix for the feature weights and a series of weight matrices for the factorized feature representation.- Specified by:
createParametersin classAbstractSGDTrainer<T extends Output<T>, U, V extends AbstractFMModel<T>, FMParameters>- Parameters:
numFeatures- The number of input features.numOutputs- The number of output dimensions.localRNG- The RNG to use for parameter initialisation.- Returns:
- The trainable parameters.
-