001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.libsvm;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Output;
026import org.tribuo.provenance.ModelProvenance;
027import libsvm.svm_model;
028import libsvm.svm_node;
029
030import java.io.Serializable;
031import java.util.ArrayList;
032import java.util.Arrays;
033import java.util.Collections;
034import java.util.List;
035import java.util.Map;
036import java.util.Optional;
037import java.util.logging.Logger;
038
039/**
040 * A model that uses an underlying libSVM model to make the
041 * predictions.
042 * <p>
043 * See:
044 * <pre>
045 * Chang CC, Lin CJ.
046 * "LIBSVM: a library for Support Vector Machines"
047 * ACM transactions on intelligent systems and technology (TIST), 2011.
048 * </pre>
049 * for the nu-svm algorithm:
050 * <pre>
051 * Schölkopf B, Smola A, Williamson R, Bartlett P L.
052 * "New support vector algorithms"
053 * Neural Computation, 2000, 1207-1245.
054 * </pre>
055 * and for the original algorithm:
056 * <pre>
057 * Cortes C, Vapnik V.
058 * "Support-Vector Networks"
059 * Machine Learning, 1995.
060 * </pre>
061 */
062public abstract class LibSVMModel<T extends Output<T>> extends Model<T> implements Serializable {
063    private static final long serialVersionUID = 3L;
064
065    private static final Logger logger = Logger.getLogger(LibSVMModel.class.getName());
066
067    /**
068     * The LibSVM models. Multiple models are used for multi-label or multidimensional regression outputs.
069     */
070    protected final List<svm_model> models;
071
072    /**
073     * Constructs a LibSVMModel from the supplied arguments.
074     * @param name The model name.
075     * @param description The model provenance.
076     * @param featureIDMap The features the model knows about.
077     * @param outputIDInfo The outputs the model can produce.
078     * @param generatesProbabilities Does the model generate probabilities or not?
079     * @param models The svm models themselves.
080     */
081    protected LibSVMModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, boolean generatesProbabilities, List<svm_model> models) {
082        super(name, description, featureIDMap, outputIDInfo, generatesProbabilities);
083        this.models = models;
084    }
085
086    /**
087     * Returns an unmodifiable copy of the underlying list of libsvm models.
088     * <p>
089     * Deprecated to unify the names across LibLinear, LibSVM and XGBoost.
090     * @return The underlying model list.
091     */
092    @Deprecated
093    public List<svm_model> getModel() {
094        return getInnerModels();
095    }
096
097    /**
098     * Returns an unmodifiable copy of the underlying list of libsvm models.
099     * @return The underlying model list.
100     */
101    public List<svm_model> getInnerModels() {
102        List<svm_model> copy = new ArrayList<>();
103
104        for (svm_model m : models) {
105            copy.add(copyModel(m));
106        }
107
108        return Collections.unmodifiableList(copy);
109    }
110
111    @Override
112    public Map<String, List<Pair<String,Double>>> getTopFeatures(int n) {
113        return Collections.emptyMap();
114    }
115
116    @Override
117    public Optional<Excuse<T>> getExcuse(Example<T> example) {
118        return Optional.empty();
119    }
120
121    /**
122     * Copies an svm_model, as it does not provide a copy method.
123     *
124     * @param model The svm_model to copy.
125     * @return A deep copy of the model.
126     */
127    protected static svm_model copyModel(svm_model model) {
128        svm_model newModel = new svm_model();
129
130        newModel.param = SVMParameters.copyParameters(model.param);
131        newModel.l = model.l;
132        newModel.nr_class = model.nr_class;
133        newModel.rho = model.rho != null ? Arrays.copyOf(model.rho,model.rho.length) : null;
134        newModel.probA = model.probA != null ? Arrays.copyOf(model.probA,model.probA.length) : null;
135        newModel.probB = model.probB != null ? Arrays.copyOf(model.probB,model.probB.length) : null;
136        newModel.label = model.label != null ? Arrays.copyOf(model.label,model.label.length) : null;
137        newModel.sv_indices = model.sv_indices != null ? Arrays.copyOf(model.sv_indices,model.sv_indices.length) : null;
138        newModel.nSV = model.nSV != null ? Arrays.copyOf(model.nSV,model.nSV.length) : null;
139        if (model.SV != null) {
140            newModel.SV = new svm_node[model.SV.length][];
141            for (int i = 0; i < newModel.SV.length; i++) {
142                if (model.SV[i] != null) {
143                    svm_node[] copy = new svm_node[model.SV[i].length];
144                    for (int j = 0; j < copy.length; j++) {
145                        if (model.SV[i][j] != null) {
146                            svm_node curCopy = new svm_node();
147                            curCopy.index = model.SV[i][j].index;
148                            curCopy.value = model.SV[i][j].value;
149                            copy[j] = curCopy;
150                        }
151                    }
152                    newModel.SV[i] = copy;
153                }
154            }
155        }
156        if (model.sv_coef != null) {
157            newModel.sv_coef = new double[model.sv_coef.length][];
158            for (int i = 0; i < newModel.sv_coef.length; i++) {
159                if (model.sv_coef[i] != null) {
160                    newModel.sv_coef[i] = Arrays.copyOf(model.sv_coef[i],model.sv_coef[i].length);
161                }
162            }
163        }
164
165        return newModel;
166    }
167
168}