Class ShrinkingVector

java.lang.Object
org.tribuo.math.la.DenseVector
org.tribuo.math.optimisers.util.ShrinkingVector
All Implemented Interfaces:
Serializable, Iterable<VectorTuple>, SGDVector, Tensor, ShrinkingTensor, ProtoSerializable<org.tribuo.math.protos.TensorProto>

public class ShrinkingVector extends DenseVector implements ShrinkingTensor
A subclass of DenseVector which shrinks the value every time a new value is added.

Be careful when modifying this or DenseVector.

See Also:
  • Constructor Details

    • ShrinkingVector

      public ShrinkingVector(DenseVector v, double baseRate, boolean scaleShrinking)
      Constructs a shrinking vector copy of the supplied dense matrix.

      This vector shrinks during each call to intersectAndAddInPlace(Tensor, DoubleUnaryOperator).

      Parameters:
      v - The vector to copy.
      baseRate - The base amount of shrinking to apply after each update.
      scaleShrinking - If true reduce the shrinking value over time proportionally to the number of updates.
    • ShrinkingVector

      public ShrinkingVector(DenseVector v, double baseRate, double lambda)
      Constructs a shrinking vector copy of the supplied dense vector.

      This vector shrinks during each call to intersectAndAddInPlace(Tensor, DoubleUnaryOperator), and then reprojects the vector so it has the same twoNorm.

      Parameters:
      v - The vector to copy.
      baseRate - The base rate of shrinkage.
      lambda - The lambda value (see Pegasos).
  • Method Details

    • deserializeFromProto

      public static ShrinkingVector deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException
      Deserialization factory.
      Parameters:
      version - The serialized object version.
      className - The class name.
      message - The serialized data.
      Returns:
      The deserialized object.
      Throws:
      com.google.protobuf.InvalidProtocolBufferException - If the protobuf could not be parsed from the message.
    • serialize

      public org.tribuo.math.protos.TensorProto serialize()
      Description copied from interface: ProtoSerializable
      Serializes this object to a protobuf.
      Specified by:
      serialize in interface ProtoSerializable<org.tribuo.math.protos.TensorProto>
      Overrides:
      serialize in class DenseVector
      Returns:
      The protobuf.
    • convertToDense

      public DenseVector convertToDense()
      Description copied from interface: ShrinkingTensor
      Converts the tensor into a dense tensor.
      Specified by:
      convertToDense in interface ShrinkingTensor
      Returns:
      A dense tensor copy of this shrinking tensor.
    • copy

      public ShrinkingVector copy()
      Description copied from interface: SGDVector
      Returns a deep copy of this vector.
      Specified by:
      copy in interface SGDVector
      Specified by:
      copy in interface Tensor
      Overrides:
      copy in class DenseVector
      Returns:
      A copy of this vector.
    • toArray

      public double[] toArray()
      Description copied from class: DenseVector
      Generates a copy of the values in this DenseVector.

      This implementation uses Arrays.copyOf, and should be overridden if the get function has been modified.

      Specified by:
      toArray in interface SGDVector
      Overrides:
      toArray in class DenseVector
      Returns:
      A copy of the values in this DenseVector.
    • get

      public double get(int index)
      Description copied from interface: SGDVector
      Gets an element from this vector.
      Specified by:
      get in interface SGDVector
      Overrides:
      get in class DenseVector
      Parameters:
      index - The index of the element.
      Returns:
      The value at that index.
    • sum

      public double sum()
      Description copied from interface: SGDVector
      Calculates the sum of this vector.
      Specified by:
      sum in interface SGDVector
      Overrides:
      sum in class DenseVector
      Returns:
      The sum.
    • intersectAndAddInPlace

      public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f)
      Description copied from interface: Tensor
      Updates this Tensor by adding all the values from the intersection with other.

      The function f is applied to all values from other before the addition.

      Each value is updated as value += f(otherValue).

      Specified by:
      intersectAndAddInPlace in interface Tensor
      Overrides:
      intersectAndAddInPlace in class DenseVector
      Parameters:
      other - The other Tensor.
      f - A function to apply.
    • indexOfMax

      public int indexOfMax()
      Description copied from interface: SGDVector
      Returns the index of the maximum value. Requires probing the array.
      Specified by:
      indexOfMax in interface SGDVector
      Overrides:
      indexOfMax in class DenseVector
      Returns:
      The index of the maximum value.
    • dot

      public double dot(SGDVector other)
      Description copied from interface: SGDVector
      Calculates the dot product between this vector and other.
      Specified by:
      dot in interface SGDVector
      Overrides:
      dot in class DenseVector
      Parameters:
      other - The other vector.
      Returns:
      The dot product.
    • scaleInPlace

      public void scaleInPlace(double value)
      Description copied from interface: Tensor
      Scales each element of this Tensor by coefficient.
      Specified by:
      scaleInPlace in interface Tensor
      Parameters:
      value - The coefficient of scaling.
    • twoNorm

      public double twoNorm()
      Description copied from interface: SGDVector
      Calculates the euclidean norm for this vector.
      Specified by:
      twoNorm in interface SGDVector
      Specified by:
      twoNorm in interface Tensor
      Overrides:
      twoNorm in class DenseVector
      Returns:
      The euclidean norm.
    • maxValue

      public double maxValue()
      Description copied from interface: SGDVector
      Returns the maximum value. Requires probing the array.
      Specified by:
      maxValue in interface SGDVector
      Overrides:
      maxValue in class DenseVector
      Returns:
      The maximum value.
    • minValue

      public double minValue()
      Description copied from interface: SGDVector
      Returns the minimum value. Requires probing the array.
      Specified by:
      minValue in interface SGDVector
      Overrides:
      minValue in class DenseVector
      Returns:
      The minimum value.
    • iterator

      public VectorIterator iterator()
      Specified by:
      iterator in interface Iterable<VectorTuple>
      Overrides:
      iterator in class DenseVector