001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.la; 018 019import org.tribuo.math.util.VectorNormalizer; 020 021import java.util.Arrays; 022import java.util.NoSuchElementException; 023import java.util.Objects; 024import java.util.function.DoubleUnaryOperator; 025 026/** 027 * A dense matrix, backed by a primitive array. 028 */ 029public class DenseMatrix implements Matrix { 030 private static final long serialVersionUID = 1L; 031 032 private static final double DELTA = 1e-10; 033 034 protected final double[][] values; 035 protected final int dim1; 036 protected final int dim2; 037 038 private final int[] shape; 039 040 private final int numElements; 041 042 public DenseMatrix(int dim1, int dim2) { 043 this.values = new double[dim1][dim2]; 044 this.dim1 = dim1; 045 this.dim2 = dim2; 046 this.shape = new int[]{dim1,dim2}; 047 this.numElements = dim1*dim2; 048 } 049 050 public DenseMatrix(DenseMatrix other) { 051 this.values = new double[other.values.length][]; 052 for (int i = 0; i < values.length; i++) { 053 this.values[i] = new double[other.values[i].length]; 054 for (int j = 0; j < values[i].length; j++) { 055 this.values[i][j] = other.get(i,j); 056 } 057 } 058 this.dim1 = other.dim1; 059 this.dim2 = other.dim2; 060 this.shape = new int[]{dim1,dim2}; 061 this.numElements = dim1*dim2; 062 } 063 064 public DenseMatrix(Matrix other) { 065 this.dim1 = other.getDimension1Size(); 066 this.dim2 = other.getDimension2Size(); 067 this.values = new double[dim1][dim2]; 068 for (MatrixTuple t : other) { 069 this.values[t.i][t.j] = t.value; 070 } 071 this.shape = new int[]{dim1,dim2}; 072 this.numElements = dim1*dim2; 073 } 074 075 /** 076 * Creates a DenseMatrix without defensive copying. 077 * @param values The values of the matrix. 078 */ 079 DenseMatrix(double[][] values) { 080 this.values = values; 081 this.dim1 = values.length; 082 this.dim2 = values[0].length; 083 this.shape = new int[]{dim1,dim2}; 084 this.numElements = dim1*dim2; 085 } 086 087 /** 088 * Defensively copies the values before construction. 089 * <p> 090 * Throws IllegalArgumentException if the supplied values are a ragged array. 091 * @param values The values of this dense matrix. 092 * @return A new dense matrix. 093 */ 094 public static DenseMatrix createDenseMatrix(double[][] values) { 095 double[][] newValues = new double[values.length][]; 096 int sizeCounter = -1; 097 for (int i = 0; i < newValues.length; i++) { 098 if (sizeCounter == -1) { 099 sizeCounter = values[i].length; 100 } 101 if (sizeCounter != values[i].length) { 102 throw new IllegalArgumentException("DenseMatrix must not be ragged. Expected dim2 = " + sizeCounter + ", but found " + values[i].length + " at index " + i); 103 } 104 newValues[i] = Arrays.copyOf(values[i],values[i].length); 105 } 106 return new DenseMatrix(newValues); 107 } 108 109 @Override 110 public int[] getShape() { 111 return shape; 112 } 113 114 @Override 115 public Tensor reshape(int[] newShape) { 116 int sum = Tensor.shapeSum(newShape); 117 if (sum != numElements) { 118 throw new IllegalArgumentException("Invalid shape " + Arrays.toString(newShape) + ", expected something with " + numElements + " elements."); 119 } 120 121 if (newShape.length == 2) { 122 DenseMatrix matrix = new DenseMatrix(newShape[0],newShape[1]); 123 124 for (int a = 0; a < numElements; a++) { 125 int oldI = a % dim1; 126 int oldJ = a % dim2; 127 int i = a % newShape[0]; 128 int j = a / newShape[0]; 129 matrix.set(i,j,get(oldI,oldJ)); 130 } 131 132 return matrix; 133 } else if (newShape.length == 1) { 134 DenseVector vector = new DenseVector(numElements); 135 int a = 0; 136 for (int i = 0; i < dim1; i++) { 137 for (int j = 0; j < dim2; j++) { 138 vector.set(a,get(i,j)); 139 a++; 140 } 141 } 142 return vector; 143 } else { 144 throw new IllegalArgumentException("Only supports 1 or 2 dimensional tensors."); 145 } 146 } 147 148 /** 149 * Copies the matrix. 150 * @return A deep copy of the matrix. 151 */ 152 public DenseMatrix copy() { 153 return new DenseMatrix(this); 154 } 155 156 @Override 157 public double get(int i, int j) { 158 return values[i][j]; 159 } 160 161 public DenseVector gatherAcrossDim1(int[] elements) { 162 if (elements.length != dim2) { 163 throw new IllegalArgumentException("Invalid number of elements to gather, must select one per value of dim2"); 164 } 165 double[] outputValues = new double[dim2]; 166 167 for (int i = 0; i < elements.length; i++) { 168 outputValues[i] = values[elements[i]][i]; 169 } 170 171 return new DenseVector(outputValues); 172 } 173 174 public DenseVector gatherAcrossDim2(int[] elements) { 175 if (elements.length != dim1) { 176 throw new IllegalArgumentException("Invalid number of elements to gather, must select one per value of dim1"); 177 } 178 double[] outputValues = new double[dim1]; 179 180 for (int i = 0; i < elements.length; i++) { 181 outputValues[i] = values[i][elements[i]]; 182 } 183 184 return new DenseVector(outputValues); 185 } 186 187 public DenseMatrix transpose() { 188 double[][] newValues = new double[dim2][dim1]; 189 190 for (int i = 0; i < dim1; i++) { 191 for (int j = 0; j < dim2; j++) { 192 newValues[j][i] = values[i][j]; 193 } 194 } 195 196 return new DenseMatrix(newValues); 197 } 198 199 @Override 200 public boolean equals(Object o) { 201 if (this == o) return true; 202 if (!(o instanceof DenseMatrix)) return false; 203 DenseMatrix that = (DenseMatrix) o; 204 if ((dim1 == that.dim1) && (dim2 == that.dim2) && (numElements == that.numElements) && Arrays.equals(getShape(),that.getShape())) { 205 for (int i = 0; i < dim1; i++) { 206 for (int j = 0; j < dim2; j++) { 207 if (Math.abs(get(i,j) - that.get(i,j)) > DELTA) { 208 return false; 209 } 210 } 211 } 212 return true; 213 } else { 214 return false; 215 } 216 } 217 218 @Override 219 public int hashCode() { 220 int result = Objects.hash(dim1, dim2, numElements); 221 result = 31 * result + Arrays.hashCode(values); 222 result = 31 * result + Arrays.hashCode(getShape()); 223 return result; 224 } 225 226 @Override 227 public void set(int i, int j, double value) { 228 values[i][j] = value; 229 } 230 231 @Override 232 public int getDimension1Size() { 233 return dim1; 234 } 235 236 @Override 237 public int getDimension2Size() { 238 return dim2; 239 } 240 241 @Override 242 public DenseVector leftMultiply(SGDVector input) { 243 if (input.size() == dim2) { 244 double[] output = new double[dim1]; 245 246 for (VectorTuple tuple : input) { 247 for (int i = 0; i < output.length; i++) { 248 output[i] += values[i][tuple.index] * tuple.value; 249 } 250 } 251 252 return new DenseVector(output); 253 } else { 254 throw new IllegalArgumentException("input.size() != dim2, input.size() = " + input.size() + ", dim1,dim2 = " + dim1+","+dim2); 255 } 256 } 257 258 @Override 259 public DenseVector rightMultiply(SGDVector input) { 260 if (input.size() == dim1) { 261 double[] output = new double[dim2]; 262 263 for (VectorTuple tuple : input) { 264 for (int i = 0; i < output.length; i++) { 265 output[i] += values[tuple.index][i] * tuple.value; 266 } 267 } 268 269 return new DenseVector(output); 270 } else { 271 throw new IllegalArgumentException("input.size() != dim1"); 272 } 273 } 274 275 @Override 276 public DenseMatrix matrixMultiply(Matrix other) { 277 if (dim2 == other.getDimension1Size()) { 278 if (other instanceof DenseMatrix) { 279 DenseMatrix otherDense = (DenseMatrix) other; 280 double[][] output = new double[dim1][otherDense.dim2]; 281 282 for (int i = 0; i < dim1; i++) { 283 for (int j = 0; j < otherDense.dim2; j++) { 284 output[i][j] = columnRowDot(i,j,otherDense); 285 } 286 } 287 288 return new DenseMatrix(output); 289 } else if (other instanceof DenseSparseMatrix) { 290 DenseSparseMatrix otherSparse = (DenseSparseMatrix) other; 291 int otherDim2 = otherSparse.getDimension2Size(); 292 double[][] output = new double[dim1][otherDim2]; 293 294 for (int i = 0; i < dim1; i++) { 295 for (int j = 0; j < otherDim2; j++) { 296 output[i][j] = columnRowDot(i,j,otherSparse); 297 } 298 } 299 300 return new DenseMatrix(output); 301 } else { 302 throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName()); 303 } 304 } else { 305 throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape())); 306 } 307 } 308 309 @Override 310 public DenseMatrix matrixMultiply(Matrix other, boolean transposeThis, boolean transposeOther) { 311 if (transposeThis && transposeOther) { 312 return matrixMultiplyTransposeBoth(other); 313 } else if (transposeThis) { 314 return matrixMultiplyTransposeThis(other); 315 } else if (transposeOther) { 316 return matrixMultiplyTransposeOther(other); 317 } else { 318 return matrixMultiply(other); 319 } 320 } 321 322 private DenseMatrix matrixMultiplyTransposeBoth(Matrix other) { 323 if (dim1 == other.getDimension2Size()) { 324 if (other instanceof DenseMatrix) { 325 DenseMatrix otherDense = (DenseMatrix) other; 326 double[][] output = new double[dim2][otherDense.dim1]; 327 328 for (int i = 0; i < dim2; i++) { 329 for (int j = 0; j < otherDense.dim1; j++) { 330 output[i][j] = rowColumnDot(i,j,otherDense); 331 } 332 } 333 334 return new DenseMatrix(output); 335 } else if (other instanceof DenseSparseMatrix) { 336 DenseSparseMatrix otherSparse = (DenseSparseMatrix) other; 337 int otherDim1 = otherSparse.getDimension1Size(); 338 double[][] output = new double[dim2][otherDim1]; 339 340 for (int i = 0; i < dim2; i++) { 341 for (int j = 0; j < otherDim1; j++) { 342 output[i][j] = rowColumnDot(i,j,otherSparse); 343 } 344 } 345 346 return new DenseMatrix(output); 347 } else { 348 throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName()); 349 } 350 } else { 351 throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape())); 352 } 353 } 354 355 private DenseMatrix matrixMultiplyTransposeThis(Matrix other) { 356 if (dim1 == other.getDimension1Size()) { 357 if (other instanceof DenseMatrix) { 358 DenseMatrix otherDense = (DenseMatrix) other; 359 double[][] output = new double[dim2][otherDense.dim2]; 360 361 for (int i = 0; i < dim2; i++) { 362 for (int j = 0; j < otherDense.dim2; j++) { 363 output[i][j] = columnColumnDot(i,j,otherDense); 364 } 365 } 366 367 return new DenseMatrix(output); 368 } else if (other instanceof DenseSparseMatrix) { 369 DenseSparseMatrix otherSparse = (DenseSparseMatrix) other; 370 int otherDim2 = otherSparse.getDimension2Size(); 371 double[][] output = new double[dim2][otherDim2]; 372 373 for (int i = 0; i < dim2; i++) { 374 for (int j = 0; j < otherDim2; j++) { 375 output[i][j] = columnColumnDot(i,j,otherSparse); 376 } 377 } 378 379 return new DenseMatrix(output); 380 } else { 381 throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName()); 382 } 383 } else { 384 throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape())); 385 } 386 } 387 388 private DenseMatrix matrixMultiplyTransposeOther(Matrix other) { 389 if (dim2 == other.getDimension2Size()) { 390 if (other instanceof DenseMatrix) { 391 DenseMatrix otherDense = (DenseMatrix) other; 392 double[][] output = new double[dim1][otherDense.dim1]; 393 394 for (int i = 0; i < dim1; i++) { 395 for (int j = 0; j < otherDense.dim1; j++) { 396 output[i][j] = rowRowDot(i,j,otherDense); 397 } 398 } 399 400 return new DenseMatrix(output); 401 } else if (other instanceof DenseSparseMatrix) { 402 DenseSparseMatrix otherSparse = (DenseSparseMatrix) other; 403 int otherDim1 = otherSparse.getDimension1Size(); 404 double[][] output = new double[dim1][otherDim1]; 405 406 for (int i = 0; i < dim1; i++) { 407 for (int j = 0; j < otherDim1; j++) { 408 output[i][j] = rowRowDot(i,j,otherSparse); 409 } 410 } 411 412 return new DenseMatrix(output); 413 } else { 414 throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName()); 415 } 416 } else { 417 throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape())); 418 } 419 } 420 421 private double columnRowDot(int rowIndex, int otherColIndex, Matrix other) { 422 double sum = 0.0; 423 for (int i = 0; i < dim2; i++) { 424 sum += get(rowIndex,i) * other.get(i,otherColIndex); 425 } 426 return sum; 427 } 428 429 private double rowColumnDot(int colIndex, int otherRowIndex, Matrix other) { 430 double sum = 0.0; 431 for (int i = 0; i < dim1; i++) { 432 sum += get(i,colIndex) * other.get(otherRowIndex,i); 433 } 434 return sum; 435 } 436 437 private double columnColumnDot(int colIndex, int otherColIndex, Matrix other) { 438 double sum = 0.0; 439 for (int i = 0; i < dim1; i++) { 440 sum += get(i,colIndex) * other.get(i,otherColIndex); 441 } 442 return sum; 443 } 444 445 private double rowRowDot(int rowIndex, int otherRowIndex, Matrix other) { 446 double sum = 0.0; 447 for (int i = 0; i < dim2; i++) { 448 sum += get(rowIndex,i) * other.get(otherRowIndex,i); 449 } 450 return sum; 451 } 452 453 @Override 454 public DenseVector rowSum() { 455 double[] rowSum = new double[dim1]; 456 for (int i = 0; i < dim1; i++) { 457 double tmp = 0.0; 458 for (int j = 0; j < dim2; j++) { 459 tmp += values[i][j]; 460 } 461 rowSum[i] = tmp; 462 } 463 return new DenseVector(rowSum); 464 } 465 466 @Override 467 public void rowScaleInPlace(DenseVector scalingCoefficients) { 468 for (int i = 0; i < dim1; i++) { 469 double scalar = scalingCoefficients.get(i); 470 for (int j = 0; j < dim2; j++) { 471 values[i][j] *= scalar; 472 } 473 } 474 } 475 476 @Override 477 public void add(int i, int j, double value) { 478 values[i][j] += value; 479 } 480 481 public void addAcrossDim1(int[] indices, double value) { 482 if (indices.length != dim2) { 483 throw new IllegalArgumentException("Invalid number of elements to add, must select one per value of dim2"); 484 } 485 for (int i = 0; i < indices.length; i++) { 486 values[indices[i]][i] += value; 487 } 488 } 489 490 public void addAcrossDim2(int[] indices, double value) { 491 if (indices.length != dim1) { 492 throw new IllegalArgumentException("Invalid number of elements to indices, must select one per value of dim1"); 493 } 494 for (int i = 0; i < indices.length; i++) { 495 values[i][indices[i]] += value; 496 } 497 } 498 499 @Override 500 public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) { 501 if (other instanceof Matrix) { 502 Matrix otherMat = (Matrix) other; 503 if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) { 504 for (MatrixTuple tuple : otherMat) { 505 values[tuple.i][tuple.j] += f.applyAsDouble(tuple.value); 506 } 507 } else { 508 throw new IllegalArgumentException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")"); 509 } 510 } else { 511 throw new IllegalArgumentException("Adding a non-Matrix to a Matrix"); 512 } 513 } 514 515 @Override 516 public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f) { 517 if (other instanceof Matrix) { 518 Matrix otherMat = (Matrix) other; 519 if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) { 520 for (MatrixTuple tuple : otherMat) { 521 values[tuple.i][tuple.j] *= f.applyAsDouble(tuple.value); 522 } 523 } else { 524 throw new IllegalArgumentException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")"); 525 } 526 } else { 527 throw new IllegalArgumentException("Adding a non-Matrix to a Matrix"); 528 } 529 } 530 531 @Override 532 public void foreachInPlace(DoubleUnaryOperator f) { 533 for (int i = 0; i < values.length; i++) { 534 for (int j = 0; j < dim2; j++) { 535 values[i][j] = f.applyAsDouble(values[i][j]); 536 } 537 } 538 } 539 540 /** 541 * Broadcasts the input vector and adds it to each row/column of the matrix. 542 * @param input The input vector. 543 * @param broadcastOverDim1 If true broadcasts over the first dimension, else broadcasts over the second. 544 */ 545 public void broadcastIntersectAndAddInPlace(SGDVector input, boolean broadcastOverDim1) { 546 if (input instanceof DenseVector) { 547 if (broadcastOverDim1) { 548 if (input.size() == dim2) { 549 for (int i = 0; i < dim1; i++) { 550 for (int j = 0; j < dim2; j++) { 551 values[i][j] += input.get(j); 552 } 553 } 554 } else { 555 throw new IllegalArgumentException("Input vector must have dimension equal to dim 2, input.size() = " + input.size() + ", dim2 = " + dim2); 556 } 557 } else { 558 if (input.size() == dim1) { 559 for (int i = 0; i < dim1; i++) { 560 double ith = input.get(i); 561 for (int j = 0; j < dim2; j++) { 562 values[i][j] += ith; 563 } 564 } 565 } else { 566 throw new IllegalArgumentException("Input vector must have dimension equal to dim 1, input.size() = " + input.size() + ", dim1 = " + dim1); 567 } 568 } 569 } else if (input instanceof SparseVector) { 570 if (broadcastOverDim1) { 571 if (input.size() == dim2) { 572 for (int i = 0; i < dim1; i++) { 573 for (VectorTuple v : input) { 574 values[i][v.index] += v.value; 575 } 576 } 577 } else { 578 throw new IllegalArgumentException("Input vector must have dimension equal to dim 2, input.size() = " + input.size() + ", dim2 = " + dim2); 579 } 580 } else { 581 if (input.size() == dim1) { 582 for (VectorTuple v : input) { 583 for (int j = 0; j < dim2; j++) { 584 values[v.index][j] += v.value; 585 } 586 } 587 } else { 588 throw new IllegalArgumentException("Input vector must have dimension equal to dim 1, input.size() = " + input.size() + ", dim1 = " + dim1); 589 } 590 } 591 592 } else { 593 throw new IllegalArgumentException("Input vector was neither dense nor sparse."); 594 } 595 } 596 597 @Override 598 public int numActiveElements(int row) { 599 return dim2; 600 } 601 602 @Override 603 public DenseVector getRow(int i) { 604 return new DenseVector(values[i]); 605 } 606 607 public DenseVector getColumn(int index) { 608 double[] output = new double[dim1]; 609 for (int i = 0; i < dim1; i++) { 610 output[i] = values[i][index]; 611 } 612 return new DenseVector(output); 613 } 614 615 public double rowSum(int rowIndex) { 616 double[] row = values[rowIndex]; 617 double sum = 0d; 618 for (int i = 0; i < row.length; i++) { 619 sum += row[i]; 620 } 621 return sum; 622 } 623 624 public double columnSum(int columnIndex) { 625 double sum = 0d; 626 for (int i = 0; i < dim1; i++) { 627 sum += values[i][columnIndex]; 628 } 629 return sum; 630 } 631 632 @Override 633 public double twoNorm() { 634 double output = 0.0; 635 for (int i = 0; i < dim1; i++) { 636 for (int j = 0; j < dim2; j++) { 637 double value = get(i,j); 638 output += value * value; 639 } 640 } 641 return Math.sqrt(output); 642 } 643 644 @Override 645 public String toString() { 646 StringBuilder buffer = new StringBuilder(); 647 648 buffer.append("DenseMatrix(dim1="); 649 buffer.append(dim1); 650 buffer.append(",dim2="); 651 buffer.append(dim2); 652 buffer.append(",values=\n"); 653 for (int i = 0; i < dim1; i++) { 654 buffer.append("\trow "); 655 buffer.append(i); 656 buffer.append(" ["); 657 for (int j = 0; j < dim2; j++) { 658 if (values[i][j] < 0.0) { 659 buffer.append(String.format("%.15f", values[i][j])); 660 } else { 661 buffer.append(String.format(" %.15f", values[i][j])); 662 } 663 buffer.append(","); 664 } 665 buffer.deleteCharAt(buffer.length()-1); 666 buffer.append("];\n"); 667 } 668 buffer.append(")"); 669 670 return buffer.toString(); 671 } 672 673 @Override 674 public MatrixIterator iterator() { 675 return new DenseMatrixIterator(this); 676 } 677 678 public void normalizeRows(VectorNormalizer normalizer) { 679 for (int i = 0; i < dim1; i++) { 680 double[] normalizedRow = normalizer.normalize(values[i]); 681 System.arraycopy(normalizedRow, 0, values[i], 0, dim2); 682 } 683 } 684 685 public DenseVector columnSum() { 686 double[] columnSum = new double[dim2]; 687 for (int i = 0; i < dim1; i++) { 688 for (int j = 0; j < dim2; j++) { 689 columnSum[j] += values[i][j]; 690 } 691 } 692 return new DenseVector(columnSum); 693 } 694 695 private class DenseMatrixIterator implements MatrixIterator { 696 private final DenseMatrix matrix; 697 private final MatrixTuple tuple; 698 private int i; 699 private int j; 700 701 public DenseMatrixIterator(DenseMatrix matrix) { 702 this.matrix = matrix; 703 this.tuple = new MatrixTuple(); 704 this.i = 0; 705 this.j = 0; 706 } 707 708 @Override 709 public MatrixTuple getReference() { 710 return tuple; 711 } 712 713 @Override 714 public boolean hasNext() { 715 return (i < matrix.dim1) && (j < matrix.dim2); 716 } 717 718 @Override 719 public MatrixTuple next() { 720 if (!hasNext()) { 721 throw new NoSuchElementException("Off the end of the iterator."); 722 } 723 tuple.i = i; 724 tuple.j = j; 725 tuple.value = matrix.values[i][j]; 726 if (j < dim2-1) { 727 j++; 728 } else { 729 //Reached end of current vector, get next one 730 i++; 731 j = 0; 732 } 733 return tuple; 734 } 735 } 736 737}