001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.liblinear;
018
019import de.bwaldvogel.liblinear.Linear;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Output;
026import org.tribuo.provenance.ModelProvenance;
027
028import java.io.IOException;
029import java.io.StringReader;
030import java.io.StringWriter;
031import java.util.ArrayList;
032import java.util.Collections;
033import java.util.List;
034import java.util.Optional;
035import java.util.logging.Logger;
036
037/**
038 * A {@link Model} which wraps a LibLinear-java model.
039 * <p>
040 * It disables the LibLinear debug output as it's very chatty.
041 * <p>
042 * See:
043 * <pre>
044 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
045 * "LIBLINEAR: A library for Large Linear Classification"
046 * Journal of Machine Learning Research, 2008.
047 * </pre>
048 * and for the original algorithm:
049 * <pre>
050 * Cortes C, Vapnik V.
051 * "Support-Vector Networks"
052 * Machine Learning, 1995.
053 * </pre>
054 */
055public abstract class LibLinearModel<T extends Output<T>> extends Model<T> {
056    private static final long serialVersionUID = 3L;
057
058    private static final Logger logger = Logger.getLogger(LibLinearModel.class.getName());
059
060    /**
061     * The list of LibLinear models. Multiple models are used by multi-label and multidimensional regression outputs.
062     */
063    protected final List<de.bwaldvogel.liblinear.Model> models;
064
065    /**
066     * Constructs a LibLinear model from the supplied arguments.
067     * @param name The model name.
068     * @param description The model provenance.
069     * @param featureIDMap The features this model knows about.
070     * @param labelIDMap The outputs this model produces.
071     * @param generatesProbabilities Does this model generate probabilities?
072     * @param models The liblinear models themselves.
073     */
074    protected LibLinearModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> labelIDMap, boolean generatesProbabilities, List<de.bwaldvogel.liblinear.Model> models) {
075        super(name, description, featureIDMap, labelIDMap, generatesProbabilities);
076        this.models = models;
077        Linear.disableDebugOutput();
078    }
079
080    /**
081     * Returns an unmodifiable list containing a copy of each model.
082     * <p>
083     * As liblinear-java models don't expose a copy constructor this requires
084     * serializing each model to a String and rebuilding it, and is thus quite expensive.
085     *
086     * @return A copy of all of the models.
087     */
088    public List<de.bwaldvogel.liblinear.Model> getInnerModels() {
089        List<de.bwaldvogel.liblinear.Model> copy = new ArrayList<>();
090
091        for (de.bwaldvogel.liblinear.Model m : models) {
092            copy.add(copyModel(m));
093        }
094
095        return Collections.unmodifiableList(copy);
096    }
097
098    /**
099     * This call is expensive as it copies out the weight matrix from the
100     * LibLinear model.
101     * <p>
102     * Prefer {@link LibLinearModel#getExcuses} to get multiple excuses.
103     * <p>
104     * @param e The example to excuse.
105     * @return An {@link Excuse} for this example.
106     */
107    @Override
108    public Optional<Excuse<T>> getExcuse(Example<T> e) {
109        double[][] featureWeights = getFeatureWeights();
110        return Optional.of(innerGetExcuse(e, featureWeights));
111    }
112
113    @Override
114    public Optional<List<Excuse<T>>> getExcuses(Iterable<Example<T>> examples) {
115        //This call copies out the weights, so it's better to do it once
116        double[][] featureWeights = getFeatureWeights();
117        List<Excuse<T>> excuses = new ArrayList<>();
118
119        for (Example<T> e : examples) {
120            excuses.add(innerGetExcuse(e, featureWeights));
121        }
122
123        return Optional.of(excuses);
124    }
125
126    /**
127     * Copies the model by writing it out to a String and loading it back in.
128     * <p>
129     * Unfortunately liblinear-java doesn't have a copy constructor on it's model.
130     * @param model The model to copy.
131     * @return A deep copy of the model.
132     */
133    protected static de.bwaldvogel.liblinear.Model copyModel(de.bwaldvogel.liblinear.Model model) {
134        try {
135            StringWriter writer = new StringWriter();
136            Linear.saveModel(writer,model);
137            String modelString = writer.toString();
138            StringReader reader = new StringReader(modelString);
139            return Linear.loadModel(reader);
140        } catch (IOException e) {
141            throw new IllegalStateException("IOException found when copying the model in memory via a String.",e);
142        }
143    }
144
145    /**
146     * Extracts the feature weights from the models.
147     * The first dimension corresponds to the model index.
148     * @return The feature weights.
149     */
150    protected abstract double[][] getFeatureWeights();
151
152    /**
153     * The call to getFeatureWeights in the public methods copies the
154     * weights array so this inner method exists to save the copy in getExcuses.
155     * <p>
156     * If it becomes a problem then we could cache the feature weights in the
157     * model.
158     * <p>
159     * @param e The example.
160     * @param featureWeights The per dimension feature weights.
161     * @return An excuse for this example.
162     */
163    protected abstract Excuse<T> innerGetExcuse(Example<T> e, double[][] featureWeights);
164}