001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.libsvm;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import libsvm.svm_model;
023import libsvm.svm_node;
024import libsvm.svm_parameter;
025import org.tribuo.Dataset;
026import org.tribuo.Example;
027import org.tribuo.Feature;
028import org.tribuo.ImmutableFeatureMap;
029import org.tribuo.ImmutableOutputInfo;
030import org.tribuo.Output;
031import org.tribuo.Trainer;
032import org.tribuo.provenance.ModelProvenance;
033import org.tribuo.provenance.TrainerProvenance;
034import org.tribuo.provenance.impl.TrainerProvenanceImpl;
035import org.tribuo.util.Util;
036
037import java.time.OffsetDateTime;
038import java.util.ArrayList;
039import java.util.Arrays;
040import java.util.Collections;
041import java.util.List;
042import java.util.Map;
043import java.util.logging.Logger;
044
045/**
046 * A trainer that will train using libsvm's Java implementation.
047 * <p>
048 * See:
049 * <pre>
050 * Chang CC, Lin CJ.
051 * "LIBSVM: a library for Support Vector Machines"
052 * ACM transactions on intelligent systems and technology (TIST), 2011.
053 * </pre>
054 * for the nu-svm algorithm:
055 * <pre>
056 * Schölkopf B, Smola A, Williamson R, Bartlett P L.
057 * "New support vector algorithms"
058 * Neural Computation, 2000, 1207-1245.
059 * </pre>
060 * and for the original algorithm:
061 * <pre>
062 * Cortes C, Vapnik V.
063 * "Support-Vector Networks"
064 * Machine Learning, 1995.
065 * </pre>
066 */
067public abstract class LibSVMTrainer<T extends Output<T>> implements Trainer<T> {
068    
069    private static final Logger logger = Logger.getLogger(LibSVMTrainer.class.getName());
070
071    /**
072     * The SVM parameters suitable for use by LibSVM.
073     */
074    protected svm_parameter parameters;
075
076    /**
077     * The type of SVM algorithm.
078     */
079    @Config(mandatory=true,description="Type of SVM algorithm.")
080    protected SVMType<T> svmType;
081
082    @Config(description="Type of Kernel.")
083    private KernelType kernelType = KernelType.LINEAR;
084
085    @Config(description="Polynomial degree.")
086    private int degree = 3;
087
088    @Config(description="Width of the RBF kernel, or scalar on sigmoid kernel.")
089    private double gamma = 0.0;
090
091    @Config(description="Polynomial coefficient or shift in sigmoid kernel.")
092    private double coef0 = 0.0;
093
094    @Config(description="nu value in NU SVM.")
095    private double nu = 0.5;
096
097    @Config(description="Internal cache size, most of the time should be left at default.")
098    private double cache_size = 500;
099
100    @Config(description="Cost parameter for incorrect predictions.")
101    private double cost = 1.0; // aka svm_parameters.C
102
103    @Config(description="Tolerance of the termination criterion.")
104    private double eps = 1e-3;
105
106    @Config(description="Epsilon in EPSILON_SVR.")
107    private double p = 1e-3;
108
109    @Config(description="Regularise the weight parameters.")
110    private boolean shrinking = true;
111
112    @Config(description="Generate probability estimates.")
113    private boolean probability = false;
114
115    private int trainInvocationCounter = 0;
116
117    /**
118     * For olcut.
119     */
120    protected LibSVMTrainer() {}
121
122    /**
123     * Constructs a LibSVMTrainer from the parameters.
124     * @param parameters The SVM parameters.
125     */
126    protected LibSVMTrainer(SVMParameters<T> parameters) {
127        this.parameters = parameters.getParameters();
128        // Unpack the parameters for the provenance system.
129        this.svmType = parameters.getSvmType();
130        this.kernelType = parameters.getKernelType();
131        this.degree = this.parameters.degree;
132        this.gamma = parameters.getGamma();
133        this.coef0 = this.parameters.coef0;
134        this.nu = this.parameters.nu;
135        this.cache_size = this.parameters.cache_size;
136        this.cost = this.parameters.C;
137        this.eps = this.parameters.eps;
138        this.p = this.parameters.p;
139        this.shrinking = this.parameters.shrinking == 1;
140        this.probability = this.parameters.probability == 1;
141    }
142
143    /**
144     * Used by the OLCUT configuration system, and should not be called by external code.
145     */
146    @Override
147    public void postConfig() {
148        parameters = new svm_parameter();
149        parameters.svm_type = svmType.getNativeType();
150        parameters.kernel_type = kernelType.getNativeType();
151        parameters.degree = degree;
152        parameters.gamma = gamma;
153        parameters.coef0 = coef0;
154        parameters.nu = nu;
155        parameters.cache_size = cache_size;
156        parameters.C = cost;
157        parameters.eps = eps;
158        parameters.p = p;
159        parameters.shrinking = shrinking ? 1 : 0;
160        parameters.probability = probability ? 1 : 0;
161    }
162
163    @Override
164    public String toString() {
165        StringBuilder buffer = new StringBuilder();
166
167        buffer.append("LibSVMTrainer(");
168        buffer.append("svm_params=");
169        buffer.append(SVMParameters.svmParamsToString(parameters));
170        buffer.append(")");
171
172        return buffer.toString();
173    }
174
175    @Override
176    public LibSVMModel<T> train(Dataset<T> examples) {
177        return train(examples, Collections.emptyMap());
178    }
179
180    @Override
181    public LibSVMModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
182        if (examples.getOutputInfo().getUnknownCount() > 0) {
183            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
184        }
185        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
186        ImmutableOutputInfo<T> outputIDInfo = examples.getOutputIDInfo();
187
188        TrainerProvenance trainerProvenance = getProvenance();
189        ModelProvenance provenance = new ModelProvenance(LibSVMModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
190        trainInvocationCounter++;
191
192        svm_parameter curParams = setupParameters(outputIDInfo);
193
194        Pair<svm_node[][],double[][]> data = extractData(examples,outputIDInfo,featureIDMap);
195
196        List<svm_model> models = trainModels(curParams,featureIDMap.size()+1,data.getA(),data.getB());
197
198        return createModel(provenance,featureIDMap,outputIDInfo,models);
199    }
200
201    /**
202     * Construct the appropriate subtype of LibSVMModel for the prediction task.
203     * @param provenance The model provenance.
204     * @param featureIDMap The feature id map.
205     * @param outputIDInfo The output id info.
206     * @param models The svm models.
207     * @return An implementation of LibSVMModel.
208     */
209    protected abstract LibSVMModel<T> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<svm_model> models);
210
211    /**
212     * Train all the liblinear instances necessary for this dataset.
213     * @param curParams The LibLinear parameters.
214     * @param numFeatures The number of features in this dataset.
215     * @param features The features themselves.
216     * @param outputs The outputs.
217     * @return A list of liblinear models.
218     */
219    protected abstract List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs);
220
221    /**
222     * Extracts the features and {@link Output}s in LibLinear's format.
223     * @param data The input data.
224     * @param outputInfo The output info.
225     * @param featureMap The feature info.
226     * @return The features and outputs.
227     */
228    protected abstract Pair<svm_node[][],double[][]> extractData(Dataset<T> data, ImmutableOutputInfo<T> outputInfo, ImmutableFeatureMap featureMap);
229
230    /**
231     * Constructs the svm_parameter. Most of the time this is a no-op, but
232     * classification overrides it to incorporate label weights if they exist.
233     * @param info The output info.
234     * @return The svm_parameters to use for training.
235     */
236    protected svm_parameter setupParameters(ImmutableOutputInfo<T> info) {
237        return SVMParameters.copyParameters(parameters);
238    }
239
240    @Override
241    public int getInvocationCount() {
242        return trainInvocationCounter;
243    }
244
245    /**
246     * Convert the example into an array of svm_node which represents a sparse feature vector.
247     * <p>
248     * If there are collisions in the feature ids then the values are summed.
249     * @param example The example to convert.
250     * @param featureIDMap The feature id map which holds the indices.
251     * @param features A buffer to use.
252     * @param <T> The type of the ouput.
253     * @return A sparse feature vector.
254     */
255    public static <T extends Output<T>> svm_node[] exampleToNodes(Example<T> example, ImmutableFeatureMap featureIDMap, List<svm_node> features) {
256        if (features == null) {
257            features = new ArrayList<>();
258        }
259        features.clear();
260        int prevIdx = -1;
261        for (Feature f : example) {
262            int id = featureIDMap.getID(f.getName());
263            double value = f.getValue();
264            if (id > prevIdx){
265                prevIdx = id;
266                svm_node n = new svm_node();
267                n.index = id;
268                n.value = value;
269                features.add(n);
270            } else if (id > -1) {
271                //
272                // Collision, deal with it.
273                int collisionIdx = Util.binarySearch(features,id,(svm_node n) -> n.index);
274                if (collisionIdx < 0) {
275                    //
276                    // Collision but not present in features
277                    // move data and bump i
278                    collisionIdx = - (collisionIdx + 1);
279                    svm_node n = new svm_node();
280                    n.index = id;
281                    n.value = value;
282                    features.add(collisionIdx,n);
283                } else {
284                    //
285                    // Collision present in features
286                    // add the values.
287                    svm_node n = features.get(collisionIdx);
288                    n.value += value;
289                }
290            }
291        }
292        return features.toArray(new svm_node[0]);
293    }
294
295    @Override
296    public TrainerProvenance getProvenance() {
297        return new TrainerProvenanceImpl(this);
298    }
299}
300