001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.liblinear;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Dataset;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Trainer;
025import org.tribuo.common.liblinear.LibLinearModel;
026import org.tribuo.common.liblinear.LibLinearTrainer;
027import org.tribuo.provenance.ModelProvenance;
028import org.tribuo.regression.Regressor;
029import org.tribuo.regression.liblinear.LinearRegressionType.LinearType;
030import de.bwaldvogel.liblinear.FeatureNode;
031import de.bwaldvogel.liblinear.Linear;
032import de.bwaldvogel.liblinear.Model;
033import de.bwaldvogel.liblinear.Parameter;
034import de.bwaldvogel.liblinear.Problem;
035
036import java.util.ArrayList;
037import java.util.List;
038import java.util.logging.Logger;
039
040/**
041 * A {@link Trainer} which wraps a liblinear-java regression trainer.
042 * <p>
043 * This generates an independent liblinear model for each regression dimension.
044 * <p>
045 * See:
046 * <pre>
047 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ.
048 * "LIBLINEAR: A library for Large Linear Classification"
049 * Journal of Machine Learning Research, 2008.
050 * </pre>
051 * and for the original algorithm:
052 * <pre>
053 * Cortes C, Vapnik V.
054 * "Support-Vector Networks"
055 * Machine Learning, 1995.
056 * </pre>
057 */
058public class LibLinearRegressionTrainer extends LibLinearTrainer<Regressor> {
059
060    private static final Logger logger = Logger.getLogger(LibLinearRegressionTrainer.class.getName());
061
062    /**
063     * Creates a trainer using the default values (L2R_L2LOSS_SVR, 1, 1000, 0.1, 0.1).
064     */
065    public LibLinearRegressionTrainer() {
066        this(new LinearRegressionType(LinearType.L2R_L2LOSS_SVR));
067    }
068
069    public LibLinearRegressionTrainer(LinearRegressionType trainerType) {
070        this(trainerType,1.0,1000,0.1,0.1);
071    }
072
073    /**
074     * Creates a trainer for a LibLinear model
075     * @param trainerType Loss function and optimisation method combination.
076     * @param cost Cost penalty for each incorrectly classified training point.
077     * @param maxIterations The maximum number of dataset iterations.
078     * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1).
079     * @param epsilon The insensitivity of the regression loss to small differences.
080     */
081    public LibLinearRegressionTrainer(LinearRegressionType trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon) {
082        super(trainerType,cost,maxIterations,terminationCriterion,epsilon);
083    }
084
085    /**
086     * Used by the OLCUT configuration system, and should not be called by external code.
087     */
088    @Override
089    public void postConfig() {
090        super.postConfig();
091        if (!trainerType.isClassification()) {
092            throw new IllegalArgumentException("Supplied regression parameters to a classification linear model.");
093        }
094    }
095
096    @Override
097    protected List<Model> trainModels(Parameter curParams, int numFeatures, FeatureNode[][] features, double[][] outputs) {
098        ArrayList<Model> models = new ArrayList<>();
099
100        for (int i = 0; i < outputs.length; i++) {
101            Problem data = new Problem();
102
103            data.l = features.length;
104            data.y = outputs[i];
105            data.x = features;
106            data.n = numFeatures;
107            data.bias = 1.0;
108
109            models.add(Linear.train(data, curParams));
110        }
111
112        return models;
113    }
114
115    @Override
116    protected LibLinearModel<Regressor> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputIDInfo, List<Model> models) {
117        if (models.size() != outputIDInfo.size()) {
118            throw new IllegalArgumentException("Regression uses one model per dimension. Found " + models.size() + " models, and " + outputIDInfo.size() + " dimensions.");
119        }
120        return new LibLinearRegressionModel("liblinear-regression-model",provenance,featureIDMap,outputIDInfo,models);
121    }
122
123    @Override
124    protected Pair<FeatureNode[][], double[][]> extractData(Dataset<Regressor> data, ImmutableOutputInfo<Regressor> outputInfo, ImmutableFeatureMap featureMap) {
125        int numOutputs = outputInfo.size();
126        ArrayList<FeatureNode> featureCache = new ArrayList<>();
127        FeatureNode[][] features = new FeatureNode[data.size()][];
128        double[][] outputs = new double[numOutputs][data.size()];
129        int i = 0;
130        for (Example<Regressor> e : data) {
131            double[] curOutputs = e.getOutput().getValues();
132            for (int j = 0; j < curOutputs.length; j++) {
133                outputs[j][i] = curOutputs[j];
134            }
135            features[i] = exampleToNodes(e,featureMap,featureCache);
136            i++;
137        }
138        return new Pair<>(features,outputs);
139    }
140
141}