001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers.util;
018
019import org.tribuo.math.la.DenseVector;
020import org.tribuo.math.la.SGDVector;
021import org.tribuo.math.la.Tensor;
022import org.tribuo.math.la.VectorIterator;
023import org.tribuo.math.la.VectorTuple;
024
025import java.util.Arrays;
026import java.util.function.DoubleUnaryOperator;
027
028/**
029 * A subclass of {@link DenseVector} which shrinks the value every time a new value is added.
030 * <p>
031 * Be careful when modifying this or {@link DenseVector}.
032 */
033public class ShrinkingVector extends DenseVector implements ShrinkingTensor {
034    private final double baseRate;
035    private final boolean scaleShrinking;
036    private final double lambdaSqrt;
037    private final boolean reproject;
038    private double squaredTwoNorm;
039    private int iteration;
040    private double multiplier;
041
042    public ShrinkingVector(DenseVector v, double baseRate, boolean scaleShrinking) {
043        super(v);
044        this.baseRate = baseRate;
045        this.scaleShrinking = scaleShrinking;
046        this.lambdaSqrt = 0.0;
047        this.reproject = false;
048        this.iteration = 1;
049        this.multiplier = 1.0;
050    }
051
052    public ShrinkingVector(DenseVector v, double baseRate, double lambda) {
053        super(v);
054        this.baseRate = baseRate;
055        this.scaleShrinking = true;
056        this.lambdaSqrt = Math.sqrt(lambda);
057        this.reproject = true;
058        this.squaredTwoNorm = 0.0;
059        this.iteration = 1;
060        this.multiplier = 1.0;
061    }
062
063    private ShrinkingVector(double[] values, double baseRate, boolean scaleShrinking, double lambdaSqrt, boolean reproject, double squaredTwoNorm, int iteration, double multiplier) {
064        super(values);
065        this.baseRate = baseRate;
066        this.scaleShrinking = scaleShrinking;
067        this.lambdaSqrt = lambdaSqrt;
068        this.reproject = reproject;
069        this.squaredTwoNorm = squaredTwoNorm;
070        this.iteration = iteration;
071        this.multiplier = multiplier;
072    }
073
074    @Override
075    public DenseVector convertToDense() {
076        return DenseVector.createDenseVector(toArray());
077    }
078
079    @Override
080    public ShrinkingVector copy() {
081        return new ShrinkingVector(Arrays.copyOf(elements,elements.length),baseRate,scaleShrinking,lambdaSqrt,reproject,squaredTwoNorm,iteration,multiplier);
082    }
083
084    @Override
085    public double[] toArray() {
086        double[] newValues = new double[elements.length];
087        for (int i = 0; i < newValues.length; i++) {
088            newValues[i] = get(i);
089        }
090        return newValues;
091    }
092
093    @Override
094    public double get(int index) {
095        return elements[index] * multiplier;
096    }
097
098    @Override
099    public double sum() {
100        double sum = 0.0;
101        for (int i = 0; i < elements.length; i++) {
102            sum += get(i);
103        }
104        return sum;
105    }
106
107    @Override
108    public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
109        double shrinkage = scaleShrinking ? 1.0 - (baseRate / iteration) : 1.0 - baseRate;
110        scaleInPlace(shrinkage);
111        SGDVector otherVec = (SGDVector) other;
112        for (VectorTuple tuple : otherVec) {
113            double update = f.applyAsDouble(tuple.value);
114            double oldValue = elements[tuple.index] * multiplier;
115            double newValue = oldValue + update;
116            squaredTwoNorm -= oldValue * oldValue;
117            squaredTwoNorm += newValue * newValue;
118            elements[tuple.index] = newValue / multiplier;
119        }
120        if (reproject) {
121            double projectionNormaliser = (1.0 / lambdaSqrt) / twoNorm();
122            if (projectionNormaliser < 1.0) {
123                scaleInPlace(projectionNormaliser);
124            }
125        }
126        iteration++;
127    }
128
129    @Override
130    public int indexOfMax() {
131        int index = 0;
132        double value = Double.NEGATIVE_INFINITY;
133        for (int i = 0; i < elements.length; i++) {
134            double tmp = get(i);
135            if (tmp > value) {
136                index = i;
137                value = tmp;
138            }
139        }
140        return index;
141    }
142
143    @Override
144    public double dot(SGDVector other) {
145        double score = 0.0;
146
147        for (VectorTuple tuple : other) {
148            score += get(tuple.index) * tuple.value;
149        }
150
151        return score;
152    }
153
154    @Override
155    public void scaleInPlace(double value) {
156        multiplier *= value;
157        if (Math.abs(multiplier) < tolerance) {
158            reifyMultiplier();
159        }
160    }
161
162    private void reifyMultiplier() {
163        for (int i = 0; i < elements.length; i++) {
164            elements[i] *= multiplier;
165        }
166        multiplier = 1.0;
167    }
168
169    @Override
170    public double twoNorm() {
171        return Math.sqrt(squaredTwoNorm);
172    }
173
174    @Override
175    public double maxValue() {
176        return multiplier * super.maxValue();
177    }
178
179    @Override
180    public double minValue() {
181        return multiplier * super.minValue();
182    }
183
184    @Override
185    public VectorIterator iterator() {
186        return new ShrinkingVectorIterator(this);
187    }
188
189    private static class ShrinkingVectorIterator implements VectorIterator {
190        private final ShrinkingVector vector;
191        private final VectorTuple tuple;
192        private int index;
193
194        public ShrinkingVectorIterator(ShrinkingVector vector) {
195            this.vector = vector;
196            this.tuple = new VectorTuple();
197            this.index = 0;
198        }
199
200        @Override
201        public boolean hasNext() {
202            return index < vector.size();
203        }
204
205        @Override
206        public VectorTuple next() {
207            tuple.index = index;
208            tuple.value = vector.get(index);
209            index++;
210            return tuple;
211        }
212
213        @Override
214        public VectorTuple getReference() {
215            return tuple;
216        }
217    }
218}
219
220