Package org.tribuo.math.optimisers.util
Class ShrinkingVector
java.lang.Object
org.tribuo.math.la.DenseVector
org.tribuo.math.optimisers.util.ShrinkingVector
- All Implemented Interfaces:
Serializable
,Iterable<VectorTuple>
,SGDVector
,Tensor
,ShrinkingTensor
,ProtoSerializable<org.tribuo.math.protos.TensorProto>
A subclass of
DenseVector
which shrinks the value every time a new value is added.
Be careful when modifying this or DenseVector
.
- See Also:
-
Field Summary
Fields inherited from class org.tribuo.math.la.DenseVector
CURRENT_VERSION, elements
Fields inherited from interface org.tribuo.protos.ProtoSerializable
DESERIALIZATION_METHOD_NAME, PROVENANCE_SERIALIZER
Fields inherited from interface org.tribuo.math.optimisers.util.ShrinkingTensor
tolerance
-
Constructor Summary
ConstructorDescriptionShrinkingVector
(DenseVector v, double baseRate, boolean scaleShrinking) Constructs a shrinking vector copy of the supplied dense matrix.ShrinkingVector
(DenseVector v, double baseRate, double lambda) Constructs a shrinking vector copy of the supplied dense vector. -
Method Summary
Modifier and TypeMethodDescriptionConverts the tensor into a dense tensor.copy()
Returns a deep copy of this vector.static ShrinkingVector
deserializeFromProto
(int version, String className, com.google.protobuf.Any message) Deserialization factory.double
Calculates the dot product between this vector andother
.double
get
(int index) Gets an element from this vector.int
Returns the index of the maximum value.void
intersectAndAddInPlace
(Tensor other, DoubleUnaryOperator f) Updates thisTensor
by adding all the values from the intersection withother
.iterator()
double
maxValue()
Returns the maximum value.double
minValue()
Returns the minimum value.void
scaleInPlace
(double value) Scales each element of thisTensor
bycoefficient
.org.tribuo.math.protos.TensorProto
Serializes this object to a protobuf.double
sum()
Calculates the sum of this vector.double[]
toArray()
Generates a copy of the values in this DenseVector.double
twoNorm()
Calculates the euclidean norm for this vector.Methods inherited from class org.tribuo.math.la.DenseVector
add, add, createDenseVector, createDenseVector, equals, euclideanDistance, expNormalize, fill, foreachIndexedInPlace, foreachInPlace, getShape, hadamardProductInPlace, hashCode, l1Distance, meanVariance, normalize, numActiveElements, oneNorm, outer, reduce, reduce, reshape, scale, set, setElements, size, sparsify, sparsify, subtract, sum, toString, unpackProto, variance
Methods inherited from class java.lang.Object
clone, finalize, getClass, notify, notifyAll, wait, wait, wait
Methods inherited from interface java.lang.Iterable
forEach, spliterator
Methods inherited from interface org.tribuo.math.la.SGDVector
cosineDistance, cosineSimilarity, l2Distance, variance
Methods inherited from interface org.tribuo.math.la.Tensor
hadamardProductInPlace, intersectAndAddInPlace, scalarAddInPlace
-
Constructor Details
-
ShrinkingVector
Constructs a shrinking vector copy of the supplied dense matrix.This vector shrinks during each call to
intersectAndAddInPlace(Tensor, DoubleUnaryOperator)
.- Parameters:
v
- The vector to copy.baseRate
- The base amount of shrinking to apply after each update.scaleShrinking
- If true reduce the shrinking value over time proportionally to the number of updates.
-
ShrinkingVector
Constructs a shrinking vector copy of the supplied dense vector.This vector shrinks during each call to
intersectAndAddInPlace(Tensor, DoubleUnaryOperator)
, and then reprojects the vector so it has the same twoNorm.- Parameters:
v
- The vector to copy.baseRate
- The base rate of shrinkage.lambda
- The lambda value (seePegasos
).
-
-
Method Details
-
deserializeFromProto
public static ShrinkingVector deserializeFromProto(int version, String className, com.google.protobuf.Any message) throws com.google.protobuf.InvalidProtocolBufferException Deserialization factory.- Parameters:
version
- The serialized object version.className
- The class name.message
- The serialized data.- Returns:
- The deserialized object.
- Throws:
com.google.protobuf.InvalidProtocolBufferException
- If the protobuf could not be parsed from themessage
.
-
serialize
public org.tribuo.math.protos.TensorProto serialize()Description copied from interface:ProtoSerializable
Serializes this object to a protobuf.- Specified by:
serialize
in interfaceProtoSerializable<org.tribuo.math.protos.TensorProto>
- Overrides:
serialize
in classDenseVector
- Returns:
- The protobuf.
-
convertToDense
Description copied from interface:ShrinkingTensor
Converts the tensor into a dense tensor.- Specified by:
convertToDense
in interfaceShrinkingTensor
- Returns:
- A dense tensor copy of this shrinking tensor.
-
copy
Description copied from interface:SGDVector
Returns a deep copy of this vector. -
toArray
public double[] toArray()Description copied from class:DenseVector
Generates a copy of the values in this DenseVector.This implementation uses Arrays.copyOf, and should be overridden if the get function has been modified.
- Specified by:
toArray
in interfaceSGDVector
- Overrides:
toArray
in classDenseVector
- Returns:
- A copy of the values in this DenseVector.
-
get
public double get(int index) Description copied from interface:SGDVector
Gets an element from this vector.- Specified by:
get
in interfaceSGDVector
- Overrides:
get
in classDenseVector
- Parameters:
index
- The index of the element.- Returns:
- The value at that index.
-
sum
public double sum()Description copied from interface:SGDVector
Calculates the sum of this vector.- Specified by:
sum
in interfaceSGDVector
- Overrides:
sum
in classDenseVector
- Returns:
- The sum.
-
intersectAndAddInPlace
Description copied from interface:Tensor
Updates thisTensor
by adding all the values from the intersection withother
.The function
f
is applied to all values fromother
before the addition.Each value is updated as value += f(otherValue).
- Specified by:
intersectAndAddInPlace
in interfaceTensor
- Overrides:
intersectAndAddInPlace
in classDenseVector
- Parameters:
other
- The otherTensor
.f
- A function to apply.
-
indexOfMax
public int indexOfMax()Description copied from interface:SGDVector
Returns the index of the maximum value. Requires probing the array.- Specified by:
indexOfMax
in interfaceSGDVector
- Overrides:
indexOfMax
in classDenseVector
- Returns:
- The index of the maximum value.
-
dot
Description copied from interface:SGDVector
Calculates the dot product between this vector andother
.- Specified by:
dot
in interfaceSGDVector
- Overrides:
dot
in classDenseVector
- Parameters:
other
- The other vector.- Returns:
- The dot product.
-
scaleInPlace
public void scaleInPlace(double value) Description copied from interface:Tensor
Scales each element of thisTensor
bycoefficient
.- Specified by:
scaleInPlace
in interfaceTensor
- Parameters:
value
- The coefficient of scaling.
-
twoNorm
public double twoNorm()Description copied from interface:SGDVector
Calculates the euclidean norm for this vector. -
maxValue
public double maxValue()Description copied from interface:SGDVector
Returns the maximum value. Requires probing the array.- Specified by:
maxValue
in interfaceSGDVector
- Overrides:
maxValue
in classDenseVector
- Returns:
- The maximum value.
-
minValue
public double minValue()Description copied from interface:SGDVector
Returns the minimum value. Requires probing the array.- Specified by:
minValue
in interfaceSGDVector
- Overrides:
minValue
in classDenseVector
- Returns:
- The minimum value.
-
iterator
- Specified by:
iterator
in interfaceIterable<VectorTuple>
- Overrides:
iterator
in classDenseVector
-