001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.math.optimisers; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.math.Parameters; 023import org.tribuo.math.StochasticGradientOptimiser; 024import org.tribuo.math.la.DenseMatrix; 025import org.tribuo.math.la.DenseVector; 026import org.tribuo.math.la.Tensor; 027import org.tribuo.math.optimisers.util.ShrinkingMatrix; 028import org.tribuo.math.optimisers.util.ShrinkingTensor; 029import org.tribuo.math.optimisers.util.ShrinkingVector; 030 031/** 032 * An implementation of the Pegasos gradient optimiser used primarily for solving the SVM problem. 033 * <p> 034 * This gradient optimiser rewrites all the {@link Tensor}s in the {@link Parameters} 035 * with {@link ShrinkingTensor}. This means it keeps a different value in the {@link Tensor} 036 * to the one produced when you call get(), so it can correctly apply regularisation to the parameters. 037 * When {@link Pegasos#finalise()} is called it rewrites the {@link Parameters} with standard dense {@link Tensor}s. 038 * Follows the implementation in Factorie. 039 * <p> 040 * Pegasos is remarkably touchy about it's learning rates. The defaults work on a couple of examples, but it 041 * requires tuning to work properly on a specific dataset. 042 * <p> 043 * See: 044 * <pre> 045 * Shalev-Shwartz S, Singer Y, Srebro N, Cotter A 046 * "Pegasos: Primal Estimated Sub-Gradient Solver for SVM" 047 * Mathematical Programming, 2011. 048 * </pre> 049 */ 050public class Pegasos implements StochasticGradientOptimiser { 051 052 @Config(description="Step size shrinkage.") 053 private double lambda = 1e-2; 054 055 @Config(description="Base learning rate.") 056 private double baseRate = 0.1; 057 058 private int iteration = 1; 059 private Parameters parameters; 060 061 /** 062 * Added for olcut configuration. 063 */ 064 private Pegasos() { } 065 066 public Pegasos(double baseRate, double lambda) { 067 this.baseRate = baseRate; 068 this.lambda = lambda; 069 } 070 071 @Override 072 public void initialise(Parameters parameters) { 073 this.parameters = parameters; 074 Tensor[] curParams = parameters.get(); 075 Tensor[] newParams = new Tensor[curParams.length]; 076 for (int i = 0; i < newParams.length; i++) { 077 if (curParams[i] instanceof DenseVector) { 078 newParams[i] = new ShrinkingVector(((DenseVector) curParams[i]), baseRate, lambda); 079 } else if (curParams[i] instanceof DenseMatrix) { 080 newParams[i] = new ShrinkingMatrix(((DenseMatrix) curParams[i]), baseRate, lambda); 081 } else { 082 throw new IllegalStateException("Unknown Tensor subclass"); 083 } 084 } 085 parameters.set(newParams); 086 } 087 088 @Override 089 public Tensor[] step(Tensor[] updates, double weight) { 090 double eta_t = baseRate / (lambda * iteration); 091 for (Tensor t : updates) { 092 t.scaleInPlace(eta_t * weight); 093 } 094 iteration++; 095 return updates; 096 } 097 098 @Override 099 public String toString() { 100 return "Pegasos(baseRate=" + baseRate + ",lambda=" + lambda + ")"; 101 } 102 103 @Override 104 public void finalise() { 105 Tensor[] curParams = parameters.get(); 106 Tensor[] newParams = new Tensor[curParams.length]; 107 for (int i = 0; i < newParams.length; i++) { 108 if (curParams[i] instanceof ShrinkingTensor) { 109 newParams[i] = ((ShrinkingTensor) curParams[i]).convertToDense(); 110 } else { 111 throw new IllegalStateException("Finalising a Parameters which wasn't initialised with Pegasos"); 112 } 113 } 114 parameters.set(newParams); 115 } 116 117 @Override 118 public void reset() { 119 iteration = 1; 120 } 121 122 @Override 123 public Pegasos copy() { 124 return new Pegasos(lambda,baseRate); 125 } 126 127 @Override 128 public ConfiguredObjectProvenance getProvenance() { 129 return new ConfiguredObjectProvenanceImpl(this,"StochasticGradientOptimiser"); 130 } 131} 132