001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.libsvm;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.common.libsvm.LibSVMModel;
025import org.tribuo.common.libsvm.LibSVMTrainer;
026import org.tribuo.common.libsvm.SVMParameters;
027import org.tribuo.provenance.ModelProvenance;
028import org.tribuo.regression.Regressor;
029import libsvm.svm;
030import libsvm.svm_model;
031import libsvm.svm_node;
032import libsvm.svm_parameter;
033import libsvm.svm_problem;
034
035import java.util.ArrayList;
036import java.util.Collections;
037import java.util.List;
038import java.util.logging.Logger;
039
040/**
041 * A trainer for regression models that uses LibSVM. Trains an independent model for each output dimension.
042 * <p>
043 * See:
044 * <pre>
045 * Chang CC, Lin CJ.
046 * "LIBSVM: a library for Support Vector Machines"
047 * ACM transactions on intelligent systems and technology (TIST), 2011.
048 * </pre>
049 * for the nu-svr algorithm:
050 * <pre>
051 * Schölkopf B, Smola A, Williamson R, Bartlett P L.
052 * "New support vector algorithms"
053 * Neural Computation, 2000, 1207-1245.
054 * </pre>
055 * and for the original algorithm:
056 * <pre>
057 * Cortes C, Vapnik V.
058 * "Support-Vector Networks"
059 * Machine Learning, 1995.
060 * </pre>
061 */
062public class LibSVMRegressionTrainer extends LibSVMTrainer<Regressor> {
063    private static final Logger logger = Logger.getLogger(LibSVMRegressionTrainer.class.getName());
064
065    /**
066     * For olcut.
067     */
068    protected LibSVMRegressionTrainer() {}
069
070    public LibSVMRegressionTrainer(SVMParameters<Regressor> parameters) {
071        super(parameters);
072    }
073
074    /**
075     * Used by the OLCUT configuration system, and should not be called by external code.
076     */
077    @Override
078    public void postConfig() {
079        super.postConfig();
080        if (!svmType.isRegression()) {
081            throw new IllegalArgumentException("Supplied classification or anomaly detection parameters to a regression SVM.");
082        }
083    }
084
085    @Override
086    protected LibSVMModel<Regressor> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, List<svm_model> models) {
087        return new LibSVMRegressionModel("svm-regression-model", provenance, featureIDMap, outputIDInfo, models);
088    }
089
090    @Override
091    protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) {
092        ArrayList<svm_model> models = new ArrayList<>();
093
094        for (int i = 0; i < outputs.length; i++) {
095            svm_problem problem = new svm_problem();
096            problem.l = outputs[i].length;
097            problem.x = features;
098            problem.y = outputs[i];
099            if (curParams.gamma == 0) {
100                curParams.gamma = 1.0 / numFeatures;
101            }
102            String checkString = svm.svm_check_parameter(problem, curParams);
103            if(checkString != null) {
104                throw new IllegalArgumentException("Error checking SVM parameters: " + checkString);
105            }
106            models.add(svm.svm_train(problem, curParams));
107        }
108
109        return Collections.unmodifiableList(models);
110    }
111
112    @Override
113    protected Pair<svm_node[][], double[][]> extractData(Dataset<Regressor> data, ImmutableOutputInfo<Regressor> outputInfo, ImmutableFeatureMap featureMap) {
114        int numOutputs = outputInfo.size();
115        ArrayList<svm_node> buffer = new ArrayList<>();
116        svm_node[][] features = new svm_node[data.size()][];
117        double[][] outputs = new double[numOutputs][data.size()];
118        int i = 0;
119        for (Example<Regressor> e : data) {
120            double[] curOutputs = e.getOutput().getValues();
121            for (int j = 0; j < curOutputs.length; j++) {
122                outputs[j][i] = curOutputs[j];
123            }
124            features[i] = exampleToNodes(e,featureMap,buffer);
125            i++;
126        }
127        return new Pair<>(features,outputs);
128    }
129}