001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.la; 018 019import java.io.Serializable; 020import java.util.Arrays; 021import java.util.function.DoubleUnaryOperator; 022 023/** 024 * An interface for Tensors, currently Vectors and Matrices. 025 */ 026public interface Tensor extends Serializable { 027 028 public static int shapeSum(int[] shape) { 029 int sum = 1; 030 for (int i = 0; i < shape.length; i++) { 031 sum *= shape[i]; 032 } 033 return sum; 034 } 035 036 public static boolean shapeCheck(Tensor first, Tensor second) { 037 if ((first != null) && (second != null)) { 038 return Arrays.equals(first.getShape(),second.getShape()); 039 } else { 040 return false; 041 } 042 } 043 044 /** 045 * Returns an int array specifying the shape of this {@link Tensor}. 046 * @return An int array. 047 */ 048 public int[] getShape(); 049 050 /** 051 * Reshapes the Tensor to the supplied shape. Throws {@link IllegalArgumentException} if the shape isn't compatible. 052 * @param shape The desired shape. 053 * @return A Tensor of the desired shape. 054 */ 055 public Tensor reshape(int[] shape); 056 057 /** 058 * Updates this {@link Tensor} by adding all the values from the intersection with {@code other}. 059 * <p> 060 * The function {@code f} is applied to all values from {@code other} before the 061 * addition. 062 * <p> 063 * Each value is updated as value += f(otherValue). 064 * @param other The other {@link Tensor}. 065 * @param f A function to apply. 066 */ 067 public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f); 068 069 /** 070 * Same as {@link Tensor#intersectAndAddInPlace}, but applies the identity function. 071 * <p> 072 * Each value is updated as value += otherValue. 073 * @param other The other {@link Tensor}. 074 */ 075 default public void intersectAndAddInPlace(Tensor other) { 076 intersectAndAddInPlace(other, DoubleUnaryOperator.identity()); 077 } 078 079 /** 080 * Updates this {@link Tensor} with the Hadamard product 081 * (i.e., a term by term multiply) of this and {@code other}. 082 * <p> 083 * The function {@code f} is applied to all values from {@code other} before the addition. 084 * <p> 085 * Each value is updated as value *= f(otherValue). 086 * @param other The other {@link Tensor}. 087 * @param f A function to apply. 088 */ 089 public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f); 090 091 /** 092 * Same as {@link Tensor#hadamardProductInPlace}, but applies the identity function. 093 * <p> 094 * Each value is updated as value *= otherValue. 095 * @param other The other {@link Tensor}. 096 */ 097 default public void hadamardProductInPlace(Tensor other) { 098 hadamardProductInPlace(other, DoubleUnaryOperator.identity()); 099 } 100 101 /** 102 * Applies a {@link DoubleUnaryOperator} elementwise to this {@link Tensor}. 103 * @param f The function to apply. 104 */ 105 public void foreachInPlace(DoubleUnaryOperator f); 106 107 /** 108 * Scales each element of this {@link Tensor} by {@code coefficient}. 109 * @param coefficient The coefficient of scaling. 110 */ 111 default public void scaleInPlace(double coefficient) { 112 foreachInPlace(d -> d * coefficient); 113 } 114 115 /** 116 * Adds {@code scalar} to each element of this {@link Tensor}. 117 * @param scalar The scalar to add. 118 */ 119 default public void scalarAddInPlace(double scalar) { 120 foreachInPlace(d -> d + scalar); 121 } 122 123 /** 124 * Calculates the euclidean norm for this vector. 125 * @return The euclidean norm. 126 */ 127 public double twoNorm(); 128}