001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.ensemble;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
021import com.oracle.labs.mlrg.olcut.provenance.Provenance;
022import org.tribuo.Dataset;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.Model;
026import org.tribuo.Output;
027import org.tribuo.Trainer;
028import org.tribuo.dataset.DatasetView;
029import org.tribuo.provenance.EnsembleModelProvenance;
030import org.tribuo.provenance.TrainerProvenance;
031import org.tribuo.provenance.impl.TrainerProvenanceImpl;
032
033import java.time.OffsetDateTime;
034import java.util.ArrayList;
035import java.util.Map;
036import java.util.SplittableRandom;
037import java.util.logging.Logger;
038
039/**
040 * A Trainer that wraps another trainer and produces a bagged ensemble.
041 * <p>
042 * A bagged ensemble is a set of models each of which was trained on a bootstrap sample of the
043 * original dataset, combined with an unweighted majority vote.
044 * <p>
045 * See:
046 * <pre>
047 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
048 * "The Elements of Statistical Learning"
049 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
050 * </pre>
051 */
052public class BaggingTrainer<T extends Output<T>> implements Trainer<T> {
053    
054    private static final Logger logger = Logger.getLogger(BaggingTrainer.class.getName());
055
056    @Config(mandatory=true, description="The trainer to use for each ensemble member.")
057    protected Trainer<T> innerTrainer;
058
059    @Config(mandatory=true, description="The number of ensemble members to train.")
060    protected int numMembers;
061
062    @Config(mandatory=true, description="The seed for the RNG.")
063    protected long seed;
064
065    @Config(mandatory=true, description="The combination function to aggregate each ensemble member's outputs.")
066    protected EnsembleCombiner<T> combiner;
067
068    protected SplittableRandom rng;
069
070    protected int trainInvocationCounter;
071
072    /**
073     * For the configuration system.
074     */
075    protected BaggingTrainer() { }
076
077    public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers) {
078        this(trainer, combiner, numMembers, Trainer.DEFAULT_SEED);
079    }
080
081    public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed) {
082        this.innerTrainer = trainer;
083        this.combiner = combiner;
084        this.numMembers = numMembers;
085        this.seed = seed;
086        postConfig();
087    }
088
089    @Override
090    public synchronized void postConfig() {
091        this.rng = new SplittableRandom(seed);
092    }
093
094    protected String ensembleName() {
095        return "bagging-ensemble";
096    }
097
098    @Override
099    public String toString() {
100        StringBuilder buffer = new StringBuilder();
101
102        buffer.append("BaggingTrainer(");
103        buffer.append("innerTrainer=");
104        buffer.append(innerTrainer.toString());
105        buffer.append(",combiner=");
106        buffer.append(combiner.toString());
107        buffer.append(",numMembers=");
108        buffer.append(numMembers);
109        buffer.append(",seed=");
110        buffer.append(seed);
111        buffer.append(")");
112
113        return buffer.toString();
114    }
115    
116    @Override
117    public Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
118        // Creates a new RNG, adds one to the invocation count.
119        SplittableRandom localRNG;
120        TrainerProvenance trainerProvenance;
121        synchronized(this) {
122            localRNG = rng.split();
123            trainerProvenance = getProvenance();
124            trainInvocationCounter++;
125        }
126        ImmutableFeatureMap featureIDs = examples.getFeatureIDMap();
127        ImmutableOutputInfo<T> labelIDs = examples.getOutputIDInfo();
128        ArrayList<Model<T>> models = new ArrayList<>();
129        for (int i = 0; i < numMembers; i++) {
130            logger.info("Building model " + i);
131            models.add(trainSingleModel(examples,featureIDs,labelIDs,localRNG,runProvenance));
132        }
133        EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
134        return new WeightedEnsembleModel<>(ensembleName(),provenance,featureIDs,labelIDs,models,combiner);
135    }
136
137    protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String,Provenance> runProvenance) {
138        DatasetView<T> bag = DatasetView.createBootstrapView(examples,examples.size(),localRNG.nextInt(),featureIDs,labelIDs);
139        Model<T> newModel = innerTrainer.train(bag,runProvenance);
140        return newModel;
141    }
142
143    @Override
144    public int getInvocationCount() {
145        return trainInvocationCounter;
146    }
147
148    @Override
149    public TrainerProvenance getProvenance() {
150        return new TrainerProvenanceImpl(this);
151    }
152}