001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.la; 018 019import java.util.Arrays; 020import java.util.Iterator; 021import java.util.List; 022import java.util.Objects; 023import java.util.function.DoubleUnaryOperator; 024 025/** 026 * A matrix which is dense in the first dimension and sparse in the second. 027 * <p> 028 * Backed by an array of {@link SparseVector}. 029 */ 030public class DenseSparseMatrix implements Matrix { 031 private static final long serialVersionUID = 1L; 032 033 private final SparseVector[] values; 034 private final int dim1; 035 private final int dim2; 036 private final int[] shape; 037 038 DenseSparseMatrix(SparseVector[] values) { 039 this.values = values; 040 this.dim1 = values.length; 041 this.dim2 = values[0].size(); 042 this.shape = new int[]{dim1,dim2}; 043 } 044 045 public DenseSparseMatrix(List<SparseVector> values) { 046 this.values = new SparseVector[values.size()]; 047 this.dim1 = values.size(); 048 this.dim2 = values.get(0).size(); 049 this.shape = new int[]{dim1,dim2}; 050 for (int i = 0; i < values.size(); i++) { 051 this.values[i] = values.get(i); 052 } 053 } 054 055 public DenseSparseMatrix(DenseSparseMatrix other) { 056 this.dim1 = other.dim1; 057 this.dim2 = other.dim2; 058 this.values = new SparseVector[other.values.length]; 059 this.shape = new int[]{dim1,dim2}; 060 for (int i = 0; i < values.length; i++) { 061 values[i] = other.values[i].copy(); 062 } 063 } 064 065 /** 066 * Defensively copies the values. 067 * @param values The sparse vectors to use. 068 * @return A DenseSparseMatrix containing the supplied vectors. 069 */ 070 public static DenseSparseMatrix createFromSparseVectors(SparseVector[] values) { 071 SparseVector[] newValues = new SparseVector[values.length]; 072 for (int i = 0; i < values.length; i++) { 073 newValues[i] = values[i].copy(); 074 } 075 return new DenseSparseMatrix(newValues); 076 } 077 078 @Override 079 public int[] getShape() { 080 return shape; 081 } 082 083 @Override 084 public Tensor reshape(int[] newShape) { 085 throw new UnsupportedOperationException("Reshape not supported on sparse Tensors."); 086 } 087 088 @Override 089 public double get(int i, int j) { 090 return values[i].get(j); 091 } 092 093 @Override 094 public void set(int i, int j, double value) { 095 values[i].set(j,value); 096 } 097 098 @Override 099 public int getDimension1Size() { 100 return dim1; 101 } 102 103 @Override 104 public int getDimension2Size() { 105 return dim2; 106 } 107 108 @Override 109 public DenseVector leftMultiply(SGDVector input) { 110 if (input.size() == dim2) { 111 double[] output = new double[dim1]; 112 113 for (int i = 0; i < output.length; i++) { 114 output[i] = values[i].dot(input); 115 } 116 117 return new DenseVector(output); 118 } else { 119 throw new IllegalArgumentException("input.size() != dim2"); 120 } 121 } 122 123 /** 124 * rightMultiply is very inefficient on DenseSparseMatrix due to the storage format. 125 * @param input The input vector. 126 * @return A*input. 127 */ 128 @Override 129 public DenseVector rightMultiply(SGDVector input) { 130 if (input.size() == dim1) { 131 double[] output = new double[dim2]; 132 133 for (int j = 0; j < values.length; j++) { 134 for (int i = 0; i < output.length; i++) { 135 output[i] = values[j].get(i) * input.get(i); 136 } 137 } 138 139 return new DenseVector(output); 140 } else { 141 throw new IllegalArgumentException("input.size() != dim1"); 142 } 143 } 144 145 @Override 146 public void add(int i, int j, double value) { 147 values[i].add(j,value); 148 } 149 150 /** 151 * Only implemented for {@link DenseMatrix}. 152 * @param other The other {@link Tensor}. 153 * @param f A function to apply. 154 */ 155 @Override 156 public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) { 157 if (other instanceof Matrix) { 158 Matrix otherMat = (Matrix) other; 159 if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) { 160 if (otherMat instanceof DenseMatrix) { 161 DenseMatrix otherDenseMat = (DenseMatrix) other; 162 for (int i = 0; i < dim1; i++) { 163 values[i].intersectAndAddInPlace(otherDenseMat.getRow(i),f); 164 } 165 } else { 166 throw new UnsupportedOperationException("Not implemented intersectAndAddInPlace in DenseSparseMatrix for types other than DenseMatrix"); 167 } 168 } else { 169 throw new IllegalArgumentException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")"); 170 } 171 } else { 172 throw new IllegalArgumentException("Adding a non-Matrix to a Matrix"); 173 } 174 } 175 176 /** 177 * Only implemented for {@link DenseMatrix}. 178 * @param other The other {@link Tensor}. 179 * @param f A function to apply. 180 */ 181 @Override 182 public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f) { 183 if (other instanceof Matrix) { 184 Matrix otherMat = (Matrix) other; 185 if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) { 186 if (otherMat instanceof DenseMatrix) { 187 DenseMatrix otherDenseMat = (DenseMatrix) other; 188 for (int i = 0; i < dim1; i++) { 189 values[i].hadamardProductInPlace(otherDenseMat.getRow(i),f); 190 } 191 } else { 192 throw new UnsupportedOperationException("Not implemented hadamardProductInPlace in DenseSparseMatrix for types other than DenseMatrix"); 193 } 194 } else { 195 throw new IllegalArgumentException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")"); 196 } 197 } else { 198 throw new IllegalArgumentException("Scaling a Matrix by a non-Matrix"); 199 } 200 } 201 202 @Override 203 public void foreachInPlace(DoubleUnaryOperator f) { 204 for (int i = 0; i < values.length; i++) { 205 values[i].foreachInPlace(f); 206 } 207 } 208 209 @Override 210 public int numActiveElements(int row) { 211 return values[row].numActiveElements(); 212 } 213 214 @Override 215 public SparseVector getRow(int i) { 216 return values[i]; 217 } 218 219 @Override 220 public boolean equals(Object other) { 221 if (other instanceof Matrix) { 222 Iterator<MatrixTuple> ourItr = iterator(); 223 Iterator<MatrixTuple> otherItr = ((Matrix)other).iterator(); 224 MatrixTuple ourTuple; 225 MatrixTuple otherTuple; 226 227 while (ourItr.hasNext() && otherItr.hasNext()) { 228 ourTuple = ourItr.next(); 229 otherTuple = otherItr.next(); 230 if (!ourTuple.equals(otherTuple)) { 231 return false; 232 } 233 } 234 235 // If one of the iterators still has elements then they are not the same. 236 return !(ourItr.hasNext() || otherItr.hasNext()); 237 } else { 238 return false; 239 } 240 } 241 242 @Override 243 public int hashCode() { 244 int result = Objects.hash(dim1, dim2); 245 result = 31 * result + Arrays.hashCode(values); 246 return result; 247 } 248 249 @Override 250 public double twoNorm() { 251 double output = 0.0; 252 for (int i = 0; i < dim1; i++) { 253 double value = values[i].twoNorm(); 254 output += value * value; 255 } 256 return Math.sqrt(output); 257 } 258 259 @Override 260 public DenseMatrix matrixMultiply(Matrix other) { 261 if (dim2 == other.getDimension1Size()) { 262 if (other instanceof DenseMatrix) { 263 DenseMatrix otherDense = (DenseMatrix) other; 264 double[][] output = new double[dim1][otherDense.dim2]; 265 266 for (int i = 0; i < dim1; i++) { 267 for (int j = 0; j < otherDense.dim2; j++) { 268 output[i][j] = columnRowDot(i,j,otherDense); 269 } 270 } 271 272 return new DenseMatrix(output); 273 } else if (other instanceof DenseSparseMatrix) { 274 DenseSparseMatrix otherSparse = (DenseSparseMatrix) other; 275 int otherDim2 = otherSparse.getDimension2Size(); 276 double[][] output = new double[dim1][otherDim2]; 277 278 for (int i = 0; i < dim1; i++) { 279 for (int j = 0; j < otherDim2; j++) { 280 output[i][j] = columnRowDot(i,j,otherSparse); 281 } 282 } 283 284 return new DenseMatrix(output); 285 } else { 286 throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName()); 287 } 288 } else { 289 throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape())); 290 } 291 } 292 293 @Override 294 public DenseMatrix matrixMultiply(Matrix other, boolean transposeThis, boolean transposeOther) { 295 if (transposeThis && transposeOther) { 296 return matrixMultiplyTransposeBoth(other); 297 } else if (transposeThis) { 298 return matrixMultiplyTransposeThis(other); 299 } else if (transposeOther) { 300 return matrixMultiplyTransposeOther(other); 301 } else { 302 return matrixMultiply(other); 303 } 304 } 305 306 private DenseMatrix matrixMultiplyTransposeBoth(Matrix other) { 307 if (dim1 == other.getDimension2Size()) { 308 if (other instanceof DenseMatrix) { 309 DenseMatrix otherDense = (DenseMatrix) other; 310 double[][] output = new double[dim2][otherDense.dim1]; 311 312 for (int i = 0; i < dim2; i++) { 313 for (int j = 0; j < otherDense.dim1; j++) { 314 output[i][j] = rowColumnDot(i,j,otherDense); 315 } 316 } 317 318 return new DenseMatrix(output); 319 } else if (other instanceof DenseSparseMatrix) { 320 DenseSparseMatrix otherSparse = (DenseSparseMatrix) other; 321 int otherDim1 = otherSparse.getDimension1Size(); 322 double[][] output = new double[dim2][otherDim1]; 323 324 for (int i = 0; i < dim2; i++) { 325 for (int j = 0; j < otherDim1; j++) { 326 output[i][j] = rowColumnDot(i,j,otherSparse); 327 } 328 } 329 330 return new DenseMatrix(output); 331 } else { 332 throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName()); 333 } 334 } else { 335 throw new IllegalArgumentException("Invalid matrix dimensions, dim1 = " + dim1 + ", other.dim2 = " + other.getDimension2Size()); 336 } 337 } 338 339 private DenseMatrix matrixMultiplyTransposeThis(Matrix other) { 340 if (dim1 == other.getDimension1Size()) { 341 if (other instanceof DenseMatrix) { 342 DenseMatrix otherDense = (DenseMatrix) other; 343 double[][] output = new double[dim2][otherDense.dim2]; 344 345 for (int i = 0; i < dim2; i++) { 346 for (int j = 0; j < otherDense.dim2; j++) { 347 output[i][j] = columnColumnDot(i,j,otherDense); 348 } 349 } 350 351 return new DenseMatrix(output); 352 } else if (other instanceof DenseSparseMatrix) { 353 DenseSparseMatrix otherSparse = (DenseSparseMatrix) other; 354 int otherDim2 = otherSparse.getDimension2Size(); 355 double[][] output = new double[dim2][otherDim2]; 356 357 for (int i = 0; i < dim2; i++) { 358 for (int j = 0; j < otherDim2; j++) { 359 output[i][j] = columnColumnDot(i,j,otherSparse); 360 } 361 } 362 363 return new DenseMatrix(output); 364 } else { 365 throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName()); 366 } 367 } else { 368 throw new IllegalArgumentException("Invalid matrix dimensions, dim1 = " + dim1 + ", other.dim1 = " + other.getDimension1Size()); 369 } 370 } 371 372 private DenseMatrix matrixMultiplyTransposeOther(Matrix other) { 373 if (dim2 == other.getDimension2Size()) { 374 if (other instanceof DenseMatrix) { 375 DenseMatrix otherDense = (DenseMatrix) other; 376 double[][] output = new double[dim1][otherDense.dim1]; 377 378 for (int i = 0; i < dim1; i++) { 379 for (int j = 0; j < otherDense.dim1; j++) { 380 output[i][j] = rowRowDot(i,j,otherDense); 381 } 382 } 383 384 return new DenseMatrix(output); 385 } else if (other instanceof DenseSparseMatrix) { 386 DenseSparseMatrix otherSparse = (DenseSparseMatrix) other; 387 int otherDim1 = otherSparse.getDimension1Size(); 388 double[][] output = new double[dim1][otherDim1]; 389 390 for (int i = 0; i < dim1; i++) { 391 for (int j = 0; j < otherDim1; j++) { 392 output[i][j] = rowRowDot(i,j,otherSparse); 393 } 394 } 395 396 return new DenseMatrix(output); 397 } else { 398 throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName()); 399 } 400 } else { 401 throw new IllegalArgumentException("Invalid matrix dimensions, dim2 = " + dim2 + ", other.dim2 = " + other.getDimension2Size()); 402 } 403 } 404 405 private double columnRowDot(int rowIndex, int otherColIndex, Matrix other) { 406 double sum = 0.0; 407 for (VectorTuple tuple : values[rowIndex]) { 408 sum += tuple.value * other.get(tuple.index,otherColIndex); 409 } 410 return sum; 411 } 412 413 private double rowColumnDot(int colIndex, int otherRowIndex, Matrix other) { 414 double sum = 0.0; 415 for (int i = 0; i < dim1; i++) { 416 sum += get(i,colIndex) * other.get(otherRowIndex,i); 417 } 418 return sum; 419 } 420 421 private double columnColumnDot(int colIndex, int otherColIndex, Matrix other) { 422 double sum = 0.0; 423 for (int i = 0; i < dim1; i++) { 424 sum += get(i,colIndex) * other.get(i,otherColIndex); 425 } 426 return sum; 427 } 428 429 private double rowRowDot(int rowIndex, int otherRowIndex, Matrix other) { 430 double sum = 0.0; 431 for (VectorTuple tuple : values[rowIndex]) { 432 sum += tuple.value * other.get(otherRowIndex,tuple.index); 433 } 434 return sum; 435 } 436 437 @Override 438 public DenseVector rowSum() { 439 double[] rowSum = new double[dim1]; 440 for (int i = 0; i < dim1; i++) { 441 rowSum[i] = values[i].sum(); 442 } 443 return new DenseVector(rowSum); 444 } 445 446 @Override 447 public void rowScaleInPlace(DenseVector scalingCoefficients) { 448 for (int i = 0; i < dim1; i++) { 449 values[i].scaleInPlace(scalingCoefficients.get(i)); 450 } 451 } 452 453 @Override 454 public String toString() { 455 StringBuilder buffer = new StringBuilder(); 456 457 buffer.append("DenseSparseMatrix(\n"); 458 for (int i = 0; i < values.length; i++) { 459 buffer.append("\t"); 460 buffer.append(values[i].toString()); 461 buffer.append(";\n"); 462 } 463 buffer.append(")"); 464 465 return buffer.toString(); 466 } 467 468 @Override 469 public MatrixIterator iterator() { 470 return new DenseSparseMatrixIterator(this); 471 } 472 473 private static class DenseSparseMatrixIterator implements MatrixIterator { 474 private final DenseSparseMatrix matrix; 475 private final MatrixTuple tuple; 476 private int i; 477 private Iterator<VectorTuple> itr; 478 private VectorTuple vecTuple; 479 480 public DenseSparseMatrixIterator(DenseSparseMatrix matrix) { 481 this.matrix = matrix; 482 this.tuple = new MatrixTuple(); 483 this.i = 0; 484 this.itr = matrix.values[0].iterator(); 485 } 486 487 @Override 488 public String toString() { 489 return "DenseSparseMatrixIterator(position="+i+",tuple="+ tuple.toString()+")"; 490 } 491 492 @Override 493 public MatrixTuple getReference() { 494 return tuple; 495 } 496 497 @Override 498 public boolean hasNext() { 499 if (itr.hasNext()) { 500 return true; 501 } else { 502 while ((i < matrix.dim1) && (!itr.hasNext())) { 503 i++; 504 if (i < matrix.dim1) { 505 itr = matrix.values[i].iterator(); 506 } 507 } 508 } 509 return (i < matrix.dim1) && itr.hasNext(); 510 } 511 512 @Override 513 public MatrixTuple next() { 514 vecTuple = itr.next(); 515 tuple.i = i; 516 tuple.j = vecTuple.index; 517 tuple.value = vecTuple.value; 518 return tuple; 519 } 520 } 521}