001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.liblinear;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.util.Pair;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Trainer;
026import org.tribuo.classification.Label;
027import org.tribuo.classification.WeightedLabels;
028import org.tribuo.classification.liblinear.LinearClassificationType.LinearType;
029import org.tribuo.common.liblinear.LibLinearModel;
030import org.tribuo.common.liblinear.LibLinearTrainer;
031import org.tribuo.provenance.ModelProvenance;
032import de.bwaldvogel.liblinear.FeatureNode;
033import de.bwaldvogel.liblinear.Linear;
034import de.bwaldvogel.liblinear.Model;
035import de.bwaldvogel.liblinear.Parameter;
036import de.bwaldvogel.liblinear.Problem;
037
038import java.util.ArrayList;
039import java.util.Collections;
040import java.util.HashMap;
041import java.util.List;
042import java.util.Map;
043import java.util.logging.Logger;
044
045/**
046 * A {@link Trainer} which wraps a liblinear-java classifier trainer.
047 *
048 * See:
049 * <pre>
050 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
051 * "LIBLINEAR: A library for Large Linear Classification"
052 * Journal of Machine Learning Research, 2008.
053 * </pre>
054 * and for the original algorithm:
055 * <pre>
056 * Cortes C, Vapnik V.
057 * "Support-Vector Networks"
058 * Machine Learning, 1995.
059 * </pre>
060 */
061public class LibLinearClassificationTrainer extends LibLinearTrainer<Label> implements WeightedLabels {
062
063    private static final Logger logger = Logger.getLogger(LibLinearClassificationTrainer.class.getName());
064
065    @Config(description="Use Label specific weights.")
066    private Map<String,Float> labelWeights = Collections.emptyMap();
067
068    /**
069     * Creates a trainer using the default values (L2R_L2LOSS_SVC_DUAL, 1, 0.1).
070     */
071    public LibLinearClassificationTrainer() {
072        this(new LinearClassificationType(LinearType.L2R_L2LOSS_SVC_DUAL),1,1000,0.1);
073    }
074
075    /**
076     * Creates a trainer for a LibLinearClassificationModel. Sets maxIterations to 1000.
077     * @param trainerType Loss function and optimisation method combination.
078     * @param cost Cost penalty for each incorrectly classified training point.
079     * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
080     */
081    public LibLinearClassificationTrainer(LinearClassificationType trainerType, double cost, double terminationCriterion) {
082        this(trainerType,cost,1000,terminationCriterion);
083    }
084
085    /**
086     * Creates a trainer for a LibLinear model
087     * @param trainerType Loss function and optimisation method combination.
088     * @param cost Cost penalty for each incorrectly classified training point.
089     * @param maxIterations The maximum number of dataset iterations.
090     * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
091     */
092    public LibLinearClassificationTrainer(LinearClassificationType trainerType, double cost, int maxIterations, double terminationCriterion) {
093        super(trainerType,cost,maxIterations,terminationCriterion);
094    }
095
096    /**
097     * Used by the OLCUT configuration system, and should not be called by external code.
098     */
099    @Override
100    public void postConfig() {
101        super.postConfig();
102        if (!trainerType.isClassification()) {
103            throw new IllegalArgumentException("Supplied regression parameters to a classification linear model.");
104        }
105    }
106
107    @Override
108    protected List<Model> trainModels(Parameter curParams, int numFeatures, FeatureNode[][] features, double[][] outputs) {
109        Problem data = new Problem();
110
111        data.l = features.length;
112        data.y = outputs[0];
113        data.x = features;
114        data.n = numFeatures;
115        data.bias = 1.0;
116
117        return Collections.singletonList(Linear.train(data,curParams));
118    }
119
120    @Override
121    protected LibLinearModel<Label> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, List<Model> models) {
122        if (models.size() != 1) {
123            throw new IllegalArgumentException("Classification uses a single model. Found " + models.size() + " models.");
124        }
125        return new LibLinearClassificationModel("liblinear-classification-model",provenance,featureIDMap,outputIDInfo,models);
126    }
127
128    @Override
129    protected Pair<FeatureNode[][], double[][]> extractData(Dataset<Label> data, ImmutableOutputInfo<Label> outputInfo, ImmutableFeatureMap featureMap) {
130        ArrayList<FeatureNode> featureCache = new ArrayList<>();
131        FeatureNode[][] features = new FeatureNode[data.size()][];
132        double[][] outputs = new double[1][data.size()];
133        int i = 0;
134        for (Example<Label> e : data) {
135            outputs[0][i] = outputInfo.getID(e.getOutput());
136            features[i] = exampleToNodes(e,featureMap,featureCache);
137            i++;
138        }
139        return new Pair<>(features,outputs);
140    }
141
142    @Override
143    protected Parameter setupParameters(ImmutableOutputInfo<Label> labelIDMap) {
144        Parameter curParams;
145        if (!labelWeights.isEmpty()) {
146            curParams = new Parameter(libLinearParams.getSolverType(),libLinearParams.getC(),libLinearParams.getEps());
147            double[] weights = new double[labelIDMap.size()];
148            int[] indices = new int[labelIDMap.size()];
149            int i = 0;
150            for (Pair<Integer,Label> label : labelIDMap) {
151                String labelName = label.getB().getLabel();
152                Float weight = labelWeights.get(labelName);
153                indices[i] = label.getA();
154                if (weight != null) {
155                    weights[i] = weight;
156                } else {
157                    weights[i] = 1.0f;
158                }
159                i++;
160            }
161            curParams.setWeights(weights,indices);
162            //logger.info("Weights = " + Arrays.toString(weights) + ", labels = " + Arrays.toString(indices) + ", outputIDInfo = " + outputIDInfo);
163        } else {
164            curParams = libLinearParams;
165        }
166        return curParams;
167    }
168
169    @Override
170    public void setLabelWeights(Map<Label,Float> weights) {
171        labelWeights = new HashMap<>();
172        for (Map.Entry<Label,Float> e : weights.entrySet()) {
173            labelWeights.put(e.getKey().getLabel(),e.getValue());
174        }
175    }
176
177}