001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.liblinear; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.util.Pair; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Trainer; 026import org.tribuo.classification.Label; 027import org.tribuo.classification.WeightedLabels; 028import org.tribuo.classification.liblinear.LinearClassificationType.LinearType; 029import org.tribuo.common.liblinear.LibLinearModel; 030import org.tribuo.common.liblinear.LibLinearTrainer; 031import org.tribuo.provenance.ModelProvenance; 032import de.bwaldvogel.liblinear.FeatureNode; 033import de.bwaldvogel.liblinear.Linear; 034import de.bwaldvogel.liblinear.Model; 035import de.bwaldvogel.liblinear.Parameter; 036import de.bwaldvogel.liblinear.Problem; 037 038import java.util.ArrayList; 039import java.util.Collections; 040import java.util.HashMap; 041import java.util.List; 042import java.util.Map; 043import java.util.logging.Logger; 044 045/** 046 * A {@link Trainer} which wraps a liblinear-java classifier trainer. 047 * 048 * See: 049 * <pre> 050 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ. 051 * "LIBLINEAR: A library for Large Linear Classification" 052 * Journal of Machine Learning Research, 2008. 053 * </pre> 054 * and for the original algorithm: 055 * <pre> 056 * Cortes C, Vapnik V. 057 * "Support-Vector Networks" 058 * Machine Learning, 1995. 059 * </pre> 060 */ 061public class LibLinearClassificationTrainer extends LibLinearTrainer<Label> implements WeightedLabels { 062 063 private static final Logger logger = Logger.getLogger(LibLinearClassificationTrainer.class.getName()); 064 065 @Config(description="Use Label specific weights.") 066 private Map<String,Float> labelWeights = Collections.emptyMap(); 067 068 /** 069 * Creates a trainer using the default values (L2R_L2LOSS_SVC_DUAL, 1, 0.1). 070 */ 071 public LibLinearClassificationTrainer() { 072 this(new LinearClassificationType(LinearType.L2R_L2LOSS_SVC_DUAL),1,1000,0.1); 073 } 074 075 /** 076 * Creates a trainer for a LibLinearClassificationModel. Sets maxIterations to 1000. 077 * @param trainerType Loss function and optimisation method combination. 078 * @param cost Cost penalty for each incorrectly classified training point. 079 * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1). 080 */ 081 public LibLinearClassificationTrainer(LinearClassificationType trainerType, double cost, double terminationCriterion) { 082 this(trainerType,cost,1000,terminationCriterion); 083 } 084 085 /** 086 * Creates a trainer for a LibLinear model 087 * @param trainerType Loss function and optimisation method combination. 088 * @param cost Cost penalty for each incorrectly classified training point. 089 * @param maxIterations The maximum number of dataset iterations. 090 * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1). 091 */ 092 public LibLinearClassificationTrainer(LinearClassificationType trainerType, double cost, int maxIterations, double terminationCriterion) { 093 super(trainerType,cost,maxIterations,terminationCriterion); 094 } 095 096 /** 097 * Used by the OLCUT configuration system, and should not be called by external code. 098 */ 099 @Override 100 public void postConfig() { 101 super.postConfig(); 102 if (!trainerType.isClassification()) { 103 throw new IllegalArgumentException("Supplied regression parameters to a classification linear model."); 104 } 105 } 106 107 @Override 108 protected List<Model> trainModels(Parameter curParams, int numFeatures, FeatureNode[][] features, double[][] outputs) { 109 Problem data = new Problem(); 110 111 data.l = features.length; 112 data.y = outputs[0]; 113 data.x = features; 114 data.n = numFeatures; 115 data.bias = 1.0; 116 117 return Collections.singletonList(Linear.train(data,curParams)); 118 } 119 120 @Override 121 protected LibLinearModel<Label> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo, List<Model> models) { 122 if (models.size() != 1) { 123 throw new IllegalArgumentException("Classification uses a single model. Found " + models.size() + " models."); 124 } 125 return new LibLinearClassificationModel("liblinear-classification-model",provenance,featureIDMap,outputIDInfo,models); 126 } 127 128 @Override 129 protected Pair<FeatureNode[][], double[][]> extractData(Dataset<Label> data, ImmutableOutputInfo<Label> outputInfo, ImmutableFeatureMap featureMap) { 130 ArrayList<FeatureNode> featureCache = new ArrayList<>(); 131 FeatureNode[][] features = new FeatureNode[data.size()][]; 132 double[][] outputs = new double[1][data.size()]; 133 int i = 0; 134 for (Example<Label> e : data) { 135 outputs[0][i] = outputInfo.getID(e.getOutput()); 136 features[i] = exampleToNodes(e,featureMap,featureCache); 137 i++; 138 } 139 return new Pair<>(features,outputs); 140 } 141 142 @Override 143 protected Parameter setupParameters(ImmutableOutputInfo<Label> labelIDMap) { 144 Parameter curParams; 145 if (!labelWeights.isEmpty()) { 146 curParams = new Parameter(libLinearParams.getSolverType(),libLinearParams.getC(),libLinearParams.getEps()); 147 double[] weights = new double[labelIDMap.size()]; 148 int[] indices = new int[labelIDMap.size()]; 149 int i = 0; 150 for (Pair<Integer,Label> label : labelIDMap) { 151 String labelName = label.getB().getLabel(); 152 Float weight = labelWeights.get(labelName); 153 indices[i] = label.getA(); 154 if (weight != null) { 155 weights[i] = weight; 156 } else { 157 weights[i] = 1.0f; 158 } 159 i++; 160 } 161 curParams.setWeights(weights,indices); 162 //logger.info("Weights = " + Arrays.toString(weights) + ", labels = " + Arrays.toString(indices) + ", outputIDInfo = " + outputIDInfo); 163 } else { 164 curParams = libLinearParams; 165 } 166 return curParams; 167 } 168 169 @Override 170 public void setLabelWeights(Map<Label,Float> weights) { 171 labelWeights = new HashMap<>(); 172 for (Map.Entry<Label,Float> e : weights.entrySet()) { 173 labelWeights.put(e.getKey().getLabel(),e.getValue()); 174 } 175 } 176 177}