001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.libsvm;
018
019import org.tribuo.Example;
020import org.tribuo.ImmutableFeatureMap;
021import org.tribuo.ImmutableOutputInfo;
022import org.tribuo.Prediction;
023import org.tribuo.common.libsvm.LibSVMModel;
024import org.tribuo.common.libsvm.LibSVMTrainer;
025import org.tribuo.provenance.ModelProvenance;
026import org.tribuo.regression.Regressor;
027import libsvm.svm;
028import libsvm.svm_model;
029import libsvm.svm_node;
030
031import java.util.ArrayList;
032import java.util.HashMap;
033import java.util.List;
034import java.util.Map;
035
036/**
037 * A regression model that uses an underlying libSVM model to make the
038 * predictions. Contains an independent model for each output dimension.
039 * <p>
040 * See:
041 * <pre>
042 * Chang CC, Lin CJ.
043 * "LIBSVM: a library for Support Vector Machines"
044 * ACM transactions on intelligent systems and technology (TIST), 2011.
045 * </pre>
046 * for the nu-svr algorithm:
047 * <pre>
048 * Schölkopf B, Smola A, Williamson R, Bartlett P L.
049 * "New support vector algorithms"
050 * Neural Computation, 2000, 1207-1245.
051 * </pre>
052 * and for the original algorithm:
053 * <pre>
054 * Cortes C, Vapnik V.
055 * "Support-Vector Networks"
056 * Machine Learning, 1995.
057 * </pre>
058 */
059public class LibSVMRegressionModel extends LibSVMModel<Regressor> {
060    private static final long serialVersionUID = 2L;
061
062    private final String[] dimensionNames;
063
064    LibSVMRegressionModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, List<svm_model> models) {
065        super(name, description, featureIDMap, outputIDInfo, false, models);
066        this.dimensionNames = Regressor.extractNames(outputIDInfo);
067    }
068
069    /**
070     * Returns the support vectors used for each dimension.
071     * @return The support vectors.
072     */
073    public Map<String,Integer> getNumberOfSupportVectors() {
074        Map<String,Integer> output = new HashMap<>();
075
076        for (int i = 0; i < dimensionNames.length; i++) {
077            output.put(dimensionNames[i],models.get(i).SV.length);
078        }
079
080        return output;
081    }
082
083    @Override
084    public Prediction<Regressor> predict(Example<Regressor> example) {
085        svm_node[] features = LibSVMTrainer.exampleToNodes(example, featureIDMap, null);
086        // Bias feature is always set
087        if (features.length == 0) {
088            throw new IllegalArgumentException("No features found in Example " + example.toString());
089        }
090        double[] scores = new double[1];
091        double[] regressedValues = new double[models.size()];
092
093        for (int i = 0; i < regressedValues.length; i++) {
094            regressedValues[i] = svm.svm_predict_values(models.get(i), features, scores);
095        }
096
097        Regressor regressor = new Regressor(dimensionNames,regressedValues);
098        return new Prediction<>(regressor, features.length, example);
099    }
100
101    @Override
102    protected LibSVMRegressionModel copy(String newName, ModelProvenance newProvenance) {
103        List<svm_model> newModels = new ArrayList<>();
104        for (svm_model m : models) {
105            newModels.add(copyModel(m));
106        }
107        return new LibSVMRegressionModel(newName,newProvenance,featureIDMap,outputIDInfo,newModels);
108    }
109
110}