001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.liblinear;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.Feature;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.Prediction;
027import org.tribuo.common.liblinear.LibLinearModel;
028import org.tribuo.common.liblinear.LibLinearTrainer;
029import org.tribuo.provenance.ModelProvenance;
030import org.tribuo.regression.Regressor;
031import de.bwaldvogel.liblinear.FeatureNode;
032import de.bwaldvogel.liblinear.Linear;
033
034import java.util.ArrayList;
035import java.util.Collections;
036import java.util.Comparator;
037import java.util.HashMap;
038import java.util.List;
039import java.util.Map;
040import java.util.PriorityQueue;
041import java.util.logging.Logger;
042
043/**
044 * A {@link Model} which wraps a LibLinear-java model.
045 * <p>
046 * It disables the LibLinear debug output as it's very chatty.
047 * <p>
048 * It contains an independent liblinear model for each regression dimension.
049 * <p>
050 * See:
051 * <pre>
052 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
053 * "LIBLINEAR: A library for Large Linear Classification"
054 * Journal of Machine Learning Research, 2008.
055 * </pre>
056 * and for the original algorithm:
057 * <pre>
058 * Cortes C, Vapnik V.
059 * "Support-Vector Networks"
060 * Machine Learning, 1995.
061 * </pre>
062 */
063public class LibLinearRegressionModel extends LibLinearModel<Regressor> {
064    private static final long serialVersionUID = 2L;
065
066    private static final Logger logger = Logger.getLogger(LibLinearRegressionModel.class.getName());
067
068    private final String[] dimensionNames;
069
070    LibLinearRegressionModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo, List<de.bwaldvogel.liblinear.Model> models) {
071        super(name, description, featureIDMap, outputInfo, false, models);
072        this.dimensionNames = Regressor.extractNames(outputInfo);
073    }
074
075    @Override
076    public Prediction<Regressor> predict(Example<Regressor> example) {
077        FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, featureIDMap, null);
078        // Bias feature is always set
079        if (features.length == 1) {
080            throw new IllegalArgumentException("No features found in Example " + example.toString());
081        }
082
083        double[] scores = new double[models.get(0).getNrClass()];
084        double[] regressedValues = new double[models.size()];
085
086        for (int i = 0; i < regressedValues.length; i++) {
087            regressedValues[i] = Linear.predictValues(models.get(i),features,scores);
088        }
089
090        Regressor regressor = new Regressor(dimensionNames,regressedValues);
091        return new Prediction<>(regressor, features.length - 1, example);
092    }
093
094    @Override
095    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
096        int maxFeatures = n < 0 ? featureIDMap.size() : n;
097        double[][] featureWeights = getFeatureWeights();
098
099        Comparator<Pair<String, Double>> comparator = Comparator.comparingDouble(p -> Math.abs(p.getB()));
100        Map<String, List<Pair<String, Double>>> map = new HashMap<>();
101        PriorityQueue<Pair<String, Double>> q = new PriorityQueue<>(maxFeatures, comparator);
102
103        for (int i = 0; i < models.size(); i++) {
104            int numFeatures = models.get(i).getNrFeature();
105            for (int j = 0; j < numFeatures; j++) {
106                Pair<String, Double> cur = new Pair<>(featureIDMap.get(j).getName(), featureWeights[i][j]);
107                if (maxFeatures < 0 || q.size() < maxFeatures) {
108                    q.offer(cur);
109                } else if (comparator.compare(cur, q.peek()) > 0) {
110                    q.poll();
111                    q.offer(cur);
112                }
113            }
114            List<Pair<String, Double>> list = new ArrayList<>();
115            while (q.size() > 0) {
116                list.add(q.poll());
117            }
118            Collections.reverse(list);
119            map.put(dimensionNames[i], list);
120        }
121
122        return map;
123    }
124
125    @Override
126    protected LibLinearRegressionModel copy(String newName, ModelProvenance newProvenance) {
127        List<de.bwaldvogel.liblinear.Model> newModels = new ArrayList<>();
128        for (de.bwaldvogel.liblinear.Model m : models) {
129            newModels.add(copyModel(m));
130        }
131        return new LibLinearRegressionModel(newName,newProvenance,featureIDMap,outputIDInfo,newModels);
132    }
133
134    @Override
135    protected double[][] getFeatureWeights() {
136        double[][] featureWeights = new double[models.size()][];
137
138        for (int i = 0; i < models.size(); i++) {
139            featureWeights[i] = models.get(i).getFeatureWeights();
140        }
141
142        return featureWeights;
143    }
144
145    /**
146     * The call to model.getFeatureWeights in the public methods copies the
147     * weights array so this inner method exists to save the copy in getExcuses.
148     * <p>
149     * If it becomes a problem then we could cache the feature weights in the
150     * model.
151     * <p>
152     * @param e The example.
153     * @param allFeatureWeights The feature weights.
154     * @return An excuse for this example.
155     */
156    @Override
157    protected Excuse<Regressor> innerGetExcuse(Example<Regressor> e, double[][] allFeatureWeights) {
158        Prediction<Regressor> prediction = predict(e);
159        Map<String, List<Pair<String, Double>>> weightMap = new HashMap<>();
160        for (int i = 0; i < allFeatureWeights.length; i++) {
161            List<Pair<String, Double>> scores = new ArrayList<>();
162            for (Feature f : e) {
163                int id = featureIDMap.getID(f.getName());
164                if (id > -1) {
165                    double score = allFeatureWeights[i][id] * f.getValue();
166                    scores.add(new Pair<>(f.getName(), score));
167                }
168            }
169            scores.sort((o1, o2) -> o2.getB().compareTo(o1.getB()));
170            weightMap.put(dimensionNames[i], scores);
171        }
172
173        return new Excuse<>(e, prediction, weightMap);
174    }
175}