001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.libsvm; 018 019import org.tribuo.Example; 020import org.tribuo.ImmutableFeatureMap; 021import org.tribuo.ImmutableOutputInfo; 022import org.tribuo.Prediction; 023import org.tribuo.common.libsvm.LibSVMModel; 024import org.tribuo.common.libsvm.LibSVMTrainer; 025import org.tribuo.provenance.ModelProvenance; 026import org.tribuo.regression.Regressor; 027import libsvm.svm; 028import libsvm.svm_model; 029import libsvm.svm_node; 030 031import java.util.ArrayList; 032import java.util.HashMap; 033import java.util.List; 034import java.util.Map; 035 036/** 037 * A regression model that uses an underlying libSVM model to make the 038 * predictions. Contains an independent model for each output dimension. 039 * <p> 040 * See: 041 * <pre> 042 * Chang CC, Lin CJ. 043 * "LIBSVM: a library for Support Vector Machines" 044 * ACM transactions on intelligent systems and technology (TIST), 2011. 045 * </pre> 046 * for the nu-svr algorithm: 047 * <pre> 048 * Schölkopf B, Smola A, Williamson R, Bartlett P L. 049 * "New support vector algorithms" 050 * Neural Computation, 2000, 1207-1245. 051 * </pre> 052 * and for the original algorithm: 053 * <pre> 054 * Cortes C, Vapnik V. 055 * "Support-Vector Networks" 056 * Machine Learning, 1995. 057 * </pre> 058 */ 059public class LibSVMRegressionModel extends LibSVMModel<Regressor> { 060 private static final long serialVersionUID = 2L; 061 062 private final String[] dimensionNames; 063 064 LibSVMRegressionModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, List<svm_model> models) { 065 super(name, description, featureIDMap, outputIDInfo, false, models); 066 this.dimensionNames = Regressor.extractNames(outputIDInfo); 067 } 068 069 /** 070 * Returns the support vectors used for each dimension. 071 * @return The support vectors. 072 */ 073 public Map<String,Integer> getNumberOfSupportVectors() { 074 Map<String,Integer> output = new HashMap<>(); 075 076 for (int i = 0; i < dimensionNames.length; i++) { 077 output.put(dimensionNames[i],models.get(i).SV.length); 078 } 079 080 return output; 081 } 082 083 @Override 084 public Prediction<Regressor> predict(Example<Regressor> example) { 085 svm_node[] features = LibSVMTrainer.exampleToNodes(example, featureIDMap, null); 086 // Bias feature is always set 087 if (features.length == 0) { 088 throw new IllegalArgumentException("No features found in Example " + example.toString()); 089 } 090 double[] scores = new double[1]; 091 double[] regressedValues = new double[models.size()]; 092 093 for (int i = 0; i < regressedValues.length; i++) { 094 regressedValues[i] = svm.svm_predict_values(models.get(i), features, scores); 095 } 096 097 Regressor regressor = new Regressor(dimensionNames,regressedValues); 098 return new Prediction<>(regressor, features.length, example); 099 } 100 101 @Override 102 protected LibSVMRegressionModel copy(String newName, ModelProvenance newProvenance) { 103 List<svm_model> newModels = new ArrayList<>(); 104 for (svm_model m : models) { 105 newModels.add(copyModel(m)); 106 } 107 return new LibSVMRegressionModel(newName,newProvenance,featureIDMap,outputIDInfo,newModels); 108 } 109 110}