001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.math.optimisers;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
022import org.tribuo.math.Parameters;
023import org.tribuo.math.StochasticGradientOptimiser;
024import org.tribuo.math.la.DenseMatrix;
025import org.tribuo.math.la.DenseVector;
026import org.tribuo.math.la.Tensor;
027import org.tribuo.math.optimisers.util.ShrinkingMatrix;
028import org.tribuo.math.optimisers.util.ShrinkingTensor;
029import org.tribuo.math.optimisers.util.ShrinkingVector;
030
031/**
032 * An implementation of the Pegasos gradient optimiser used primarily for solving the SVM problem.
033 * <p>
034 * This gradient optimiser rewrites all the {@link Tensor}s in the {@link Parameters}
035 * with {@link ShrinkingTensor}. This means it keeps a different value in the {@link Tensor}
036 * to the one produced when you call get(), so it can correctly apply regularisation to the parameters.
037 * When {@link Pegasos#finalise()} is called it rewrites the {@link Parameters} with standard dense {@link Tensor}s.
038 * Follows the implementation in Factorie.
039 * <p>
040 * Pegasos is remarkably touchy about it's learning rates. The defaults work on a couple of examples, but it
041 * requires tuning to work properly on a specific dataset.
042 * <p>
043 * See:
044 * <pre>
045 * Shalev-Shwartz S, Singer Y, Srebro N, Cotter A
046 * "Pegasos: Primal Estimated Sub-Gradient Solver for SVM"
047 * Mathematical Programming, 2011.
048 * </pre>
049 */
050public class Pegasos implements StochasticGradientOptimiser {
051
052    @Config(description="Step size shrinkage.")
053    private double lambda = 1e-2;
054
055    @Config(description="Base learning rate.")
056    private double baseRate = 0.1;
057
058    private int iteration = 1;
059    private Parameters parameters;
060
061    /**
062     * Added for olcut configuration.
063     */
064    private Pegasos() { }
065
066    public Pegasos(double baseRate, double lambda) {
067        this.baseRate = baseRate;
068        this.lambda = lambda;
069    }
070
071    @Override
072    public void initialise(Parameters parameters) {
073        this.parameters = parameters;
074        Tensor[] curParams = parameters.get();
075        Tensor[] newParams = new Tensor[curParams.length];
076        for (int i = 0; i < newParams.length; i++) {
077            if (curParams[i] instanceof DenseVector) {
078                newParams[i] = new ShrinkingVector(((DenseVector) curParams[i]), baseRate, lambda);
079            } else if (curParams[i] instanceof DenseMatrix) {
080                newParams[i] = new ShrinkingMatrix(((DenseMatrix) curParams[i]), baseRate, lambda);
081            } else {
082                throw new IllegalStateException("Unknown Tensor subclass");
083            }
084        }
085        parameters.set(newParams);
086    }
087
088    @Override
089    public Tensor[] step(Tensor[] updates, double weight) {
090        double eta_t = baseRate / (lambda * iteration);
091        for (Tensor t : updates) {
092            t.scaleInPlace(eta_t * weight);
093        }
094        iteration++;
095        return updates;
096    }
097
098    @Override
099    public String toString() {
100        return "Pegasos(baseRate=" + baseRate + ",lambda=" + lambda + ")";
101    }
102
103    @Override
104    public void finalise() {
105        Tensor[] curParams = parameters.get();
106        Tensor[] newParams = new Tensor[curParams.length];
107        for (int i = 0; i < newParams.length; i++) {
108            if (curParams[i] instanceof ShrinkingTensor) {
109                newParams[i] = ((ShrinkingTensor) curParams[i]).convertToDense();
110            } else {
111                throw new IllegalStateException("Finalising a Parameters which wasn't initialised with Pegasos");
112            }
113        }
114        parameters.set(newParams);
115    }
116
117    @Override
118    public void reset() {
119        iteration = 1;
120    }
121
122    @Override
123    public Pegasos copy() {
124        return new Pegasos(lambda,baseRate);
125    }
126
127    @Override
128    public ConfiguredObjectProvenance getProvenance() {
129        return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser");
130    }
131}
132