001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers.util; 018 019import org.tribuo.math.la.DenseVector; 020import org.tribuo.math.la.SGDVector; 021import org.tribuo.math.la.Tensor; 022import org.tribuo.math.la.VectorIterator; 023import org.tribuo.math.la.VectorTuple; 024 025import java.util.Arrays; 026import java.util.function.DoubleUnaryOperator; 027 028/** 029 * A subclass of {@link DenseVector} which shrinks the value every time a new value is added. 030 * <p> 031 * Be careful when modifying this or {@link DenseVector}. 032 */ 033public class ShrinkingVector extends DenseVector implements ShrinkingTensor { 034 private final double baseRate; 035 private final boolean scaleShrinking; 036 private final double lambdaSqrt; 037 private final boolean reproject; 038 private double squaredTwoNorm; 039 private int iteration; 040 private double multiplier; 041 042 public ShrinkingVector(DenseVector v, double baseRate, boolean scaleShrinking) { 043 super(v); 044 this.baseRate = baseRate; 045 this.scaleShrinking = scaleShrinking; 046 this.lambdaSqrt = 0.0; 047 this.reproject = false; 048 this.iteration = 1; 049 this.multiplier = 1.0; 050 } 051 052 public ShrinkingVector(DenseVector v, double baseRate, double lambda) { 053 super(v); 054 this.baseRate = baseRate; 055 this.scaleShrinking = true; 056 this.lambdaSqrt = Math.sqrt(lambda); 057 this.reproject = true; 058 this.squaredTwoNorm = 0.0; 059 this.iteration = 1; 060 this.multiplier = 1.0; 061 } 062 063 private ShrinkingVector(double[] values, double baseRate, boolean scaleShrinking, double lambdaSqrt, boolean reproject, double squaredTwoNorm, int iteration, double multiplier) { 064 super(values); 065 this.baseRate = baseRate; 066 this.scaleShrinking = scaleShrinking; 067 this.lambdaSqrt = lambdaSqrt; 068 this.reproject = reproject; 069 this.squaredTwoNorm = squaredTwoNorm; 070 this.iteration = iteration; 071 this.multiplier = multiplier; 072 } 073 074 @Override 075 public DenseVector convertToDense() { 076 return DenseVector.createDenseVector(toArray()); 077 } 078 079 @Override 080 public ShrinkingVector copy() { 081 return new ShrinkingVector(Arrays.copyOf(elements,elements.length),baseRate,scaleShrinking,lambdaSqrt,reproject,squaredTwoNorm,iteration,multiplier); 082 } 083 084 @Override 085 public double[] toArray() { 086 double[] newValues = new double[elements.length]; 087 for (int i = 0; i < newValues.length; i++) { 088 newValues[i] = get(i); 089 } 090 return newValues; 091 } 092 093 @Override 094 public double get(int index) { 095 return elements[index] * multiplier; 096 } 097 098 @Override 099 public double sum() { 100 double sum = 0.0; 101 for (int i = 0; i < elements.length; i++) { 102 sum += get(i); 103 } 104 return sum; 105 } 106 107 @Override 108 public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) { 109 double shrinkage = scaleShrinking ? 1.0 - (baseRate / iteration) : 1.0 - baseRate; 110 scaleInPlace(shrinkage); 111 SGDVector otherVec = (SGDVector) other; 112 for (VectorTuple tuple : otherVec) { 113 double update = f.applyAsDouble(tuple.value); 114 double oldValue = elements[tuple.index] * multiplier; 115 double newValue = oldValue + update; 116 squaredTwoNorm -= oldValue * oldValue; 117 squaredTwoNorm += newValue * newValue; 118 elements[tuple.index] = newValue / multiplier; 119 } 120 if (reproject) { 121 double projectionNormaliser = (1.0 / lambdaSqrt) / twoNorm(); 122 if (projectionNormaliser < 1.0) { 123 scaleInPlace(projectionNormaliser); 124 } 125 } 126 iteration++; 127 } 128 129 @Override 130 public int indexOfMax() { 131 int index = 0; 132 double value = Double.NEGATIVE_INFINITY; 133 for (int i = 0; i < elements.length; i++) { 134 double tmp = get(i); 135 if (tmp > value) { 136 index = i; 137 value = tmp; 138 } 139 } 140 return index; 141 } 142 143 @Override 144 public double dot(SGDVector other) { 145 double score = 0.0; 146 147 for (VectorTuple tuple : other) { 148 score += get(tuple.index) * tuple.value; 149 } 150 151 return score; 152 } 153 154 @Override 155 public void scaleInPlace(double value) { 156 multiplier *= value; 157 if (Math.abs(multiplier) < tolerance) { 158 reifyMultiplier(); 159 } 160 } 161 162 private void reifyMultiplier() { 163 for (int i = 0; i < elements.length; i++) { 164 elements[i] *= multiplier; 165 } 166 multiplier = 1.0; 167 } 168 169 @Override 170 public double twoNorm() { 171 return Math.sqrt(squaredTwoNorm); 172 } 173 174 @Override 175 public double maxValue() { 176 return multiplier * super.maxValue(); 177 } 178 179 @Override 180 public double minValue() { 181 return multiplier * super.minValue(); 182 } 183 184 @Override 185 public VectorIterator iterator() { 186 return new ShrinkingVectorIterator(this); 187 } 188 189 private static class ShrinkingVectorIterator implements VectorIterator { 190 private final ShrinkingVector vector; 191 private final VectorTuple tuple; 192 private int index; 193 194 public ShrinkingVectorIterator(ShrinkingVector vector) { 195 this.vector = vector; 196 this.tuple = new VectorTuple(); 197 this.index = 0; 198 } 199 200 @Override 201 public boolean hasNext() { 202 return index < vector.size(); 203 } 204 205 @Override 206 public VectorTuple next() { 207 tuple.index = index; 208 tuple.value = vector.get(index); 209 index++; 210 return tuple; 211 } 212 213 @Override 214 public VectorTuple getReference() { 215 return tuple; 216 } 217 } 218} 219 220