001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.regression.libsvm; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Dataset; 021import org.tribuo.Example; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.common.libsvm.LibSVMModel; 025import org.tribuo.common.libsvm.LibSVMTrainer; 026import org.tribuo.common.libsvm.SVMParameters; 027import org.tribuo.provenance.ModelProvenance; 028import org.tribuo.regression.Regressor; 029import libsvm.svm; 030import libsvm.svm_model; 031import libsvm.svm_node; 032import libsvm.svm_parameter; 033import libsvm.svm_problem; 034 035import java.util.ArrayList; 036import java.util.Collections; 037import java.util.List; 038import java.util.logging.Logger; 039 040/** 041 * A trainer for regression models that uses LibSVM. Trains an independent model for each output dimension. 042 * <p> 043 * See: 044 * <pre> 045 * Chang CC, Lin CJ. 046 * "LIBSVM: a library for Support Vector Machines" 047 * ACM transactions on intelligent systems and technology (TIST), 2011. 048 * </pre> 049 * for the nu-svr algorithm: 050 * <pre> 051 * Schölkopf B, Smola A, Williamson R, Bartlett P L. 052 * "New support vector algorithms" 053 * Neural Computation, 2000, 1207-1245. 054 * </pre> 055 * and for the original algorithm: 056 * <pre> 057 * Cortes C, Vapnik V. 058 * "Support-Vector Networks" 059 * Machine Learning, 1995. 060 * </pre> 061 */ 062public class LibSVMRegressionTrainer extends LibSVMTrainer<Regressor> { 063 private static final Logger logger = Logger.getLogger(LibSVMRegressionTrainer.class.getName()); 064 065 /** 066 * For olcut. 067 */ 068 protected LibSVMRegressionTrainer() {} 069 070 public LibSVMRegressionTrainer(SVMParameters<Regressor> parameters) { 071 super(parameters); 072 } 073 074 /** 075 * Used by the OLCUT configuration system, and should not be called by external code. 076 */ 077 @Override 078 public void postConfig() { 079 super.postConfig(); 080 if (!svmType.isRegression()) { 081 throw new IllegalArgumentException("Supplied classification or anomaly detection parameters to a regression SVM."); 082 } 083 } 084 085 @Override 086 protected LibSVMModel<Regressor> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, List<svm_model> models) { 087 return new LibSVMRegressionModel("svm-regression-model", provenance, featureIDMap, outputIDInfo, models); 088 } 089 090 @Override 091 protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) { 092 ArrayList<svm_model> models = new ArrayList<>(); 093 094 for (int i = 0; i < outputs.length; i++) { 095 svm_problem problem = new svm_problem(); 096 problem.l = outputs[i].length; 097 problem.x = features; 098 problem.y = outputs[i]; 099 if (curParams.gamma == 0) { 100 curParams.gamma = 1.0 / numFeatures; 101 } 102 String checkString = svm.svm_check_parameter(problem, curParams); 103 if(checkString != null) { 104 throw new IllegalArgumentException("Error checking SVM parameters: " + checkString); 105 } 106 models.add(svm.svm_train(problem, curParams)); 107 } 108 109 return Collections.unmodifiableList(models); 110 } 111 112 @Override 113 protected Pair<svm_node[][], double[][]> extractData(Dataset<Regressor> data, ImmutableOutputInfo<Regressor> outputInfo, ImmutableFeatureMap featureMap) { 114 int numOutputs = outputInfo.size(); 115 ArrayList<svm_node> buffer = new ArrayList<>(); 116 svm_node[][] features = new svm_node[data.size()][]; 117 double[][] outputs = new double[numOutputs][data.size()]; 118 int i = 0; 119 for (Example<Regressor> e : data) { 120 double[] curOutputs = e.getOutput().getValues(); 121 for (int j = 0; j < curOutputs.length; j++) { 122 outputs[j][i] = curOutputs[j]; 123 } 124 features[i] = exampleToNodes(e,featureMap,buffer); 125 i++; 126 } 127 return new Pair<>(features,outputs); 128 } 129}