001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.libsvm;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.util.Pair;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.classification.Label;
026import org.tribuo.classification.WeightedLabels;
027import org.tribuo.common.libsvm.LibSVMModel;
028import org.tribuo.common.libsvm.LibSVMTrainer;
029import org.tribuo.common.libsvm.SVMParameters;
030import org.tribuo.provenance.ModelProvenance;
031import libsvm.svm;
032import libsvm.svm_model;
033import libsvm.svm_node;
034import libsvm.svm_parameter;
035import libsvm.svm_problem;
036
037import java.util.ArrayList;
038import java.util.Collections;
039import java.util.HashMap;
040import java.util.List;
041import java.util.Map;
042import java.util.logging.Logger;
043
044/**
045 * A trainer for classification models that uses LibSVM.
046 * <p>
047 * See:
048 * <pre>
049 * Chang CC, Lin CJ.
050 * "LIBSVM: a library for Support Vector Machines"
051 * ACM transactions on intelligent systems and technology (TIST), 2011.
052 * </pre>
053 * for the nu-svc algorithm:
054 * <pre>
055 * Schölkopf B, Smola A, Williamson R, Bartlett P L.
056 * "New support vector algorithms"
057 * Neural Computation, 2000, 1207-1245.
058 * </pre>
059 * and for the original algorithm:
060 * <pre>
061 * Cortes C, Vapnik V.
062 * "Support-Vector Networks"
063 * Machine Learning, 1995.
064 * </pre>
065 */
066public class LibSVMClassificationTrainer extends LibSVMTrainer<Label> implements WeightedLabels {
067    private static final Logger logger = Logger.getLogger(LibSVMClassificationTrainer.class.getName());
068
069    @Config(description="Use Label specific weights.")
070    private Map<String,Float> labelWeights = Collections.emptyMap();
071
072    protected LibSVMClassificationTrainer() {}
073
074    public LibSVMClassificationTrainer(SVMParameters<Label> parameters) {
075        super(parameters);
076    }
077
078    /**
079     * Used by the OLCUT configuration system, and should not be called by external code.
080     */
081    @Override
082    public void postConfig() {
083        super.postConfig();
084        if (!svmType.isClassification()) {
085            throw new IllegalArgumentException("Supplied regression or anomaly detection parameters to a classification SVM.");
086        }
087    }
088
089    @Override
090    protected LibSVMModel<Label> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, List<svm_model> models) {
091        return new LibSVMClassificationModel("svm-classification-model", provenance, featureIDMap, outputIDInfo, models);
092    }
093
094    @Override
095    protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) {
096        svm_problem problem = new svm_problem();
097        problem.l = outputs[0].length;
098        problem.x = features;
099        problem.y = outputs[0];
100        if (curParams.gamma == 0) {
101            curParams.gamma = 1.0 / numFeatures;
102        }
103        String checkString = svm.svm_check_parameter(problem, curParams);
104        if(checkString != null) {
105            throw new IllegalArgumentException("Error checking SVM parameters: " + checkString);
106        }
107        return Collections.singletonList(svm.svm_train(problem, curParams));
108    }
109
110    @Override
111    protected Pair<svm_node[][], double[][]> extractData(Dataset<Label> data, ImmutableOutputInfo<Label> outputInfo, ImmutableFeatureMap featureMap) {
112        double[][] ys = new double[1][data.size()];
113        svm_node[][] xs = new svm_node[data.size()][];
114        List<svm_node> buffer = new ArrayList<>();
115        int i = 0;
116        for (Example<Label> example : data) {
117            ys[0][i] = outputInfo.getID(example.getOutput());
118            xs[i] = exampleToNodes(example, featureMap, buffer);
119            i++;
120        }
121        return new Pair<>(xs,ys);
122    }
123
124    @Override
125    protected svm_parameter setupParameters(ImmutableOutputInfo<Label> outputIDInfo) {
126        svm_parameter curParams = SVMParameters.copyParameters(parameters);
127        if (!labelWeights.isEmpty()) {
128            double[] weights = new double[outputIDInfo.size()];
129            int[] indices = new int[outputIDInfo.size()];
130            int i = 0;
131            for (Pair<Integer,Label> label : outputIDInfo) {
132                String labelName = label.getB().getLabel();
133                Float weight = labelWeights.get(labelName);
134                indices[i] = label.getA();
135                if (weight != null) {
136                    weights[i] = weight;
137                } else {
138                    weights[i] = 1.0f;
139                }
140                i++;
141            }
142            curParams.nr_weight = weights.length;
143            curParams.weight = weights;
144            curParams.weight_label = indices;
145            //logger.info("Weights = " + Arrays.toString(weights) + ", labels = " + Arrays.toString(indices) + ", outputIDInfo = " + outputIDInfo);
146        }
147        return curParams;
148    }
149
150    @Override
151    public void setLabelWeights(Map<Label,Float> weights) {
152        labelWeights = new HashMap<>();
153        for (Map.Entry<Label,Float> e : weights.entrySet()) {
154            labelWeights.put(e.getKey().getLabel(),e.getValue());
155        }
156    }
157}