001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers.util;
018
019import org.tribuo.math.la.DenseMatrix;
020import org.tribuo.math.la.DenseVector;
021import org.tribuo.math.la.Matrix;
022import org.tribuo.math.la.MatrixIterator;
023import org.tribuo.math.la.MatrixTuple;
024import org.tribuo.math.la.SGDVector;
025import org.tribuo.math.la.Tensor;
026import org.tribuo.math.la.VectorTuple;
027
028import java.util.function.DoubleUnaryOperator;
029
030/**
031 * A subclass of {@link DenseMatrix} which shrinks the value every time a new value is added.
032 * <p>
033 * Be careful when modifying this or {@link DenseMatrix}.
034 */
035public class ShrinkingMatrix extends DenseMatrix implements ShrinkingTensor {
036    private final double baseRate;
037    private final double lambdaSqrt;
038    private final boolean scaleShrinking;
039    private final boolean reproject;
040    private double squaredTwoNorm;
041    private int iteration;
042    private double multiplier;
043
044    public ShrinkingMatrix(DenseMatrix v, double baseRate, boolean scaleShrinking) {
045        super(v);
046        this.baseRate = baseRate;
047        this.scaleShrinking = scaleShrinking;
048        this.lambdaSqrt = 0.0;
049        this.reproject = false;
050        this.squaredTwoNorm = 0.0;
051        this.iteration = 1;
052        this.multiplier = 1.0;
053    }
054
055    public ShrinkingMatrix(DenseMatrix v, double baseRate, double lambda) {
056        super(v);
057        this.baseRate = baseRate;
058        this.scaleShrinking = true;
059        this.lambdaSqrt = Math.sqrt(lambda);
060        this.reproject = true;
061        this.squaredTwoNorm = 0.0;
062        this.iteration = 1;
063        this.multiplier = 1.0;
064    }
065
066    @Override
067    public DenseMatrix convertToDense() {
068        return new DenseMatrix(this);
069    }
070
071    @Override
072    public DenseVector leftMultiply(SGDVector input) {
073        if (input.size() == dim2) {
074            double[] output = new double[dim1];
075            for (VectorTuple tuple : input) {
076                for (int i = 0; i < output.length; i++) {
077                    output[i] += get(i, tuple.index) * tuple.value;
078                }
079            }
080
081            return DenseVector.createDenseVector(output);
082        } else {
083            throw new IllegalArgumentException("input.size() != dim2");
084        }
085    }
086
087    @Override
088    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
089        if (other instanceof Matrix) {
090            Matrix otherMat = (Matrix) other;
091            if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) {
092                double shrinkage = scaleShrinking ? 1.0 - (baseRate / iteration) : 1.0 - baseRate;
093                scaleInPlace(shrinkage);
094                for (MatrixTuple tuple : otherMat) {
095                    double update = f.applyAsDouble(tuple.value);
096                    double oldValue = values[tuple.i][tuple.j] * multiplier;
097                    double newValue = oldValue + update;
098                    squaredTwoNorm -= oldValue * oldValue;
099                    squaredTwoNorm += newValue * newValue;
100                    values[tuple.i][tuple.j] = newValue / multiplier;
101                }
102                if (reproject) {
103                    double projectionNormaliser = (1.0 / lambdaSqrt) / twoNorm();
104                    if (projectionNormaliser < 1.0) {
105                        scaleInPlace(projectionNormaliser);
106                    }
107                }
108                iteration++;
109            } else {
110                throw new IllegalStateException("Matrices are not the same size, this(" + dim1 + "," + dim2 + "), other(" + otherMat.getDimension1Size() + "," + otherMat.getDimension2Size() + ")");
111            }
112        } else {
113            throw new IllegalStateException("Adding a non-Matrix to a Matrix");
114        }
115    }
116
117    @Override
118    public double get(int i, int j) {
119        return values[i][j] * multiplier;
120    }
121
122    @Override
123    public void scaleInPlace(double value) {
124        multiplier *= value;
125        if (Math.abs(multiplier) < tolerance) {
126            reifyMultiplier();
127        }
128    }
129
130    private void reifyMultiplier() {
131        for (int i = 0; i < dim1; i++) {
132            for (int j = 0; j < dim2; j++) {
133                values[i][j] *= multiplier;
134            }
135        }
136        multiplier = 1.0;
137    }
138
139    @Override
140    public double twoNorm() {
141        return Math.sqrt(squaredTwoNorm);
142    }
143
144    @Override
145    public MatrixIterator iterator() {
146        return new ShrinkingMatrixIterator(this);
147    }
148
149    private class ShrinkingMatrixIterator implements MatrixIterator {
150        private final ShrinkingMatrix matrix;
151        private final MatrixTuple tuple;
152        private int i;
153        private int j;
154
155        public ShrinkingMatrixIterator(ShrinkingMatrix matrix) {
156            this.matrix = matrix;
157            this.tuple = new MatrixTuple();
158            this.i = 0;
159            this.j = 0;
160        }
161
162        @Override
163        public MatrixTuple getReference() {
164            return tuple;
165        }
166
167        @Override
168        public boolean hasNext() {
169            return (i < matrix.dim1) && (j < matrix.dim2);
170        }
171
172        @Override
173        public MatrixTuple next() {
174            tuple.i = i;
175            tuple.j = j;
176            tuple.value = matrix.get(i, j);
177            if (j < dim2 - 1) {
178                j++;
179            } else {
180                //Reached end of current vector, get next one
181                i++;
182                j = 0;
183            }
184            return tuple;
185        }
186    }
187
188}
189