001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers.util; 018 019import org.tribuo.math.la.DenseMatrix; 020import org.tribuo.math.la.DenseVector; 021import org.tribuo.math.la.Matrix; 022import org.tribuo.math.la.MatrixIterator; 023import org.tribuo.math.la.MatrixTuple; 024import org.tribuo.math.la.SGDVector; 025import org.tribuo.math.la.Tensor; 026import org.tribuo.math.la.VectorTuple; 027 028import java.util.function.DoubleUnaryOperator; 029 030/** 031 * A subclass of {@link DenseMatrix} which shrinks the value every time a new value is added. 032 * <p> 033 * Be careful when modifying this or {@link DenseMatrix}. 034 */ 035public class ShrinkingMatrix extends DenseMatrix implements ShrinkingTensor { 036 private final double baseRate; 037 private final double lambdaSqrt; 038 private final boolean scaleShrinking; 039 private final boolean reproject; 040 private double squaredTwoNorm; 041 private int iteration; 042 private double multiplier; 043 044 public ShrinkingMatrix(DenseMatrix v, double baseRate, boolean scaleShrinking) { 045 super(v); 046 this.baseRate = baseRate; 047 this.scaleShrinking = scaleShrinking; 048 this.lambdaSqrt = 0.0; 049 this.reproject = false; 050 this.squaredTwoNorm = 0.0; 051 this.iteration = 1; 052 this.multiplier = 1.0; 053 } 054 055 public ShrinkingMatrix(DenseMatrix v, double baseRate, double lambda) { 056 super(v); 057 this.baseRate = baseRate; 058 this.scaleShrinking = true; 059 this.lambdaSqrt = Math.sqrt(lambda); 060 this.reproject = true; 061 this.squaredTwoNorm = 0.0; 062 this.iteration = 1; 063 this.multiplier = 1.0; 064 } 065 066 @Override 067 public DenseMatrix convertToDense() { 068 return new DenseMatrix(this); 069 } 070 071 @Override 072 public DenseVector leftMultiply(SGDVector input) { 073 if (input.size() == dim2) { 074 double[] output = new double[dim1]; 075 for (VectorTuple tuple : input) { 076 for (int i = 0; i < output.length; i++) { 077 output[i] += get(i, tuple.index) * tuple.value; 078 } 079 } 080 081 return DenseVector.createDenseVector(output); 082 } else { 083 throw new IllegalArgumentException("input.size() != dim2"); 084 } 085 } 086 087 @Override 088 public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) { 089 if (other instanceof Matrix) { 090 Matrix otherMat = (Matrix) other; 091 if ((dim1 == otherMat.getDimension1Size()) && (dim2 == otherMat.getDimension2Size())) { 092 double shrinkage = scaleShrinking ? 1.0 - (baseRate / iteration) : 1.0 - baseRate; 093 scaleInPlace(shrinkage); 094 for (MatrixTuple tuple : otherMat) { 095 double update = f.applyAsDouble(tuple.value); 096 double oldValue = values[tuple.i][tuple.j] * multiplier; 097 double newValue = oldValue + update; 098 squaredTwoNorm -= oldValue * oldValue; 099 squaredTwoNorm += newValue * newValue; 100 values[tuple.i][tuple.j] = newValue / multiplier; 101 } 102 if (reproject) { 103 double projectionNormaliser = (1.0 / lambdaSqrt) / twoNorm(); 104 if (projectionNormaliser < 1.0) { 105 scaleInPlace(projectionNormaliser); 106 } 107 } 108 iteration++; 109 } else { 110 throw new IllegalStateException("Matrices are not the same size, this(" + dim1 + "," + dim2 + "), other(" + otherMat.getDimension1Size() + "," + otherMat.getDimension2Size() + ")"); 111 } 112 } else { 113 throw new IllegalStateException("Adding a non-Matrix to a Matrix"); 114 } 115 } 116 117 @Override 118 public double get(int i, int j) { 119 return values[i][j] * multiplier; 120 } 121 122 @Override 123 public void scaleInPlace(double value) { 124 multiplier *= value; 125 if (Math.abs(multiplier) < tolerance) { 126 reifyMultiplier(); 127 } 128 } 129 130 private void reifyMultiplier() { 131 for (int i = 0; i < dim1; i++) { 132 for (int j = 0; j < dim2; j++) { 133 values[i][j] *= multiplier; 134 } 135 } 136 multiplier = 1.0; 137 } 138 139 @Override 140 public double twoNorm() { 141 return Math.sqrt(squaredTwoNorm); 142 } 143 144 @Override 145 public MatrixIterator iterator() { 146 return new ShrinkingMatrixIterator(this); 147 } 148 149 private class ShrinkingMatrixIterator implements MatrixIterator { 150 private final ShrinkingMatrix matrix; 151 private final MatrixTuple tuple; 152 private int i; 153 private int j; 154 155 public ShrinkingMatrixIterator(ShrinkingMatrix matrix) { 156 this.matrix = matrix; 157 this.tuple = new MatrixTuple(); 158 this.i = 0; 159 this.j = 0; 160 } 161 162 @Override 163 public MatrixTuple getReference() { 164 return tuple; 165 } 166 167 @Override 168 public boolean hasNext() { 169 return (i < matrix.dim1) && (j < matrix.dim2); 170 } 171 172 @Override 173 public MatrixTuple next() { 174 tuple.i = i; 175 tuple.j = j; 176 tuple.value = matrix.get(i, j); 177 if (j < dim2 - 1) { 178 j++; 179 } else { 180 //Reached end of current vector, get next one 181 i++; 182 j = 0; 183 } 184 return tuple; 185 } 186 } 187 188} 189