001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.common.liblinear; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.Dataset; 023import org.tribuo.Example; 024import org.tribuo.Feature; 025import org.tribuo.ImmutableFeatureMap; 026import org.tribuo.ImmutableOutputInfo; 027import org.tribuo.Output; 028import org.tribuo.Trainer; 029import org.tribuo.provenance.ModelProvenance; 030import org.tribuo.provenance.TrainerProvenance; 031import org.tribuo.provenance.impl.TrainerProvenanceImpl; 032import org.tribuo.util.Util; 033import de.bwaldvogel.liblinear.FeatureNode; 034import de.bwaldvogel.liblinear.Linear; 035import de.bwaldvogel.liblinear.Parameter; 036 037import java.time.OffsetDateTime; 038import java.util.ArrayList; 039import java.util.List; 040import java.util.Map; 041import java.util.logging.Logger; 042 043/** 044 * A {@link Trainer} which wraps a liblinear-java trainer. 045 * <p> 046 * See: 047 * <pre> 048 * Fan RE, Chang KW, Hsieh CJ, Wang XR, Lin CJ. 049 * "LIBLINEAR: A library for Large Linear Classification" 050 * Journal of Machine Learning Research, 2008. 051 * </pre> 052 * and for the original algorithm: 053 * <pre> 054 * Cortes C, Vapnik V. 055 * "Support-Vector Networks" 056 * Machine Learning, 1995. 057 * </pre> 058 */ 059public abstract class LibLinearTrainer<T extends Output<T>> implements Trainer<T> { 060 061 private static final Logger logger = Logger.getLogger(LibLinearTrainer.class.getName()); 062 063 protected Parameter libLinearParams; 064 065 @Config(description="Algorithm to use.") 066 protected LibLinearType<T> trainerType; 067 068 @Config(description="Cost penalty for misclassifications.") 069 protected double cost = 1; 070 071 @Config(description="Maximum number of iterations before terminating.") 072 protected int maxIterations = 1000; 073 074 @Config(description="Stop iterating when the loss score decreases by less than this value.") 075 protected double terminationCriterion = 0.1; 076 077 @Config(description="Epsilon insensitivity in the regression cost function.") 078 protected double epsilon = 0.1; 079 080 private int trainInvocationCount = 0; 081 082 protected LibLinearTrainer() {} 083 084 /** 085 * Creates a trainer for a LibLinear model 086 * @param trainerType Loss function and optimisation method combination. 087 * @param cost Cost penalty for each incorrectly classified training point. 088 * @param maxIterations The maximum number of dataset iterations. 089 * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1). 090 */ 091 protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion) { 092 this(trainerType,cost,maxIterations,terminationCriterion,0.1); 093 } 094 095 /** 096 * Creates a trainer for a LibLinear model 097 * @param trainerType Loss function and optimisation method combination. 098 * @param cost Cost penalty for each incorrectly classified training point. 099 * @param maxIterations The maximum number of dataset iterations. 100 * @param terminationCriterion How close does the optimisation function need to be before terminating that subproblem (usually set to 0.1). 101 * @param epsilon The insensitivity of the regression loss to small differences. 102 */ 103 protected LibLinearTrainer(LibLinearType<T> trainerType, double cost, int maxIterations, double terminationCriterion, double epsilon) { 104 this.trainerType = trainerType; 105 this.cost = cost; 106 this.maxIterations = maxIterations; 107 this.terminationCriterion = terminationCriterion; 108 this.epsilon = epsilon; 109 postConfig(); 110 } 111 112 /** 113 * Used by the OLCUT configuration system, and should not be called by external code. 114 */ 115 @Override 116 public void postConfig() { 117 libLinearParams = new Parameter(trainerType.getSolverType(),cost,terminationCriterion,maxIterations,epsilon); 118 Linear.disableDebugOutput(); 119 } 120 121 @Override 122 public LibLinearModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) { 123 if (examples.getOutputInfo().getUnknownCount() > 0) { 124 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 125 } 126 ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap(); 127 ImmutableOutputInfo<T> outputIDInfo = examples.getOutputIDInfo(); 128 TrainerProvenance trainerProvenance = getProvenance(); 129 ModelProvenance provenance = new ModelProvenance(LibLinearModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance); 130 trainInvocationCount++; 131 132 Parameter curParams = setupParameters(outputIDInfo); 133 134 Pair<FeatureNode[][],double[][]> data = extractData(examples,outputIDInfo,featureIDMap); 135 136 List<de.bwaldvogel.liblinear.Model> models = trainModels(curParams,featureIDMap.size()+1,data.getA(),data.getB()); 137 138 return createModel(provenance,featureIDMap,outputIDInfo,models); 139 } 140 141 @Override 142 public int getInvocationCount() { 143 return trainInvocationCount; 144 } 145 146 @Override 147 public String toString() { 148 StringBuilder buffer = new StringBuilder(); 149 150 buffer.append("LibLinearTrainer("); 151 buffer.append("solver="); 152 buffer.append(libLinearParams.getSolverType()); 153 buffer.append(",cost="); 154 buffer.append(libLinearParams.getC()); 155 buffer.append(",terminationCriterion="); 156 buffer.append(libLinearParams.getEps()); 157 buffer.append(",maxIterations="); 158 buffer.append(libLinearParams.getMaxIters()); 159 buffer.append(",regression-epsilon="); 160 buffer.append(libLinearParams.getP()); 161 buffer.append(')'); 162 163 return buffer.toString(); 164 } 165 166 /** 167 * Train all the liblinear instances necessary for this dataset. 168 * @param curParams The LibLinear parameters. 169 * @param numFeatures The number of features in this dataset. 170 * @param features The features themselves. 171 * @param outputs The outputs. 172 * @return A list of liblinear models. 173 */ 174 protected abstract List<de.bwaldvogel.liblinear.Model> trainModels(Parameter curParams, int numFeatures, FeatureNode[][] features, double[][] outputs); 175 176 /** 177 * Construct the appropriate subtype of LibLinearModel for the prediction task. 178 * @param provenance The model provenance. 179 * @param featureIDMap The feature id map. 180 * @param outputIDInfo The output id info. 181 * @param models The list of linear models. 182 * @return An implementation of LibLinearModel. 183 */ 184 protected abstract LibLinearModel<T> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<de.bwaldvogel.liblinear.Model> models); 185 186 /** 187 * Extracts the features and {@link Output}s in LibLinear's format. 188 * @param data The input data. 189 * @param outputInfo The output info. 190 * @param featureMap The feature info. 191 * @return The features and outputs. 192 */ 193 protected abstract Pair<FeatureNode[][],double[][]> extractData(Dataset<T> data, ImmutableOutputInfo<T> outputInfo, ImmutableFeatureMap featureMap); 194 195 /** 196 * Constructs the parameters. Most of the time this is a no-op, but 197 * classification overrides it to incorporate label weights if they exist. 198 * @param info The output info. 199 * @return The Parameters to use for training. 200 */ 201 protected Parameter setupParameters(ImmutableOutputInfo<T> info) { 202 return libLinearParams; 203 } 204 205 /** 206 * Converts a Tribuo {@link Example} into a liblinear {@code FeatureNode} array, including a bias feature. 207 * <p> 208 * If there is a collision between feature ids (i.e., if there is feature hashing or some other mechanism changing 209 * the feature ids) then the feature values are summed. 210 * @param example The input example. 211 * @param featureIDMap The feature id map which contains the example's indices. 212 * @param features A buffer. If null then an array list is created and used internally. 213 * @param <T> The output type. 214 * @return The features suitable for use in liblinear. 215 */ 216 public static <T extends Output<T>> FeatureNode[] exampleToNodes(Example<T> example, ImmutableFeatureMap featureIDMap, List<FeatureNode> features) { 217 int biasIndex = featureIDMap.size()+1; 218 219 if (features == null) { 220 features = new ArrayList<>(); 221 } 222 features.clear(); 223 224 int prevIdx = -1; 225 for (Feature f : example) { 226 int id = featureIDMap.getID(f.getName()); 227 if (id > prevIdx){ 228 prevIdx = id; 229 features.add(new FeatureNode(id + 1, f.getValue())); 230 } else if (id > -1) { 231 // 232 // Collision, deal with it. 233 int collisionIdx = Util.binarySearch(features,id+1, FeatureNode::getIndex); 234 if (collisionIdx < 0) { 235 // 236 // Collision but not present in features 237 // move data and bump i 238 collisionIdx = - (collisionIdx + 1); 239 features.add(collisionIdx,new FeatureNode(id + 1, f.getValue())); 240 } else { 241 // 242 // Collision present in features 243 // add the values. 244 FeatureNode n = features.get(collisionIdx); 245 n.setValue(n.getValue() + f.getValue()); 246 } 247 } 248 } 249 250 features.add(new FeatureNode(biasIndex,1.0)); 251 252 return features.toArray(new FeatureNode[0]); 253 } 254 255 @Override 256 public TrainerProvenance getProvenance() { 257 return new TrainerProvenanceImpl(this); 258 } 259}