001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.libsvm; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Output; 026import org.tribuo.provenance.ModelProvenance; 027import libsvm.svm_model; 028import libsvm.svm_node; 029 030import java.io.Serializable; 031import java.util.ArrayList; 032import java.util.Arrays; 033import java.util.Collections; 034import java.util.List; 035import java.util.Map; 036import java.util.Optional; 037import java.util.logging.Logger; 038 039/** 040 * A model that uses an underlying libSVM model to make the 041 * predictions. 042 * <p> 043 * See: 044 * <pre> 045 * Chang CC, Lin CJ. 046 * "LIBSVM: a library for Support Vector Machines" 047 * ACM transactions on intelligent systems and technology (TIST), 2011. 048 * </pre> 049 * for the nu-svm algorithm: 050 * <pre> 051 * Schölkopf B, Smola A, Williamson R, Bartlett P L. 052 * "New support vector algorithms" 053 * Neural Computation, 2000, 1207-1245. 054 * </pre> 055 * and for the original algorithm: 056 * <pre> 057 * Cortes C, Vapnik V. 058 * "Support-Vector Networks" 059 * Machine Learning, 1995. 060 * </pre> 061 */ 062public abstract class LibSVMModel<T extends Output<T>> extends Model<T> implements Serializable { 063 private static final long serialVersionUID = 3L; 064 065 private static final Logger logger = Logger.getLogger(LibSVMModel.class.getName()); 066 067 /** 068 * The LibSVM models. Multiple models are used for multi-label or multidimensional regression outputs. 069 */ 070 protected final List<svm_model> models; 071 072 /** 073 * Constructs a LibSVMModel from the supplied arguments. 074 * @param name The model name. 075 * @param description The model provenance. 076 * @param featureIDMap The features the model knows about. 077 * @param outputIDInfo The outputs the model can produce. 078 * @param generatesProbabilities Does the model generate probabilities or not? 079 * @param models The svm models themselves. 080 */ 081 protected LibSVMModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, List<svm_model> models) { 082 super(name, description, featureIDMap, outputIDInfo, generatesProbabilities); 083 this.models = models; 084 } 085 086 /** 087 * Returns an unmodifiable copy of the underlying list of libsvm models. 088 * <p> 089 * Deprecated to unify the names across LibLinear, LibSVM and XGBoost. 090 * @return The underlying model list. 091 */ 092 @Deprecated 093 public List<svm_model> getModel() { 094 return getInnerModels(); 095 } 096 097 /** 098 * Returns an unmodifiable copy of the underlying list of libsvm models. 099 * @return The underlying model list. 100 */ 101 public List<svm_model> getInnerModels() { 102 List<svm_model> copy = new ArrayList<>(); 103 104 for (svm_model m : models) { 105 copy.add(copyModel(m)); 106 } 107 108 return Collections.unmodifiableList(copy); 109 } 110 111 @Override 112 public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) { 113 return Collections.emptyMap(); 114 } 115 116 @Override 117 public Optional<Excuse<T>> getExcuse(Example<T> example) { 118 return Optional.empty(); 119 } 120 121 /** 122 * Copies an svm_model, as it does not provide a copy method. 123 * 124 * @param model The svm_model to copy. 125 * @return A deep copy of the model. 126 */ 127 protected static svm_model copyModel(svm_model model) { 128 svm_model newModel = new svm_model(); 129 130 newModel.param = SVMParameters.copyParameters(model.param); 131 newModel.l = model.l; 132 newModel.nr_class = model.nr_class; 133 newModel.rho = model.rho != null ? Arrays.copyOf(model.rho,model.rho.length) : null; 134 newModel.probA = model.probA != null ? Arrays.copyOf(model.probA,model.probA.length) : null; 135 newModel.probB = model.probB != null ? Arrays.copyOf(model.probB,model.probB.length) : null; 136 newModel.label = model.label != null ? Arrays.copyOf(model.label,model.label.length) : null; 137 newModel.sv_indices = model.sv_indices != null ? Arrays.copyOf(model.sv_indices,model.sv_indices.length) : null; 138 newModel.nSV = model.nSV != null ? Arrays.copyOf(model.nSV,model.nSV.length) : null; 139 if (model.SV != null) { 140 newModel.SV = new svm_node[model.SV.length][]; 141 for (int i = 0; i < newModel.SV.length; i++) { 142 if (model.SV[i] != null) { 143 svm_node[] copy = new svm_node[model.SV[i].length]; 144 for (int j = 0; j < copy.length; j++) { 145 if (model.SV[i][j] != null) { 146 svm_node curCopy = new svm_node(); 147 curCopy.index = model.SV[i][j].index; 148 curCopy.value = model.SV[i][j].value; 149 copy[j] = curCopy; 150 } 151 } 152 newModel.SV[i] = copy; 153 } 154 } 155 } 156 if (model.sv_coef != null) { 157 newModel.sv_coef = new double[model.sv_coef.length][]; 158 for (int i = 0; i < newModel.sv_coef.length; i++) { 159 if (model.sv_coef[i] != null) { 160 newModel.sv_coef[i] = Arrays.copyOf(model.sv_coef[i],model.sv_coef[i].length); 161 } 162 } 163 } 164 165 return newModel; 166 } 167 168}