Package org.tribuo.math.la
Interface Tensor
- All Superinterfaces:
Serializable
- All Known Implementing Classes:
DenseMatrix
,DenseSparseMatrix
,DenseVector
,ShrinkingMatrix
,ShrinkingVector
,SparseVector
An interface for Tensors, currently Vectors and Matrices.
-
Method Summary
Modifier and TypeMethodDescriptioncopy()
Returns a copy of this Tensor.void
Applies aDoubleUnaryOperator
elementwise to thisTensor
.int[]
getShape()
Returns an int array specifying the shape of thisTensor
.default void
hadamardProductInPlace
(Tensor other) Same ashadamardProductInPlace(org.tribuo.math.la.Tensor, java.util.function.DoubleUnaryOperator)
, but applies the identity function.void
hadamardProductInPlace
(Tensor other, DoubleUnaryOperator f) Updates thisTensor
with the Hadamard product (i.e., a term by term multiply) of this andother
.default void
intersectAndAddInPlace
(Tensor other) Same asintersectAndAddInPlace(org.tribuo.math.la.Tensor, java.util.function.DoubleUnaryOperator)
, but applies the identity function.void
intersectAndAddInPlace
(Tensor other, DoubleUnaryOperator f) Updates thisTensor
by adding all the values from the intersection withother
.reshape
(int[] shape) Reshapes the Tensor to the supplied shape.default void
scalarAddInPlace
(double scalar) Addsscalar
to each element of thisTensor
.default void
scaleInPlace
(double coefficient) Scales each element of thisTensor
bycoefficient
.static boolean
shapeCheck
(Tensor first, Tensor second) Checks that the two tensors have compatible shapes.static int
shapeSum
(int[] shape) The number of elements in this shape, i.e., the product of the shape array.double
twoNorm()
Calculates the euclidean norm for this vector.
-
Method Details
-
shapeSum
static int shapeSum(int[] shape) The number of elements in this shape, i.e., the product of the shape array.- Parameters:
shape
- The tensor shape.- Returns:
- The total number of elements.
-
shapeCheck
Checks that the two tensors have compatible shapes.Compatible shapes are those which are exactly equal, as Tribuo does not support broadcasting.
- Parameters:
first
- The first tensor.second
- The second tensor.- Returns:
- True if the shapes are the same.
-
getShape
int[] getShape()Returns an int array specifying the shape of thisTensor
.- Returns:
- An int array.
-
reshape
Reshapes the Tensor to the supplied shape. ThrowsIllegalArgumentException
if the shape isn't compatible.- Parameters:
shape
- The desired shape.- Returns:
- A Tensor of the desired shape.
-
copy
Tensor copy()Returns a copy of this Tensor.- Returns:
- A copy of the Tensor.
-
intersectAndAddInPlace
Updates thisTensor
by adding all the values from the intersection withother
.The function
f
is applied to all values fromother
before the addition.Each value is updated as value += f(otherValue).
- Parameters:
other
- The otherTensor
.f
- A function to apply.
-
intersectAndAddInPlace
Same asintersectAndAddInPlace(org.tribuo.math.la.Tensor, java.util.function.DoubleUnaryOperator)
, but applies the identity function.Each value is updated as value += otherValue.
- Parameters:
other
- The otherTensor
.
-
hadamardProductInPlace
Updates thisTensor
with the Hadamard product (i.e., a term by term multiply) of this andother
.The function
f
is applied to all values fromother
before the addition.Each value is updated as value *= f(otherValue).
- Parameters:
other
- The otherTensor
.f
- A function to apply.
-
hadamardProductInPlace
Same ashadamardProductInPlace(org.tribuo.math.la.Tensor, java.util.function.DoubleUnaryOperator)
, but applies the identity function.Each value is updated as value *= otherValue.
- Parameters:
other
- The otherTensor
.
-
foreachInPlace
Applies aDoubleUnaryOperator
elementwise to thisTensor
.- Parameters:
f
- The function to apply.
-
scaleInPlace
default void scaleInPlace(double coefficient) Scales each element of thisTensor
bycoefficient
.- Parameters:
coefficient
- The coefficient of scaling.
-
scalarAddInPlace
default void scalarAddInPlace(double scalar) Addsscalar
to each element of thisTensor
.- Parameters:
scalar
- The scalar to add.
-
twoNorm
double twoNorm()Calculates the euclidean norm for this vector.- Returns:
- The euclidean norm.
-