001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.la;
018
019import java.io.Serializable;
020import java.util.Arrays;
021import java.util.function.DoubleUnaryOperator;
022
023/**
024 * An interface for Tensors, currently Vectors and Matrices.
025 */
026public interface Tensor extends Serializable {
027
028    public static int shapeSum(int[] shape) {
029        int sum = 1;
030        for (int i = 0; i < shape.length; i++) {
031            sum *= shape[i];
032        }
033        return sum;
034    }
035
036    public static boolean shapeCheck(Tensor first, Tensor second) {
037        if ((first != null) && (second != null)) {
038            return Arrays.equals(first.getShape(),second.getShape());
039        } else {
040            return false;
041        }
042    }
043
044    /**
045     * Returns an int array specifying the shape of this {@link Tensor}.
046     * @return An int array.
047     */
048    public int[] getShape();
049
050    /**
051     * Reshapes the Tensor to the supplied shape. Throws {@link IllegalArgumentException} if the shape isn't compatible.
052     * @param shape The desired shape.
053     * @return A Tensor of the desired shape.
054     */
055    public Tensor reshape(int[] shape);
056
057    /**
058     * Updates this {@link Tensor} by adding all the values from the intersection with {@code other}.
059     * <p>
060     * The function {@code f} is applied to all values from {@code other} before the
061     * addition.
062     * <p>
063     * Each value is updated as value += f(otherValue).
064     * @param other The other {@link Tensor}.
065     * @param f A function to apply.
066     */
067    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f);
068
069    /**
070     * Same as {@link Tensor#intersectAndAddInPlace}, but applies the identity function.
071     * <p>
072     * Each value is updated as value += otherValue.
073     * @param other The other {@link Tensor}.
074     */
075    default public void intersectAndAddInPlace(Tensor other) {
076        intersectAndAddInPlace(other, DoubleUnaryOperator.identity());
077    }
078
079    /**
080     * Updates this {@link Tensor} with the Hadamard product
081     * (i.e., a term by term multiply) of this and {@code other}.
082     * <p>
083     * The function {@code f} is applied to all values from {@code other} before the addition.
084     * <p>
085     * Each value is updated as value *= f(otherValue).
086     * @param other The other {@link Tensor}.
087     * @param f A function to apply.
088     */
089    public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f);
090
091    /**
092     * Same as {@link Tensor#hadamardProductInPlace}, but applies the identity function.
093     * <p>
094     * Each value is updated as value *= otherValue.
095     * @param other The other {@link Tensor}.
096     */
097    default public void hadamardProductInPlace(Tensor other) {
098        hadamardProductInPlace(other, DoubleUnaryOperator.identity());
099    }
100
101    /**
102     * Applies a {@link DoubleUnaryOperator} elementwise to this {@link Tensor}.
103     * @param f The function to apply.
104     */
105    public void foreachInPlace(DoubleUnaryOperator f);
106
107    /**
108     * Scales each element of this {@link Tensor} by {@code coefficient}.
109     * @param coefficient The coefficient of scaling.
110     */
111    default public void scaleInPlace(double coefficient) {
112        foreachInPlace(d -> d * coefficient);
113    }
114
115    /**
116     * Adds {@code scalar} to each element of this {@link Tensor}.
117     * @param scalar The scalar to add.
118     */
119    default public void scalarAddInPlace(double scalar) {
120        foreachInPlace(d -> d + scalar);
121    }
122
123    /**
124     * Calculates the euclidean norm for this vector.
125     * @return The euclidean norm.
126     */
127    public double twoNorm();
128}