001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.ensemble;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import org.tribuo.Dataset;
023import org.tribuo.Example;
024import org.tribuo.ImmutableDataset;
025import org.tribuo.ImmutableFeatureMap;
026import org.tribuo.ImmutableOutputInfo;
027import org.tribuo.Model;
028import org.tribuo.Prediction;
029import org.tribuo.Trainer;
030import org.tribuo.WeightedExamples;
031import org.tribuo.classification.Label;
032import org.tribuo.dataset.DatasetView;
033import org.tribuo.ensemble.WeightedEnsembleModel;
034import org.tribuo.provenance.EnsembleModelProvenance;
035import org.tribuo.provenance.TrainerProvenance;
036import org.tribuo.provenance.impl.TrainerProvenanceImpl;
037import org.tribuo.util.Util;
038
039import java.time.OffsetDateTime;
040import java.util.ArrayList;
041import java.util.Arrays;
042import java.util.List;
043import java.util.Map;
044import java.util.SplittableRandom;
045import java.util.logging.Level;
046import java.util.logging.Logger;
047
048/**
049 * Implements Adaboost.SAMME one of the more popular algorithms for multiclass boosting.
050 * Based on  <a href="https://web.stanford.edu/~hastie/Papers/samme.pdf">this paper</a>.
051 * <p>
052 * If the trainer implements {@link WeightedExamples} then it performs boosting by weighting,
053 * otherwise it uses a weighted bootstrap sample.
054 * <p>
055 * See:
056 * <pre>
057 * J. Zhu, S. Rosset, H. Zou, T. Hastie.
058 * "Multi-class Adaboost"
059 * 2006.
060 * </pre>
061 */
062public class AdaBoostTrainer implements Trainer<Label> {
063
064    private static final Logger logger = Logger.getLogger(AdaBoostTrainer.class.getName());
065    
066    @Config(mandatory=true, description="The trainer to use to build each weak learner.")
067    protected Trainer<Label> innerTrainer;
068
069    @Config(mandatory=true, description="The number of ensemble members to train.")
070    protected int numMembers;
071
072    @Config(mandatory=true, description="The seed for the RNG.")
073    protected long seed;
074
075    protected SplittableRandom rng;
076
077    protected int trainInvocationCounter;
078
079    /**
080     * For the OLCUT configuration system.
081     */
082    private AdaBoostTrainer() { }
083
084    /**
085     * Constructs an adaboost trainer using the supplied weak learner trainer and the specified number of
086     * boosting rounds. Uses the default seed.
087     * @param trainer The weak learner trainer.
088     * @param numMembers The maximum number of boosting rounds.
089     */
090    public AdaBoostTrainer(Trainer<Label> trainer, int numMembers) {
091        this(trainer, numMembers, Trainer.DEFAULT_SEED);
092    }
093
094    /**
095     * Constructs an adaboost trainer using the supplied weak learner trainer, the specified number of
096     * boosting rounds and the supplied seed.
097     * @param trainer The weak learner trainer.
098     * @param numMembers The maximum number of boosting rounds.
099     * @param seed The RNG seed.
100     */
101    public AdaBoostTrainer(Trainer<Label> trainer, int numMembers, long seed) {
102        this.innerTrainer = trainer;
103        this.numMembers = numMembers;
104        this.seed = seed;
105        postConfig();
106    }
107
108    @Override
109    public synchronized void postConfig() {
110        this.rng = new SplittableRandom(seed);
111    }
112
113    @Override
114    public String toString() {
115        StringBuilder buffer = new StringBuilder();
116
117        buffer.append("AdaBoostTrainer(");
118        buffer.append("innerTrainer=");
119        buffer.append(innerTrainer.toString());
120        buffer.append(",numMembers=");
121        buffer.append(numMembers);
122        buffer.append(",seed=");
123        buffer.append(seed);
124        buffer.append(")");
125
126        return buffer.toString();
127    }
128
129    /**
130     * If the trainer implements {@link WeightedExamples} then do boosting by weighting,
131     * otherwise do boosting by sampling.
132     * @param examples the data set containing the examples.
133     * @return A {@link WeightedEnsembleModel}.
134     */
135    @Override
136    public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
137        if (examples.getOutputInfo().getUnknownCount() > 0) {
138            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
139        }
140        // Creates a new RNG, adds one to the invocation count.
141        SplittableRandom localRNG;
142        TrainerProvenance trainerProvenance;
143        synchronized(this) {
144            localRNG = rng.split();
145            trainerProvenance = getProvenance();
146            trainInvocationCounter++;
147        }
148        boolean weighted = innerTrainer instanceof WeightedExamples;
149        ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
150        ImmutableOutputInfo<Label> labelIDs = examples.getOutputIDInfo();
151        int numClasses = labelIDs.size();
152        logger.log(Level.INFO,"NumClasses = " + numClasses);
153        ArrayList<Model<Label>> models = new ArrayList<>();
154        float[] modelWeights = new float[numMembers];
155        float[] exampleWeights = Util.generateUniformFloatVector(examples.size(), 1.0f/examples.size());
156        if (weighted) {
157            logger.info("Using weighted Adaboost.");
158            examples = ImmutableDataset.copyDataset(examples);
159            for (int i = 0; i < examples.size(); i++) {
160                Example<Label> e = examples.getExample(i);
161                e.setWeight(exampleWeights[i]);
162            }
163        } else {
164            logger.info("Using sampling Adaboost.");
165        }
166        for (int i = 0; i < numMembers; i++) {
167            logger.info("Building model " + i);
168            Model<Label> newModel;
169            if (weighted) {
170                newModel = innerTrainer.train(examples);
171            } else {
172                DatasetView<Label> bag = DatasetView.createWeightedBootstrapView(examples,examples.size(),localRNG.nextLong(),exampleWeights,featureIDs,labelIDs);
173                newModel = innerTrainer.train(bag);
174            }
175
176            //
177            // Score this model
178            List<Prediction<Label>> predictions = newModel.predict(examples);
179            float accuracy = accuracy(predictions,examples,exampleWeights);
180            float error = 1.0f - accuracy;
181            float alpha = (float) (Math.log(accuracy/error) + Math.log(numClasses - 1));
182            models.add(newModel);
183            modelWeights[i] = alpha;
184            if ((accuracy + 1e-10) > 1.0) {
185                //
186                // Perfect accuracy, can no longer boost.
187                float[] newModelWeights = Arrays.copyOf(modelWeights, models.size());
188                newModelWeights[models.size()-1] = 1.0f; //Set the last weight to 1, as it's infinity.
189                logger.log(Level.FINE, "Perfect accuracy reached on iteration " + i + ", returning current model.");
190                logger.log(Level.FINE, "Model weights:");
191                Util.logVector(logger, Level.FINE, newModelWeights);
192                EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
193                return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),newModelWeights);
194            }
195            
196            //
197            // Update the weights
198            for (int j = 0; j < predictions.size(); j++) {
199                if (!predictions.get(j).getOutput().equals(examples.getExample(j).getOutput())) {
200                    exampleWeights[j] *= Math.exp(alpha);
201                }
202            }
203            Util.inplaceNormalizeToDistribution(exampleWeights);
204            if (weighted) {
205                for (int j = 0; j < examples.size(); j++) {
206                    examples.getExample(j).setWeight(exampleWeights[j]);
207                }
208            }
209        }
210        logger.log(Level.FINE, "Model weights:");
211        Util.logVector(logger, Level.FINE, modelWeights);
212        EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
213        return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),modelWeights);
214    }
215
216    @Override
217    public int getInvocationCount() {
218        return trainInvocationCounter;
219    }
220
221    private float accuracy(List<Prediction<Label>> predictions, Dataset<Label> examples, float[] weights) {
222        float correctSum = 0;
223        float total = 0;
224        for (int i = 0; i < predictions.size(); i++) {
225            if (predictions.get(i).getOutput().equals(examples.getExample(i).getOutput())) {
226                correctSum += weights[i];
227            }
228            total += weights[i];
229        }
230
231        logger.log(Level.FINEST, "Correct count = " + correctSum + " size = " + examples.size());
232
233        return correctSum / total;
234    }
235
236    @Override
237    public TrainerProvenance getProvenance() {
238        return new TrainerProvenanceImpl(this);
239    }
240}