001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.ensemble; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import org.tribuo.Dataset; 023import org.tribuo.Example; 024import org.tribuo.ImmutableDataset; 025import org.tribuo.ImmutableFeatureMap; 026import org.tribuo.ImmutableOutputInfo; 027import org.tribuo.Model; 028import org.tribuo.Prediction; 029import org.tribuo.Trainer; 030import org.tribuo.WeightedExamples; 031import org.tribuo.classification.Label; 032import org.tribuo.dataset.DatasetView; 033import org.tribuo.ensemble.WeightedEnsembleModel; 034import org.tribuo.provenance.EnsembleModelProvenance; 035import org.tribuo.provenance.TrainerProvenance; 036import org.tribuo.provenance.impl.TrainerProvenanceImpl; 037import org.tribuo.util.Util; 038 039import java.time.OffsetDateTime; 040import java.util.ArrayList; 041import java.util.Arrays; 042import java.util.List; 043import java.util.Map; 044import java.util.SplittableRandom; 045import java.util.logging.Level; 046import java.util.logging.Logger; 047 048/** 049 * Implements Adaboost.SAMME one of the more popular algorithms for multiclass boosting. 050 * Based on <a href="https://web.stanford.edu/~hastie/Papers/samme.pdf">this paper</a>. 051 * <p> 052 * If the trainer implements {@link WeightedExamples} then it performs boosting by weighting, 053 * otherwise it uses a weighted bootstrap sample. 054 * <p> 055 * See: 056 * <pre> 057 * J. Zhu, S. Rosset, H. Zou, T. Hastie. 058 * "Multi-class Adaboost" 059 * 2006. 060 * </pre> 061 */ 062public class AdaBoostTrainer implements Trainer<Label> { 063 064 private static final Logger logger = Logger.getLogger(AdaBoostTrainer.class.getName()); 065 066 @Config(mandatory=true, description="The trainer to use to build each weak learner.") 067 protected Trainer<Label> innerTrainer; 068 069 @Config(mandatory=true, description="The number of ensemble members to train.") 070 protected int numMembers; 071 072 @Config(mandatory=true, description="The seed for the RNG.") 073 protected long seed; 074 075 protected SplittableRandom rng; 076 077 protected int trainInvocationCounter; 078 079 /** 080 * For the OLCUT configuration system. 081 */ 082 private AdaBoostTrainer() { } 083 084 /** 085 * Constructs an adaboost trainer using the supplied weak learner trainer and the specified number of 086 * boosting rounds. Uses the default seed. 087 * @param trainer The weak learner trainer. 088 * @param numMembers The maximum number of boosting rounds. 089 */ 090 public AdaBoostTrainer(Trainer<Label> trainer, int numMembers) { 091 this(trainer, numMembers, Trainer.DEFAULT_SEED); 092 } 093 094 /** 095 * Constructs an adaboost trainer using the supplied weak learner trainer, the specified number of 096 * boosting rounds and the supplied seed. 097 * @param trainer The weak learner trainer. 098 * @param numMembers The maximum number of boosting rounds. 099 * @param seed The RNG seed. 100 */ 101 public AdaBoostTrainer(Trainer<Label> trainer, int numMembers, long seed) { 102 this.innerTrainer = trainer; 103 this.numMembers = numMembers; 104 this.seed = seed; 105 postConfig(); 106 } 107 108 @Override 109 public synchronized void postConfig() { 110 this.rng = new SplittableRandom(seed); 111 } 112 113 @Override 114 public String toString() { 115 StringBuilder buffer = new StringBuilder(); 116 117 buffer.append("AdaBoostTrainer("); 118 buffer.append("innerTrainer="); 119 buffer.append(innerTrainer.toString()); 120 buffer.append(",numMembers="); 121 buffer.append(numMembers); 122 buffer.append(",seed="); 123 buffer.append(seed); 124 buffer.append(")"); 125 126 return buffer.toString(); 127 } 128 129 /** 130 * If the trainer implements {@link WeightedExamples} then do boosting by weighting, 131 * otherwise do boosting by sampling. 132 * @param examples the data set containing the examples. 133 * @return A {@link WeightedEnsembleModel}. 134 */ 135 @Override 136 public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) { 137 if (examples.getOutputInfo().getUnknownCount() > 0) { 138 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 139 } 140 // Creates a new RNG, adds one to the invocation count. 141 SplittableRandom localRNG; 142 TrainerProvenance trainerProvenance; 143 synchronized(this) { 144 localRNG = rng.split(); 145 trainerProvenance = getProvenance(); 146 trainInvocationCounter++; 147 } 148 boolean weighted = innerTrainer instanceof WeightedExamples; 149 ImmutableFeatureMap featureIDs = examples.getFeatureIDMap(); 150 ImmutableOutputInfo<Label> labelIDs = examples.getOutputIDInfo(); 151 int numClasses = labelIDs.size(); 152 logger.log(Level.INFO,"NumClasses = " + numClasses); 153 ArrayList<Model<Label>> models = new ArrayList<>(); 154 float[] modelWeights = new float[numMembers]; 155 float[] exampleWeights = Util.generateUniformFloatVector(examples.size(), 1.0f/examples.size()); 156 if (weighted) { 157 logger.info("Using weighted Adaboost."); 158 examples = ImmutableDataset.copyDataset(examples); 159 for (int i = 0; i < examples.size(); i++) { 160 Example<Label> e = examples.getExample(i); 161 e.setWeight(exampleWeights[i]); 162 } 163 } else { 164 logger.info("Using sampling Adaboost."); 165 } 166 for (int i = 0; i < numMembers; i++) { 167 logger.info("Building model " + i); 168 Model<Label> newModel; 169 if (weighted) { 170 newModel = innerTrainer.train(examples); 171 } else { 172 DatasetView<Label> bag = DatasetView.createWeightedBootstrapView(examples,examples.size(),localRNG.nextLong(),exampleWeights,featureIDs,labelIDs); 173 newModel = innerTrainer.train(bag); 174 } 175 176 // 177 // Score this model 178 List<Prediction<Label>> predictions = newModel.predict(examples); 179 float accuracy = accuracy(predictions,examples,exampleWeights); 180 float error = 1.0f - accuracy; 181 float alpha = (float) (Math.log(accuracy/error) + Math.log(numClasses - 1)); 182 models.add(newModel); 183 modelWeights[i] = alpha; 184 if ((accuracy + 1e-10) > 1.0) { 185 // 186 // Perfect accuracy, can no longer boost. 187 float[] newModelWeights = Arrays.copyOf(modelWeights, models.size()); 188 newModelWeights[models.size()-1] = 1.0f; //Set the last weight to 1, as it's infinity. 189 logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model."); 190 logger.log(Level.FINE, "Model weights:"); 191 Util.logVector(logger, Level.FINE, newModelWeights); 192 EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models)); 193 return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),newModelWeights); 194 } 195 196 // 197 // Update the weights 198 for (int j = 0; j < predictions.size(); j++) { 199 if (!predictions.get(j).getOutput().equals(examples.getExample(j).getOutput())) { 200 exampleWeights[j] *= Math.exp(alpha); 201 } 202 } 203 Util.inplaceNormalizeToDistribution(exampleWeights); 204 if (weighted) { 205 for (int j = 0; j < examples.size(); j++) { 206 examples.getExample(j).setWeight(exampleWeights[j]); 207 } 208 } 209 } 210 logger.log(Level.FINE, "Model weights:"); 211 Util.logVector(logger, Level.FINE, modelWeights); 212 EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models)); 213 return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),modelWeights); 214 } 215 216 @Override 217 public int getInvocationCount() { 218 return trainInvocationCounter; 219 } 220 221 private float accuracy(List<Prediction<Label>> predictions, Dataset<Label> examples, float[] weights) { 222 float correctSum = 0; 223 float total = 0; 224 for (int i = 0; i < predictions.size(); i++) { 225 if (predictions.get(i).getOutput().equals(examples.getExample(i).getOutput())) { 226 correctSum += weights[i]; 227 } 228 total += weights[i]; 229 } 230 231 logger.log(Level.FINEST, "Correct count = " + correctSum + " size = " + examples.size()); 232 233 return correctSum / total; 234 } 235 236 @Override 237 public TrainerProvenance getProvenance() { 238 return new TrainerProvenanceImpl(this); 239 } 240}