001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.la;
018
019import org.tribuo.math.util.VectorNormalizer;
020
021import java.util.Arrays;
022import java.util.NoSuchElementException;
023import java.util.Objects;
024import java.util.function.DoubleUnaryOperator;
025
026/**
027 * A dense matrix, backed by a primitive array.
028 */
029public class DenseMatrix implements Matrix {
030    private static final long serialVersionUID = 1L;
031
032    private static final double DELTA = 1e-10;
033
034    protected final double[][] values;
035    protected final int dim1;
036    protected final int dim2;
037
038    private final int[] shape;
039
040    private final int numElements;
041
042    public DenseMatrix(int dim1, int dim2) {
043        this.values = new double[dim1][dim2];
044        this.dim1 = dim1;
045        this.dim2 = dim2;
046        this.shape = new int[]{dim1,dim2};
047        this.numElements = dim1*dim2;
048    }
049
050    public DenseMatrix(DenseMatrix other) {
051        this.values = new double[other.values.length][];
052        for (int i = 0; i < values.length; i++) {
053            this.values[i] = new double[other.values[i].length];
054            for (int j = 0; j < values[i].length; j++) {
055                this.values[i][j] = other.get(i,j);
056            }
057        }
058        this.dim1 = other.dim1;
059        this.dim2 = other.dim2;
060        this.shape = new int[]{dim1,dim2};
061        this.numElements = dim1*dim2;
062    }
063
064    public DenseMatrix(Matrix other) {
065        this.dim1 = other.getDimension1Size();
066        this.dim2 = other.getDimension2Size();
067        this.values = new double[dim1][dim2];
068        for (MatrixTuple t : other) {
069            this.values[t.i][t.j] = t.value;
070        }
071        this.shape = new int[]{dim1,dim2};
072        this.numElements = dim1*dim2;
073    }
074
075    /**
076     * Creates a DenseMatrix without defensive copying.
077     * @param values The values of the matrix.
078     */
079    DenseMatrix(double[][] values) {
080        this.values = values;
081        this.dim1 = values.length;
082        this.dim2 = values[0].length;
083        this.shape = new int[]{dim1,dim2};
084        this.numElements = dim1*dim2;
085    }
086
087    /**
088     * Defensively copies the values before construction.
089     * <p>
090     * Throws IllegalArgumentException if the supplied values are a ragged array.
091     * @param values The values of this dense matrix.
092     * @return A new dense matrix.
093     */
094    public static DenseMatrix createDenseMatrix(double[][] values) {
095        double[][] newValues = new double[values.length][];
096        int sizeCounter = -1;
097        for (int i = 0; i < newValues.length; i++) {
098            if (sizeCounter == -1) {
099                sizeCounter = values[i].length;
100            }
101            if (sizeCounter != values[i].length) {
102                throw new IllegalArgumentException("DenseMatrix must not be ragged. Expected dim2 = " + sizeCounter + ", but found " + values[i].length + " at index " + i);
103            }
104            newValues[i] = Arrays.copyOf(values[i],values[i].length);
105        }
106        return new DenseMatrix(newValues);
107    }
108
109    @Override
110    public int[] getShape() {
111        return shape;
112    }
113
114    @Override
115    public Tensor reshape(int[] newShape) {
116        int sum = Tensor.shapeSum(newShape);
117        if (sum != numElements) {
118            throw new IllegalArgumentException("Invalid shape " + Arrays.toString(newShape) + ", expected something with " + numElements + " elements.");
119        }
120
121        if (newShape.length == 2) {
122            DenseMatrix matrix = new DenseMatrix(newShape[0],newShape[1]);
123
124            for (int a = 0; a < numElements; a++) {
125                int oldI = a % dim1;
126                int oldJ = a % dim2;
127                int i = a % newShape[0];
128                int j = a / newShape[0];
129                matrix.set(i,j,get(oldI,oldJ));
130            }
131
132            return matrix;
133        } else if (newShape.length == 1) {
134            DenseVector vector = new DenseVector(numElements);
135            int a = 0;
136            for (int i = 0; i < dim1; i++) {
137                for (int j = 0; j < dim2; j++) {
138                    vector.set(a,get(i,j));
139                    a++;
140                }
141            }
142            return vector;
143        } else {
144            throw new IllegalArgumentException("Only supports 1 or 2 dimensional tensors.");
145        }
146    }
147
148    /**
149     * Copies the matrix.
150     * @return A deep copy of the matrix.
151     */
152    public DenseMatrix copy() {
153        return new DenseMatrix(this);
154    }
155
156    @Override
157    public double get(int i, int j) {
158        return values[i][j];
159    }
160
161    public DenseVector gatherAcrossDim1(int[] elements) {
162        if (elements.length != dim2) {
163            throw new IllegalArgumentException("Invalid number of elements to gather, must select one per value of dim2");
164        }
165        double[] outputValues = new double[dim2];
166
167        for (int i = 0; i < elements.length; i++) {
168            outputValues[i] = values[elements[i]][i];
169        }
170
171        return new DenseVector(outputValues);
172    }
173
174    public DenseVector gatherAcrossDim2(int[] elements) {
175        if (elements.length != dim1) {
176            throw new IllegalArgumentException("Invalid number of elements to gather, must select one per value of dim1");
177        }
178        double[] outputValues = new double[dim1];
179
180        for (int i = 0; i < elements.length; i++) {
181            outputValues[i] = values[i][elements[i]];
182        }
183
184        return new DenseVector(outputValues);
185    }
186
187    public DenseMatrix transpose() {
188        double[][] newValues = new double[dim2][dim1];
189
190        for (int i = 0; i < dim1; i++) {
191            for (int j = 0; j < dim2; j++) {
192                newValues[j][i] = values[i][j];
193            }
194        }
195
196        return new DenseMatrix(newValues);
197    }
198
199    @Override
200    public boolean equals(Object o) {
201        if (this == o) return true;
202        if (!(o instanceof DenseMatrix)) return false;
203        DenseMatrix that = (DenseMatrix) o;
204        if ((dim1 == that.dim1) && (dim2 == that.dim2) && (numElements == that.numElements) && Arrays.equals(getShape(),that.getShape())) {
205            for (int i = 0; i < dim1; i++) {
206                for (int j = 0; j < dim2; j++) {
207                    if (Math.abs(get(i,j) - that.get(i,j)) > DELTA) {
208                        return false;
209                    }
210                }
211            }
212            return true;
213        } else {
214            return false;
215        }
216    }
217
218    @Override
219    public int hashCode() {
220        int result = Objects.hash(dim1, dim2, numElements);
221        result = 31 * result + Arrays.hashCode(values);
222        result = 31 * result + Arrays.hashCode(getShape());
223        return result;
224    }
225
226    @Override
227    public void set(int i, int j, double value) {
228        values[i][j] = value;
229    }
230
231    @Override
232    public int getDimension1Size() {
233        return dim1;
234    }
235
236    @Override
237    public int getDimension2Size() {
238        return dim2;
239    }
240
241    @Override
242    public DenseVector leftMultiply(SGDVector input) {
243        if (input.size() == dim2) {
244            double[] output = new double[dim1];
245
246            for (VectorTuple tuple : input) {
247                for (int i = 0; i < output.length; i++) {
248                    output[i] += values[i][tuple.index] * tuple.value;
249                }
250            }
251
252            return new DenseVector(output);
253        } else {
254            throw new IllegalArgumentException("input.size() != dim2, input.size() = " + input.size() + ", dim1,dim2 = " + dim1+","+dim2);
255        }
256    }
257
258    @Override
259    public DenseVector rightMultiply(SGDVector input) {
260        if (input.size() == dim1) {
261            double[] output = new double[dim2];
262
263            for (VectorTuple tuple : input) {
264                for (int i = 0; i < output.length; i++) {
265                    output[i] += values[tuple.index][i] * tuple.value;
266                }
267            }
268
269            return new DenseVector(output);
270        } else {
271            throw new IllegalArgumentException("input.size() != dim1");
272        }
273    }
274
275    @Override
276    public DenseMatrix matrixMultiply(Matrix other) {
277        if (dim2 == other.getDimension1Size()) {
278            if (other instanceof DenseMatrix) {
279                DenseMatrix otherDense = (DenseMatrix) other;
280                double[][] output = new double[dim1][otherDense.dim2];
281
282                for (int i = 0; i < dim1; i++) {
283                    for (int j = 0; j < otherDense.dim2; j++) {
284                        output[i][j] = columnRowDot(i,j,otherDense);
285                    }
286                }
287
288                return new DenseMatrix(output);
289            } else if (other instanceof DenseSparseMatrix) {
290                DenseSparseMatrix otherSparse = (DenseSparseMatrix) other;
291                int otherDim2 = otherSparse.getDimension2Size();
292                double[][] output = new double[dim1][otherDim2];
293
294                for (int i = 0; i < dim1; i++) {
295                    for (int j = 0; j < otherDim2; j++) {
296                        output[i][j] = columnRowDot(i,j,otherSparse);
297                    }
298                }
299
300                return new DenseMatrix(output);
301            } else {
302                throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName());
303            }
304        } else {
305            throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape()));
306        }
307    }
308
309    @Override
310    public DenseMatrix matrixMultiply(Matrix other, boolean transposeThis, boolean transposeOther) {
311        if (transposeThis && transposeOther) {
312            return matrixMultiplyTransposeBoth(other);
313        } else if (transposeThis) {
314            return matrixMultiplyTransposeThis(other);
315        } else if (transposeOther) {
316            return matrixMultiplyTransposeOther(other);
317        } else {
318            return matrixMultiply(other);
319        }
320    }
321
322    private DenseMatrix matrixMultiplyTransposeBoth(Matrix other) {
323        if (dim1 == other.getDimension2Size()) {
324            if (other instanceof DenseMatrix) {
325                DenseMatrix otherDense = (DenseMatrix) other;
326                double[][] output = new double[dim2][otherDense.dim1];
327
328                for (int i = 0; i < dim2; i++) {
329                    for (int j = 0; j < otherDense.dim1; j++) {
330                        output[i][j] = rowColumnDot(i,j,otherDense);
331                    }
332                }
333
334                return new DenseMatrix(output);
335            } else if (other instanceof DenseSparseMatrix) {
336                DenseSparseMatrix otherSparse = (DenseSparseMatrix) other;
337                int otherDim1 = otherSparse.getDimension1Size();
338                double[][] output = new double[dim2][otherDim1];
339
340                for (int i = 0; i < dim2; i++) {
341                    for (int j = 0; j < otherDim1; j++) {
342                        output[i][j] = rowColumnDot(i,j,otherSparse);
343                    }
344                }
345
346                return new DenseMatrix(output);
347            } else {
348                throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName());
349            }
350        } else {
351            throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape()));
352        }
353    }
354
355    private DenseMatrix matrixMultiplyTransposeThis(Matrix other) {
356        if (dim1 == other.getDimension1Size()) {
357            if (other instanceof DenseMatrix) {
358                DenseMatrix otherDense = (DenseMatrix) other;
359                double[][] output = new double[dim2][otherDense.dim2];
360
361                for (int i = 0; i < dim2; i++) {
362                    for (int j = 0; j < otherDense.dim2; j++) {
363                        output[i][j] = columnColumnDot(i,j,otherDense);
364                    }
365                }
366
367                return new DenseMatrix(output);
368            } else if (other instanceof DenseSparseMatrix) {
369                DenseSparseMatrix otherSparse = (DenseSparseMatrix) other;
370                int otherDim2 = otherSparse.getDimension2Size();
371                double[][] output = new double[dim2][otherDim2];
372
373                for (int i = 0; i < dim2; i++) {
374                    for (int j = 0; j < otherDim2; j++) {
375                        output[i][j] = columnColumnDot(i,j,otherSparse);
376                    }
377                }
378
379                return new DenseMatrix(output);
380            } else {
381                throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName());
382            }
383        } else {
384            throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape()));
385        }
386    }
387
388    private DenseMatrix matrixMultiplyTransposeOther(Matrix other) {
389        if (dim2 == other.getDimension2Size()) {
390            if (other instanceof DenseMatrix) {
391                DenseMatrix otherDense = (DenseMatrix) other;
392                double[][] output = new double[dim1][otherDense.dim1];
393
394                for (int i = 0; i < dim1; i++) {
395                    for (int j = 0; j < otherDense.dim1; j++) {
396                        output[i][j] = rowRowDot(i,j,otherDense);
397                    }
398                }
399
400                return new DenseMatrix(output);
401            } else if (other instanceof DenseSparseMatrix) {
402                DenseSparseMatrix otherSparse = (DenseSparseMatrix) other;
403                int otherDim1 = otherSparse.getDimension1Size();
404                double[][] output = new double[dim1][otherDim1];
405
406                for (int i = 0; i < dim1; i++) {
407                    for (int j = 0; j < otherDim1; j++) {
408                        output[i][j] = rowRowDot(i,j,otherSparse);
409                    }
410                }
411
412                return new DenseMatrix(output);
413            } else {
414                throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName());
415            }
416        } else {
417            throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape()));
418        }
419    }
420
421    private double columnRowDot(int rowIndex, int otherColIndex, Matrix other) {
422        double sum = 0.0;
423        for (int i = 0; i < dim2; i++) {
424            sum += get(rowIndex,i) * other.get(i,otherColIndex);
425        }
426        return sum;
427    }
428
429    private double rowColumnDot(int colIndex, int otherRowIndex, Matrix other) {
430        double sum = 0.0;
431        for (int i = 0; i < dim1; i++) {
432            sum += get(i,colIndex) * other.get(otherRowIndex,i);
433        }
434        return sum;
435    }
436
437    private double columnColumnDot(int colIndex, int otherColIndex, Matrix other) {
438        double sum = 0.0;
439        for (int i = 0; i < dim1; i++) {
440            sum += get(i,colIndex) * other.get(i,otherColIndex);
441        }
442        return sum;
443    }
444
445    private double rowRowDot(int rowIndex, int otherRowIndex, Matrix other) {
446        double sum = 0.0;
447        for (int i = 0; i < dim2; i++) {
448            sum += get(rowIndex,i) * other.get(otherRowIndex,i);
449        }
450        return sum;
451    }
452
453    @Override
454    public DenseVector rowSum() {
455        double[] rowSum = new double[dim1];
456        for (int i = 0; i < dim1; i++) {
457            double tmp = 0.0;
458            for (int j = 0; j < dim2; j++) {
459                tmp += values[i][j];
460            }
461            rowSum[i] = tmp;
462        }
463        return new DenseVector(rowSum);
464    }
465
466    @Override
467    public void rowScaleInPlace(DenseVector scalingCoefficients) {
468        for (int i = 0; i < dim1; i++) {
469            double scalar = scalingCoefficients.get(i);
470            for (int j = 0; j < dim2; j++) {
471                values[i][j] *= scalar;
472            }
473        }
474    }
475
476    @Override
477    public void add(int i, int j, double value) {
478        values[i][j] += value;
479    }
480
481    public void addAcrossDim1(int[] indices, double value) {
482        if (indices.length != dim2) {
483            throw new IllegalArgumentException("Invalid number of elements to add, must select one per value of dim2");
484        }
485        for (int i = 0; i < indices.length; i++) {
486            values[indices[i]][i] += value;
487        }
488    }
489
490    public void addAcrossDim2(int[] indices, double value) {
491        if (indices.length != dim1) {
492            throw new IllegalArgumentException("Invalid number of elements to indices, must select one per value of dim1");
493        }
494        for (int i = 0; i < indices.length; i++) {
495            values[i][indices[i]] += value;
496        }
497    }
498
499    @Override
500    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
501        if (other instanceof Matrix) {
502            Matrix otherMat = (Matrix) other;
503            if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) {
504                for (MatrixTuple tuple : otherMat) {
505                    values[tuple.i][tuple.j] += f.applyAsDouble(tuple.value);
506                }
507            } else {
508                throw new IllegalArgumentException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")");
509            }
510        } else {
511            throw new IllegalArgumentException("Adding a non-Matrix to a Matrix");
512        }
513    }
514
515    @Override
516    public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f) {
517        if (other instanceof Matrix) {
518            Matrix otherMat = (Matrix) other;
519            if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) {
520                for (MatrixTuple tuple : otherMat) {
521                    values[tuple.i][tuple.j] *= f.applyAsDouble(tuple.value);
522                }
523            } else {
524                throw new IllegalArgumentException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")");
525            }
526        } else {
527            throw new IllegalArgumentException("Adding a non-Matrix to a Matrix");
528        }
529    }
530
531    @Override
532    public void foreachInPlace(DoubleUnaryOperator f) {
533        for (int i = 0; i < values.length; i++) {
534            for (int j = 0; j < dim2; j++) {
535                values[i][j] = f.applyAsDouble(values[i][j]);
536            }
537        }
538    }
539
540    /**
541     * Broadcasts the input vector and adds it to each row/column of the matrix.
542     * @param input The input vector.
543     * @param broadcastOverDim1 If true broadcasts over the first dimension, else broadcasts over the second.
544     */
545    public void broadcastIntersectAndAddInPlace(SGDVector input, boolean broadcastOverDim1) {
546        if (input instanceof DenseVector) {
547            if (broadcastOverDim1) {
548                if (input.size() == dim2) {
549                    for (int i = 0; i < dim1; i++) {
550                        for (int j = 0; j < dim2; j++) {
551                            values[i][j] += input.get(j);
552                        }
553                    }
554                } else {
555                    throw new IllegalArgumentException("Input vector must have dimension equal to dim 2, input.size() = " + input.size() + ", dim2 = " + dim2);
556                }
557            } else {
558                if (input.size() == dim1) {
559                    for (int i = 0; i < dim1; i++) {
560                        double ith = input.get(i);
561                        for (int j = 0; j < dim2; j++) {
562                            values[i][j] += ith;
563                        }
564                    }
565                } else {
566                    throw new IllegalArgumentException("Input vector must have dimension equal to dim 1, input.size() = " + input.size() + ", dim1 = " + dim1);
567                }
568            }
569        } else if (input instanceof SparseVector) {
570            if (broadcastOverDim1) {
571                if (input.size() == dim2) {
572                    for (int i = 0; i < dim1; i++) {
573                        for (VectorTuple v : input) {
574                            values[i][v.index] += v.value;
575                        }
576                    }
577                } else {
578                    throw new IllegalArgumentException("Input vector must have dimension equal to dim 2, input.size() = " + input.size() + ", dim2 = " + dim2);
579                }
580            } else {
581                if (input.size() == dim1) {
582                    for (VectorTuple v : input) {
583                        for (int j = 0; j < dim2; j++) {
584                            values[v.index][j] += v.value;
585                        }
586                    }
587                } else {
588                    throw new IllegalArgumentException("Input vector must have dimension equal to dim 1, input.size() = " + input.size() + ", dim1 = " + dim1);
589                }
590            }
591
592        } else {
593            throw new IllegalArgumentException("Input vector was neither dense nor sparse.");
594        }
595    }
596
597    @Override
598    public int numActiveElements(int row) {
599        return dim2;
600    }
601
602    @Override
603    public DenseVector getRow(int i) {
604        return new DenseVector(values[i]);
605    }
606
607    public DenseVector getColumn(int index) {
608        double[] output = new double[dim1];
609        for (int i = 0; i < dim1; i++) {
610            output[i] = values[i][index];
611        }
612        return new DenseVector(output);
613    }
614
615    public double rowSum(int rowIndex) {
616        double[] row = values[rowIndex];
617        double sum = 0d;
618        for (int i = 0; i < row.length; i++) {
619            sum += row[i];
620        }
621        return sum;
622    }
623
624    public double columnSum(int columnIndex) {
625        double sum = 0d;
626        for (int i = 0; i < dim1; i++) {
627            sum += values[i][columnIndex];
628        }
629        return sum;
630    }
631
632    @Override
633    public double twoNorm() {
634        double output = 0.0;
635        for (int i = 0; i < dim1; i++) {
636            for (int j = 0; j < dim2; j++) {
637                double value = get(i,j);
638                output += value * value;
639            }
640        }
641        return Math.sqrt(output);
642    }
643
644    @Override
645    public String toString() {
646        StringBuilder buffer = new StringBuilder();
647
648        buffer.append("DenseMatrix(dim1=");
649        buffer.append(dim1);
650        buffer.append(",dim2=");
651        buffer.append(dim2);
652        buffer.append(",values=\n");
653        for (int i = 0; i < dim1; i++) {
654            buffer.append("\trow ");
655            buffer.append(i);
656            buffer.append(" [");
657            for (int j = 0; j < dim2; j++) {
658                if (values[i][j] < 0.0) {
659                    buffer.append(String.format("%.15f", values[i][j]));
660                } else {
661                    buffer.append(String.format(" %.15f", values[i][j]));
662                }
663                buffer.append(",");
664            }
665            buffer.deleteCharAt(buffer.length()-1);
666            buffer.append("];\n");
667        }
668        buffer.append(")");
669
670        return buffer.toString();
671    }
672
673    @Override
674    public MatrixIterator iterator() {
675        return new DenseMatrixIterator(this);
676    }
677
678    public void normalizeRows(VectorNormalizer normalizer) {
679        for (int i = 0; i < dim1; i++) {
680            double[] normalizedRow = normalizer.normalize(values[i]);
681            System.arraycopy(normalizedRow, 0, values[i], 0, dim2);
682        }
683    }
684
685    public DenseVector columnSum() {
686        double[] columnSum = new double[dim2];
687        for (int i = 0; i < dim1; i++) {
688            for (int j = 0; j < dim2; j++) {
689                columnSum[j] += values[i][j];
690            }
691        }
692        return new DenseVector(columnSum);
693    }
694
695    private class DenseMatrixIterator implements MatrixIterator {
696        private final DenseMatrix matrix;
697        private final MatrixTuple tuple;
698        private int i;
699        private int j;
700
701        public DenseMatrixIterator(DenseMatrix matrix) {
702            this.matrix = matrix;
703            this.tuple = new MatrixTuple();
704            this.i = 0;
705            this.j = 0;
706        }
707
708        @Override
709        public MatrixTuple getReference() {
710            return tuple;
711        }
712
713        @Override
714        public boolean hasNext() {
715            return (i < matrix.dim1) && (j < matrix.dim2);
716        }
717
718        @Override
719        public MatrixTuple next() {
720            if (!hasNext()) {
721                throw new NoSuchElementException("Off the end of the iterator.");
722            }
723            tuple.i = i;
724            tuple.j = j;
725            tuple.value = matrix.values[i][j];
726            if (j < dim2-1) {
727                j++;
728            } else {
729                //Reached end of current vector, get next one
730                i++;
731                j = 0;
732            }
733            return tuple;
734        }
735    }
736
737}