001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.la; 018 019/** 020 * Interface for 2 dimensional {@link Tensor}s. 021 * <p> 022 * Matrices have immutable sizes and immutable indices (so {@link DenseSparseMatrix} can't grow). 023 */ 024public interface Matrix extends Tensor, Iterable<MatrixTuple> { 025 026 /** 027 * Gets an element from this {@link Matrix}. 028 * @param i The index for the first dimension. 029 * @param j The index for the second dimension. 030 * @return The value at matrix[i][j]. 031 */ 032 public double get(int i, int j); 033 034 /** 035 * Sets an element at the supplied location. 036 * @param i The index for the first dimension. 037 * @param j The index for the second dimension. 038 * @param value The value to be stored at matrix[i][j]. 039 */ 040 public void set(int i, int j, double value); 041 042 /** 043 * Adds the argument value to the value at the supplied index. 044 * @param i The index for the first dimension. 045 * @param j The index for the second dimension. 046 * @param value The value to add. 047 */ 048 public void add(int i, int j, double value); 049 050 /** 051 * The size of the first dimension. 052 * @return The size of the first dimension. 053 */ 054 public int getDimension1Size(); 055 056 /** 057 * The size of the second dimension. 058 * @return The size of the second dimension. 059 */ 060 public int getDimension2Size(); 061 062 /** 063 * The number of non-zero elements in that row. 064 * <p> 065 * An element could be active and zero, if it was active on construction. 066 * 067 * @param row The index of the row. 068 * @return The number of non-zero elements. 069 */ 070 public int numActiveElements(int row); 071 072 /** 073 * Multiplies this Matrix by a {@link SGDVector} returning a vector of the appropriate size. 074 * <p> 075 * The input must have dimension equal to {@link Matrix#getDimension2Size()}. 076 * @param input The input vector. 077 * @return A new {@link SGDVector} of size {@link Matrix#getDimension1Size()}. 078 */ 079 public SGDVector leftMultiply(SGDVector input); 080 081 /** 082 * Multiplies this Matrix by a {@link SGDVector} returning a vector of the appropriate size. 083 * <p> 084 * The input must have dimension equal to {@link Matrix#getDimension1Size()}. 085 * @param input The input vector. 086 * @return A new {@link SGDVector} of size {@link Matrix#getDimension2Size()}. 087 */ 088 public SGDVector rightMultiply(SGDVector input); 089 090 /** 091 * Multiplies this Matrix by another {@link Matrix} returning a matrix of the appropriate size. 092 * <p> 093 * The input must have dimension 1 equal to {@link Matrix#getDimension2Size()}. 094 * @param input The input matrix. 095 * @return A new {@link Matrix} of size {@link Matrix#getDimension1Size()}, {@code input.getDimension2Size()}. 096 */ 097 public Matrix matrixMultiply(Matrix input); 098 099 /** 100 * Multiplies this Matrix by another {@link Matrix} returning a matrix of the appropriate size. 101 * <p> 102 * Must obey the rules of matrix multiplication after the transposes are applied. 103 * @param input The input matrix. 104 * @param transposeThis Implicitly transposes this matrix just for the multiplication. 105 * @param transposeOther Implicitly transposes other just for the multiplication. 106 * @return A new {@link Matrix}. 107 */ 108 public Matrix matrixMultiply(Matrix input, boolean transposeThis, boolean transposeOther); 109 110 /** 111 * Generates a {@link DenseVector} representing the sum of each row. 112 * @return A new {@link DenseVector} of size {@link Matrix#getDimension1Size()}. 113 */ 114 public DenseVector rowSum(); 115 116 /** 117 * Scales each row by the appropriate value in the {@link DenseVector}. 118 * @param scalingCoefficients A {@link DenseVector} with size {@link Matrix#getDimension1Size()}. 119 */ 120 public void rowScaleInPlace(DenseVector scalingCoefficients); 121 122 /** 123 * Extract a row as an {@link SGDVector}. 124 * <p> 125 * This refers to the same values as the matrix, so updating this vector will update the matrix. 126 * @param i The index of the row to extract. 127 * @return An {@link SGDVector}. 128 */ 129 public SGDVector getRow(int i); 130 131}