Interface Tensor

All Superinterfaces:
ProtoSerializable<org.tribuo.math.protos.TensorProto>, Serializable
All Known Subinterfaces:
Matrix, SGDVector
All Known Implementing Classes:
DenseMatrix, DenseSparseMatrix, DenseVector, ShrinkingMatrix, ShrinkingVector, SparseVector

public interface Tensor extends ProtoSerializable<org.tribuo.math.protos.TensorProto>, Serializable
An interface for Tensors, currently Vectors and Matrices.
  • Method Details

    • shapeSum

      static int shapeSum(int[] shape)
      The number of elements in this shape, i.e., the product of the shape array.
      Parameters:
      shape - The tensor shape.
      Returns:
      The total number of elements.
    • shapeCheck

      static boolean shapeCheck(Tensor first, Tensor second)
      Checks that the two tensors have compatible shapes.

      Compatible shapes are those which are exactly equal, as Tribuo does not support broadcasting.

      Parameters:
      first - The first tensor.
      second - The second tensor.
      Returns:
      True if the shapes are the same.
    • getShape

      int[] getShape()
      Returns an int array specifying the shape of this Tensor.
      Returns:
      An int array.
    • reshape

      Tensor reshape(int[] shape)
      Reshapes the Tensor to the supplied shape. Throws IllegalArgumentException if the shape isn't compatible.
      Parameters:
      shape - The desired shape.
      Returns:
      A Tensor of the desired shape.
    • copy

      Tensor copy()
      Returns a copy of this Tensor.
      Returns:
      A copy of the Tensor.
    • intersectAndAddInPlace

      void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f)
      Updates this Tensor by adding all the values from the intersection with other.

      The function f is applied to all values from other before the addition.

      Each value is updated as value += f(otherValue).

      Parameters:
      other - The other Tensor.
      f - A function to apply.
    • intersectAndAddInPlace

      default void intersectAndAddInPlace(Tensor other)
      Same as intersectAndAddInPlace(org.tribuo.math.la.Tensor, java.util.function.DoubleUnaryOperator), but applies the identity function.

      Each value is updated as value += otherValue.

      Parameters:
      other - The other Tensor.
    • hadamardProductInPlace

      void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f)
      Updates this Tensor with the Hadamard product (i.e., a term by term multiply) of this and other.

      The function f is applied to all values from other before the addition.

      Each value is updated as value *= f(otherValue).

      Parameters:
      other - The other Tensor.
      f - A function to apply.
    • hadamardProductInPlace

      default void hadamardProductInPlace(Tensor other)
      Same as hadamardProductInPlace(org.tribuo.math.la.Tensor, java.util.function.DoubleUnaryOperator), but applies the identity function.

      Each value is updated as value *= otherValue.

      Parameters:
      other - The other Tensor.
    • foreachInPlace

      void foreachInPlace(DoubleUnaryOperator f)
      Applies a DoubleUnaryOperator elementwise to this Tensor.
      Parameters:
      f - The function to apply.
    • scaleInPlace

      default void scaleInPlace(double coefficient)
      Scales each element of this Tensor by coefficient.
      Parameters:
      coefficient - The coefficient of scaling.
    • scalarAddInPlace

      default void scalarAddInPlace(double scalar)
      Adds scalar to each element of this Tensor.
      Parameters:
      scalar - The scalar to add.
    • twoNorm

      double twoNorm()
      Calculates the euclidean norm for this vector.
      Returns:
      The euclidean norm.
    • deserialize

      static Tensor deserialize(org.tribuo.math.protos.TensorProto proto)
      Deserialize a tensor proto into a Tensor.

      Throws IllegalArgumentException if the proto is invalid.

      Parameters:
      proto - The proto to deserialize.
      Returns:
      The tensor.