Class FMParameters

java.lang.Object
org.tribuo.common.sgd.FMParameters
All Implemented Interfaces:
Serializable, FeedForwardParameters, Parameters

public final class FMParameters extends Object implements FeedForwardParameters
A Parameters for factorization machines.
See Also:
  • Constructor Details

    • FMParameters

      public FMParameters(SplittableRandom rng, int numFeatures, int numLabels, int numFactors, double variance)
      Constructor. The number of features and the number of outputs must be fixed and known in advance.
      Parameters:
      rng - The RNG to use for initialization.
      numFeatures - The number of features in the training dataset.
      numLabels - The number of outputs in the training dataset.
      numFactors - The size of the factorized feature representation.
      variance - The variance of the factor initializer.
  • Method Details

    • predict

      public DenseVector predict(SGDVector example)
      Generates an unnormalised prediction by multiplying the weights with the incoming features, adding the bias and adding the feature factors.
      Specified by:
      predict in interface FeedForwardParameters
      Parameters:
      example - A feature vector
      Returns:
      A DenseVector containing a score for each label.
    • gradients

      public Tensor[] gradients(com.oracle.labs.mlrg.olcut.util.Pair<Double,SGDVector> score, SGDVector features)
      Generate the gradients for a particular feature vector given the loss and the per output gradients.

      This method returns a Tensor array with numLabels + 2 elements.

      Specified by:
      gradients in interface FeedForwardParameters
      Parameters:
      score - The Pair returned by the objective.
      features - The feature vector.
      Returns:
      A Tensor array containing all the gradients.
    • getEmptyCopy

      public Tensor[] getEmptyCopy()
      This returns a DenseMatrix the same size as the Parameters.
      Specified by:
      getEmptyCopy in interface Parameters
      Returns:
      A Tensor array containing a single DenseMatrix.
    • get

      public Tensor[] get()
      Description copied from interface: Parameters
      Get a reference to the underlying Tensor array.
      Specified by:
      get in interface Parameters
      Returns:
      The parameters.
    • set

      public void set(Tensor[] newWeights)
      Description copied from interface: Parameters
      Set the underlying Tensor array to newWeights.
      Specified by:
      set in interface Parameters
      Parameters:
      newWeights - New parameters to store in this object.
    • update

      public void update(Tensor[] gradients)
      Description copied from interface: Parameters
      Apply gradients to the parameters. Assumes that gradients is the same length as the parameters, and each Tensor is the same size as the corresponding one from the parameters.

      The gradients are added to the parameters.

      Specified by:
      update in interface Parameters
      Parameters:
      gradients - A Tensor array of updates, with the length equal to Parameters.get().length.
    • merge

      public Tensor[] merge(Tensor[][] gradients, int size)
      Description copied from interface: Parameters
      Merge together an array of gradient arrays. Assumes the first dimension is the number of gradient arrays and the second dimension is the number of parameter Tensors.

      For performance reasons this call may mutate the input gradient array, and may return a subset of those elements as the merge output.

      Specified by:
      merge in interface Parameters
      Parameters:
      gradients - An array of gradient update arrays.
      size - The number of elements of gradients to merge. Allows gradients to have unused elements.
      Returns:
      A single Tensor array of the summed gradients.
    • copy

      public FMParameters copy()
      Description copied from interface: FeedForwardParameters
      Returns a copy of the parameters.
      Specified by:
      copy in interface FeedForwardParameters
      Returns:
      A copy of the model parameters.