001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.la;
018
019import java.util.Arrays;
020import java.util.Iterator;
021import java.util.List;
022import java.util.Objects;
023import java.util.function.DoubleUnaryOperator;
024
025/**
026 * A matrix which is dense in the first dimension and sparse in the second.
027 * <p>
028 * Backed by an array of {@link SparseVector}.
029 */
030public class DenseSparseMatrix implements Matrix {
031    private static final long serialVersionUID = 1L;
032
033    private final SparseVector[] values;
034    private final int dim1;
035    private final int dim2;
036    private final int[] shape;
037
038    DenseSparseMatrix(SparseVector[] values) {
039        this.values = values;
040        this.dim1 = values.length;
041        this.dim2 = values[0].size();
042        this.shape = new int[]{dim1,dim2};
043    }
044
045    public DenseSparseMatrix(List<SparseVector> values) {
046        this.values = new SparseVector[values.size()];
047        this.dim1 = values.size();
048        this.dim2 = values.get(0).size();
049        this.shape = new int[]{dim1,dim2};
050        for (int i = 0; i < values.size(); i++) {
051            this.values[i] = values.get(i);
052        }
053    }
054
055    public DenseSparseMatrix(DenseSparseMatrix other) {
056        this.dim1 = other.dim1;
057        this.dim2 = other.dim2;
058        this.values = new SparseVector[other.values.length];
059        this.shape = new int[]{dim1,dim2};
060        for (int i = 0; i < values.length; i++) {
061            values[i] = other.values[i].copy();
062        }
063    }
064
065    /**
066     * Defensively copies the values.
067     * @param values The sparse vectors to use.
068     * @return A DenseSparseMatrix containing the supplied vectors.
069     */
070    public static DenseSparseMatrix createFromSparseVectors(SparseVector[] values) {
071        SparseVector[] newValues = new SparseVector[values.length];
072        for (int i = 0; i < values.length; i++) {
073            newValues[i] = values[i].copy();
074        }
075        return new DenseSparseMatrix(newValues);
076    }
077
078    @Override
079    public int[] getShape() {
080        return shape;
081    }
082
083    @Override
084    public Tensor reshape(int[] newShape) {
085        throw new UnsupportedOperationException("Reshape not supported on sparse Tensors.");
086    }
087
088    @Override
089    public double get(int i, int j) {
090        return values[i].get(j);
091    }
092
093    @Override
094    public void set(int i, int j, double value) {
095        values[i].set(j,value);
096    }
097
098    @Override
099    public int getDimension1Size() {
100        return dim1;
101    }
102
103    @Override
104    public int getDimension2Size() {
105        return dim2;
106    }
107
108    @Override
109    public DenseVector leftMultiply(SGDVector input) {
110        if (input.size() == dim2) {
111            double[] output = new double[dim1];
112
113            for (int i = 0; i < output.length; i++) {
114                output[i] = values[i].dot(input);
115            }
116
117            return new DenseVector(output);
118        } else {
119            throw new IllegalArgumentException("input.size() != dim2");
120        }
121    }
122
123    /**
124     * rightMultiply is very inefficient on DenseSparseMatrix due to the storage format.
125     * @param input The input vector.
126     * @return A*input.
127     */
128    @Override
129    public DenseVector rightMultiply(SGDVector input) {
130        if (input.size() == dim1) {
131            double[] output = new double[dim2];
132
133            for (int j = 0; j < values.length; j++) {
134                for (int i = 0; i < output.length; i++) {
135                    output[i] = values[j].get(i) * input.get(i);
136                }
137            }
138
139            return new DenseVector(output);
140        } else {
141            throw new IllegalArgumentException("input.size() != dim1");
142        }
143    }
144
145    @Override
146    public void add(int i, int j, double value) {
147        values[i].add(j,value);
148    }
149
150    /**
151     * Only implemented for {@link DenseMatrix}.
152     * @param other The other {@link Tensor}.
153     * @param f A function to apply.
154     */
155    @Override
156    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
157        if (other instanceof Matrix) {
158            Matrix otherMat = (Matrix) other;
159            if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) {
160                if (otherMat instanceof DenseMatrix) {
161                    DenseMatrix otherDenseMat = (DenseMatrix) other;
162                    for (int i = 0; i < dim1; i++) {
163                        values[i].intersectAndAddInPlace(otherDenseMat.getRow(i),f);
164                    }
165                } else {
166                    throw new UnsupportedOperationException("Not implemented intersectAndAddInPlace in DenseSparseMatrix for types other than DenseMatrix");
167                }
168            } else {
169                throw new IllegalArgumentException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")");
170            }
171        } else {
172            throw new IllegalArgumentException("Adding a non-Matrix to a Matrix");
173        }
174    }
175
176    /**
177     * Only implemented for {@link DenseMatrix}.
178     * @param other The other {@link Tensor}.
179     * @param f A function to apply.
180     */
181    @Override
182    public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f) {
183        if (other instanceof Matrix) {
184            Matrix otherMat = (Matrix) other;
185            if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) {
186                if (otherMat instanceof DenseMatrix) {
187                    DenseMatrix otherDenseMat = (DenseMatrix) other;
188                    for (int i = 0; i < dim1; i++) {
189                        values[i].hadamardProductInPlace(otherDenseMat.getRow(i),f);
190                    }
191                } else {
192                    throw new UnsupportedOperationException("Not implemented hadamardProductInPlace in DenseSparseMatrix for types other than DenseMatrix");
193                }
194            } else {
195                throw new IllegalArgumentException("Matrices are not the same size, this("+dim1+","+dim2+"), other("+otherMat.getDimension1Size()+","+otherMat.getDimension2Size()+")");
196            }
197        } else {
198            throw new IllegalArgumentException("Scaling a Matrix by a non-Matrix");
199        }
200    }
201
202    @Override
203    public void foreachInPlace(DoubleUnaryOperator f) {
204        for (int i = 0; i < values.length; i++) {
205            values[i].foreachInPlace(f);
206        }
207    }
208
209    @Override
210    public int numActiveElements(int row) {
211        return values[row].numActiveElements();
212    }
213
214    @Override
215    public SparseVector getRow(int i) {
216        return values[i];
217    }
218
219    @Override
220    public boolean equals(Object other) {
221        if (other instanceof Matrix) {
222            Iterator<MatrixTuple> ourItr = iterator();
223            Iterator<MatrixTuple> otherItr = ((Matrix)other).iterator();
224            MatrixTuple ourTuple;
225            MatrixTuple otherTuple;
226
227            while (ourItr.hasNext() && otherItr.hasNext()) {
228                ourTuple = ourItr.next();
229                otherTuple = otherItr.next();
230                if (!ourTuple.equals(otherTuple)) {
231                    return false;
232                }
233            }
234
235            // If one of the iterators still has elements then they are not the same.
236            return !(ourItr.hasNext() || otherItr.hasNext());
237        } else {
238            return false;
239        }
240    }
241
242    @Override
243    public int hashCode() {
244        int result = Objects.hash(dim1, dim2);
245        result = 31 * result + Arrays.hashCode(values);
246        return result;
247    }
248
249    @Override
250    public double twoNorm() {
251        double output = 0.0;
252        for (int i = 0; i < dim1; i++) {
253            double value = values[i].twoNorm();
254            output += value * value;
255        }
256        return Math.sqrt(output);
257    }
258
259    @Override
260    public DenseMatrix matrixMultiply(Matrix other) {
261        if (dim2 == other.getDimension1Size()) {
262            if (other instanceof DenseMatrix) {
263                DenseMatrix otherDense = (DenseMatrix) other;
264                double[][] output = new double[dim1][otherDense.dim2];
265
266                for (int i = 0; i < dim1; i++) {
267                    for (int j = 0; j < otherDense.dim2; j++) {
268                        output[i][j] = columnRowDot(i,j,otherDense);
269                    }
270                }
271
272                return new DenseMatrix(output);
273            } else if (other instanceof DenseSparseMatrix) {
274                DenseSparseMatrix otherSparse = (DenseSparseMatrix) other;
275                int otherDim2 = otherSparse.getDimension2Size();
276                double[][] output = new double[dim1][otherDim2];
277
278                for (int i = 0; i < dim1; i++) {
279                    for (int j = 0; j < otherDim2; j++) {
280                        output[i][j] = columnRowDot(i,j,otherSparse);
281                    }
282                }
283
284                return new DenseMatrix(output);
285            } else {
286                throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName());
287            }
288        } else {
289            throw new IllegalArgumentException("Invalid matrix dimensions, this.shape=" + Arrays.toString(shape) + ", other.shape = " + Arrays.toString(other.getShape()));
290        }
291    }
292
293    @Override
294    public DenseMatrix matrixMultiply(Matrix other, boolean transposeThis, boolean transposeOther) {
295        if (transposeThis && transposeOther) {
296            return matrixMultiplyTransposeBoth(other);
297        } else if (transposeThis) {
298            return matrixMultiplyTransposeThis(other);
299        } else if (transposeOther) {
300            return matrixMultiplyTransposeOther(other);
301        } else {
302            return matrixMultiply(other);
303        }
304    }
305
306    private DenseMatrix matrixMultiplyTransposeBoth(Matrix other) {
307        if (dim1 == other.getDimension2Size()) {
308            if (other instanceof DenseMatrix) {
309                DenseMatrix otherDense = (DenseMatrix) other;
310                double[][] output = new double[dim2][otherDense.dim1];
311
312                for (int i = 0; i < dim2; i++) {
313                    for (int j = 0; j < otherDense.dim1; j++) {
314                        output[i][j] = rowColumnDot(i,j,otherDense);
315                    }
316                }
317
318                return new DenseMatrix(output);
319            } else if (other instanceof DenseSparseMatrix) {
320                DenseSparseMatrix otherSparse = (DenseSparseMatrix) other;
321                int otherDim1 = otherSparse.getDimension1Size();
322                double[][] output = new double[dim2][otherDim1];
323
324                for (int i = 0; i < dim2; i++) {
325                    for (int j = 0; j < otherDim1; j++) {
326                        output[i][j] = rowColumnDot(i,j,otherSparse);
327                    }
328                }
329
330                return new DenseMatrix(output);
331            } else {
332                throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName());
333            }
334        } else {
335            throw new IllegalArgumentException("Invalid matrix dimensions, dim1 = " + dim1 + ", other.dim2 = " + other.getDimension2Size());
336        }
337    }
338
339    private DenseMatrix matrixMultiplyTransposeThis(Matrix other) {
340        if (dim1 == other.getDimension1Size()) {
341            if (other instanceof DenseMatrix) {
342                DenseMatrix otherDense = (DenseMatrix) other;
343                double[][] output = new double[dim2][otherDense.dim2];
344
345                for (int i = 0; i < dim2; i++) {
346                    for (int j = 0; j < otherDense.dim2; j++) {
347                        output[i][j] = columnColumnDot(i,j,otherDense);
348                    }
349                }
350
351                return new DenseMatrix(output);
352            } else if (other instanceof DenseSparseMatrix) {
353                DenseSparseMatrix otherSparse = (DenseSparseMatrix) other;
354                int otherDim2 = otherSparse.getDimension2Size();
355                double[][] output = new double[dim2][otherDim2];
356
357                for (int i = 0; i < dim2; i++) {
358                    for (int j = 0; j < otherDim2; j++) {
359                        output[i][j] = columnColumnDot(i,j,otherSparse);
360                    }
361                }
362
363                return new DenseMatrix(output);
364            } else {
365                throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName());
366            }
367        } else {
368            throw new IllegalArgumentException("Invalid matrix dimensions, dim1 = " + dim1 + ", other.dim1 = " + other.getDimension1Size());
369        }
370    }
371
372    private DenseMatrix matrixMultiplyTransposeOther(Matrix other) {
373        if (dim2 == other.getDimension2Size()) {
374            if (other instanceof DenseMatrix) {
375                DenseMatrix otherDense = (DenseMatrix) other;
376                double[][] output = new double[dim1][otherDense.dim1];
377
378                for (int i = 0; i < dim1; i++) {
379                    for (int j = 0; j < otherDense.dim1; j++) {
380                        output[i][j] = rowRowDot(i,j,otherDense);
381                    }
382                }
383
384                return new DenseMatrix(output);
385            } else if (other instanceof DenseSparseMatrix) {
386                DenseSparseMatrix otherSparse = (DenseSparseMatrix) other;
387                int otherDim1 = otherSparse.getDimension1Size();
388                double[][] output = new double[dim1][otherDim1];
389
390                for (int i = 0; i < dim1; i++) {
391                    for (int j = 0; j < otherDim1; j++) {
392                        output[i][j] = rowRowDot(i,j,otherSparse);
393                    }
394                }
395
396                return new DenseMatrix(output);
397            } else {
398                throw new IllegalArgumentException("Unknown matrix type " + other.getClass().getName());
399            }
400        } else {
401            throw new IllegalArgumentException("Invalid matrix dimensions, dim2 = " + dim2 + ", other.dim2 = " + other.getDimension2Size());
402        }
403    }
404
405    private double columnRowDot(int rowIndex, int otherColIndex, Matrix other) {
406        double sum = 0.0;
407        for (VectorTuple tuple : values[rowIndex]) {
408            sum += tuple.value * other.get(tuple.index,otherColIndex);
409        }
410        return sum;
411    }
412
413    private double rowColumnDot(int colIndex, int otherRowIndex, Matrix other) {
414        double sum = 0.0;
415        for (int i = 0; i < dim1; i++) {
416            sum += get(i,colIndex) * other.get(otherRowIndex,i);
417        }
418        return sum;
419    }
420
421    private double columnColumnDot(int colIndex, int otherColIndex, Matrix other) {
422        double sum = 0.0;
423        for (int i = 0; i < dim1; i++) {
424            sum += get(i,colIndex) * other.get(i,otherColIndex);
425        }
426        return sum;
427    }
428
429    private double rowRowDot(int rowIndex, int otherRowIndex, Matrix other) {
430        double sum = 0.0;
431        for (VectorTuple tuple : values[rowIndex]) {
432            sum += tuple.value * other.get(otherRowIndex,tuple.index);
433        }
434        return sum;
435    }
436
437    @Override
438    public DenseVector rowSum() {
439        double[] rowSum = new double[dim1];
440        for (int i = 0; i < dim1; i++) {
441            rowSum[i] = values[i].sum();
442        }
443        return new DenseVector(rowSum);
444    }
445
446    @Override
447    public void rowScaleInPlace(DenseVector scalingCoefficients) {
448        for (int i = 0; i < dim1; i++) {
449            values[i].scaleInPlace(scalingCoefficients.get(i));
450        }
451    }
452
453    @Override
454    public String toString() {
455        StringBuilder buffer = new StringBuilder();
456
457        buffer.append("DenseSparseMatrix(\n");
458        for (int i = 0; i < values.length; i++) {
459            buffer.append("\t");
460            buffer.append(values[i].toString());
461            buffer.append(";\n");
462        }
463        buffer.append(")");
464
465        return buffer.toString();
466    }
467
468    @Override
469    public MatrixIterator iterator() {
470        return new DenseSparseMatrixIterator(this);
471    }
472
473    private static class DenseSparseMatrixIterator implements MatrixIterator {
474        private final DenseSparseMatrix matrix;
475        private final MatrixTuple tuple;
476        private int i;
477        private Iterator<VectorTuple> itr;
478        private VectorTuple vecTuple;
479
480        public DenseSparseMatrixIterator(DenseSparseMatrix matrix) {
481            this.matrix = matrix;
482            this.tuple = new MatrixTuple();
483            this.i = 0;
484            this.itr = matrix.values[0].iterator();
485        }
486
487        @Override
488        public String toString() {
489            return "DenseSparseMatrixIterator(position="+i+",tuple="+ tuple.toString()+")";
490        }
491
492        @Override
493        public MatrixTuple getReference() {
494            return tuple;
495        }
496
497        @Override
498        public boolean hasNext() {
499            if (itr.hasNext()) {
500                return true;
501            } else {
502                while ((i < matrix.dim1) && (!itr.hasNext())) {
503                    i++;
504                    if (i < matrix.dim1) {
505                        itr = matrix.values[i].iterator();
506                    }
507                }
508            }
509            return (i < matrix.dim1) && itr.hasNext();
510        }
511
512        @Override
513        public MatrixTuple next() {
514            vecTuple = itr.next();
515            tuple.i = i;
516            tuple.j = vecTuple.index;
517            tuple.value = vecTuple.value;
518            return tuple;
519        }
520    }
521}