001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.liblinear;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.Dataset;
023import org.tribuo.Example;
024import org.tribuo.Feature;
025import org.tribuo.ImmutableFeatureMap;
026import org.tribuo.ImmutableOutputInfo;
027import org.tribuo.Output;
028import org.tribuo.Trainer;
029import org.tribuo.provenance.ModelProvenance;
030import org.tribuo.provenance.TrainerProvenance;
031import org.tribuo.provenance.impl.TrainerProvenanceImpl;
032import org.tribuo.util.Util;
033import de.bwaldvogel.liblinear.FeatureNode;
034import de.bwaldvogel.liblinear.Linear;
035import de.bwaldvogel.liblinear.Parameter;
036
037import java.time.OffsetDateTime;
038import java.util.ArrayList;
039import java.util.List;
040import java.util.Map;
041import java.util.logging.Logger;
042
043/**
044 * A {@link Trainer} which wraps a liblinear-java trainer.
045 * <p>
046 * See:
047 * <pre>
048 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
049 * "LIBLINEAR: A library for Large Linear Classification"
050 * Journal of Machine Learning Research, 2008.
051 * </pre>
052 * and for the original algorithm:
053 * <pre>
054 * Cortes C, Vapnik V.
055 * "Support-Vector Networks"
056 * Machine Learning, 1995.
057 * </pre>
058 */
059public abstract class LibLinearTrainer<T extends Output<T>> implements Trainer<T> {
060
061    private static final Logger logger = Logger.getLogger(LibLinearTrainer.class.getName());
062
063    protected Parameter libLinearParams;
064
065    @Config(description="Algorithm to use.")
066    protected LibLinearType<T> trainerType;
067
068    @Config(description="Cost penalty for misclassifications.")
069    protected double cost = 1;
070
071    @Config(description="Maximum number of iterations before terminating.")
072    protected int maxIterations = 1000;
073
074    @Config(description="Stop iterating when the loss score decreases by less than this value.")
075    protected double terminationCriterion = 0.1;
076
077    @Config(description="Epsilon insensitivity in the regression cost function.")
078    protected double epsilon = 0.1;
079
080    private int trainInvocationCount = 0;
081
082    protected LibLinearTrainer() {}
083
084    /**
085     * Creates a trainer for a LibLinear model
086     * @param trainerType Loss function and optimisation method combination.
087     * @param cost Cost penalty for each incorrectly classified training point.
088     * @param maxIterations The maximum number of dataset iterations.
089     * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
090     */
091    protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion) {
092        this(trainerType,cost,maxIterations,terminationCriterion,0.1);
093    }
094
095    /**
096     * Creates a trainer for a LibLinear model
097     * @param trainerType Loss function and optimisation method combination.
098     * @param cost Cost penalty for each incorrectly classified training point.
099     * @param maxIterations The maximum number of dataset iterations.
100     * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
101     * @param epsilon The insensitivity of the regression loss to small differences.
102     */
103    protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon) {
104        this.trainerType = trainerType;
105        this.cost = cost;
106        this.maxIterations = maxIterations;
107        this.terminationCriterion = terminationCriterion;
108        this.epsilon = epsilon;
109        postConfig();
110    }
111
112    /**
113     * Used by the OLCUT configuration system, and should not be called by external code.
114     */
115    @Override
116    public void postConfig() {
117        libLinearParams = new Parameter(trainerType.getSolverType(),cost,terminationCriterion,maxIterations,epsilon);
118        Linear.disableDebugOutput();
119    }
120
121    @Override
122    public LibLinearModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
123        if (examples.getOutputInfo().getUnknownCount() > 0) {
124            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
125        }
126        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
127        ImmutableOutputInfo<T> outputIDInfo = examples.getOutputIDInfo();
128        TrainerProvenance trainerProvenance = getProvenance();
129        ModelProvenance provenance = new ModelProvenance(LibLinearModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
130        trainInvocationCount++;
131
132        Parameter curParams = setupParameters(outputIDInfo);
133
134        Pair<FeatureNode[][],double[][]> data = extractData(examples,outputIDInfo,featureIDMap);
135
136        List<de.bwaldvogel.liblinear.Model> models = trainModels(curParams,featureIDMap.size()+1,data.getA(),data.getB());
137
138        return createModel(provenance,featureIDMap,outputIDInfo,models);
139    }
140
141    @Override
142    public int getInvocationCount() {
143        return trainInvocationCount;
144    }
145
146    @Override
147    public String toString() {
148        StringBuilder buffer = new StringBuilder();
149
150        buffer.append("LibLinearTrainer(");
151        buffer.append("solver=");
152        buffer.append(libLinearParams.getSolverType());
153        buffer.append(",cost=");
154        buffer.append(libLinearParams.getC());
155        buffer.append(",terminationCriterion=");
156        buffer.append(libLinearParams.getEps());
157        buffer.append(",maxIterations=");
158        buffer.append(libLinearParams.getMaxIters());
159        buffer.append(",regression-epsilon=");
160        buffer.append(libLinearParams.getP());
161        buffer.append(')');
162
163        return buffer.toString();
164    }
165
166    /**
167     * Train all the liblinear instances necessary for this dataset.
168     * @param curParams The LibLinear parameters.
169     * @param numFeatures The number of features in this dataset.
170     * @param features The features themselves.
171     * @param outputs The outputs.
172     * @return A list of liblinear models.
173     */
174    protected abstract List<de.bwaldvogel.liblinear.Model> trainModels(Parameter curParams, int numFeatures, FeatureNode[][] features, double[][] outputs);
175
176    /**
177     * Construct the appropriate subtype of LibLinearModel for the prediction task.
178     * @param provenance The model provenance.
179     * @param featureIDMap The feature id map.
180     * @param outputIDInfo The output id info.
181     * @param models The list of linear models.
182     * @return An implementation of LibLinearModel.
183     */
184    protected abstract LibLinearModel<T> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<de.bwaldvogel.liblinear.Model> models);
185
186    /**
187     * Extracts the features and {@link Output}s in LibLinear's format.
188     * @param data The input data.
189     * @param outputInfo The output info.
190     * @param featureMap The feature info.
191     * @return The features and outputs.
192     */
193    protected abstract Pair<FeatureNode[][],double[][]> extractData(Dataset<T> data, ImmutableOutputInfo<T> outputInfo, ImmutableFeatureMap featureMap);
194
195    /**
196     * Constructs the parameters. Most of the time this is a no-op, but
197     * classification overrides it to incorporate label weights if they exist.
198     * @param info The output info.
199     * @return The Parameters to use for training.
200     */
201    protected Parameter setupParameters(ImmutableOutputInfo<T> info) {
202        return libLinearParams;
203    }
204
205    /**
206     * Converts a Tribuo {@link Example} into a liblinear {@code FeatureNode} array, including a bias feature.
207     * <p>
208     * If there is a collision between feature ids (i.e., if there is feature hashing or some other mechanism changing
209     * the feature ids) then the feature values are summed.
210     * @param example The input example.
211     * @param featureIDMap The feature id map which contains the example's indices.
212     * @param features A buffer. If null then an array list is created and used internally.
213     * @param <T> The output type.
214     * @return The features suitable for use in liblinear.
215     */
216    public static <T extends Output<T>> FeatureNode[] exampleToNodes(Example<T> example, ImmutableFeatureMap featureIDMap, List<FeatureNode> features) {
217        int biasIndex = featureIDMap.size()+1;
218
219        if (features == null) {
220            features = new ArrayList<>();
221        }
222        features.clear();
223
224        int prevIdx = -1;
225        for (Feature f : example) {
226            int id = featureIDMap.getID(f.getName());
227            if (id > prevIdx){
228                prevIdx = id;
229                features.add(new FeatureNode(id + 1, f.getValue()));
230            } else if (id > -1) {
231                //
232                // Collision, deal with it.
233                int collisionIdx = Util.binarySearch(features,id+1, FeatureNode::getIndex);
234                if (collisionIdx < 0) {
235                    //
236                    // Collision but not present in features
237                    // move data and bump i
238                    collisionIdx = - (collisionIdx + 1);
239                    features.add(collisionIdx,new FeatureNode(id + 1, f.getValue()));
240                } else {
241                    //
242                    // Collision present in features
243                    // add the values.
244                    FeatureNode n = features.get(collisionIdx);
245                    n.setValue(n.getValue() + f.getValue());
246                }
247            }
248        }
249
250        features.add(new FeatureNode(biasIndex,1.0));
251
252        return features.toArray(new FeatureNode[0]);
253    }
254
255    @Override
256    public TrainerProvenance getProvenance() {
257        return new TrainerProvenanceImpl(this);
258    }
259}