001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.ensemble; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ListProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.Provenance; 022import org.tribuo.Dataset; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.Model; 026import org.tribuo.Output; 027import org.tribuo.Trainer; 028import org.tribuo.dataset.DatasetView; 029import org.tribuo.provenance.EnsembleModelProvenance; 030import org.tribuo.provenance.TrainerProvenance; 031import org.tribuo.provenance.impl.TrainerProvenanceImpl; 032 033import java.time.OffsetDateTime; 034import java.util.ArrayList; 035import java.util.Map; 036import java.util.SplittableRandom; 037import java.util.logging.Logger; 038 039/** 040 * A Trainer that wraps another trainer and produces a bagged ensemble. 041 * <p> 042 * A bagged ensemble is a set of models each of which was trained on a bootstrap sample of the 043 * original dataset, combined with an unweighted majority vote. 044 * <p> 045 * See: 046 * <pre> 047 * J. Friedman, T. Hastie, & R. Tibshirani. 048 * "The Elements of Statistical Learning" 049 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 050 * </pre> 051 */ 052public class BaggingTrainer<T extends Output<T>> implements Trainer<T> { 053 054 private static final Logger logger = Logger.getLogger(BaggingTrainer.class.getName()); 055 056 @Config(mandatory=true, description="The trainer to use for each ensemble member.") 057 protected Trainer<T> innerTrainer; 058 059 @Config(mandatory=true, description="The number of ensemble members to train.") 060 protected int numMembers; 061 062 @Config(mandatory=true, description="The seed for the RNG.") 063 protected long seed; 064 065 @Config(mandatory=true, description="The combination function to aggregate each ensemble member's outputs.") 066 protected EnsembleCombiner<T> combiner; 067 068 protected SplittableRandom rng; 069 070 protected int trainInvocationCounter; 071 072 /** 073 * For the configuration system. 074 */ 075 protected BaggingTrainer() { } 076 077 public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers) { 078 this(trainer, combiner, numMembers, Trainer.DEFAULT_SEED); 079 } 080 081 public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> combiner, int numMembers, long seed) { 082 this.innerTrainer = trainer; 083 this.combiner = combiner; 084 this.numMembers = numMembers; 085 this.seed = seed; 086 postConfig(); 087 } 088 089 @Override 090 public synchronized void postConfig() { 091 this.rng = new SplittableRandom(seed); 092 } 093 094 protected String ensembleName() { 095 return "bagging-ensemble"; 096 } 097 098 @Override 099 public String toString() { 100 StringBuilder buffer = new StringBuilder(); 101 102 buffer.append("BaggingTrainer("); 103 buffer.append("innerTrainer="); 104 buffer.append(innerTrainer.toString()); 105 buffer.append(",combiner="); 106 buffer.append(combiner.toString()); 107 buffer.append(",numMembers="); 108 buffer.append(numMembers); 109 buffer.append(",seed="); 110 buffer.append(seed); 111 buffer.append(")"); 112 113 return buffer.toString(); 114 } 115 116 @Override 117 public Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) { 118 // Creates a new RNG, adds one to the invocation count. 119 SplittableRandom localRNG; 120 TrainerProvenance trainerProvenance; 121 synchronized(this) { 122 localRNG = rng.split(); 123 trainerProvenance = getProvenance(); 124 trainInvocationCounter++; 125 } 126 ImmutableFeatureMap featureIDs = examples.getFeatureIDMap(); 127 ImmutableOutputInfo<T> labelIDs = examples.getOutputIDInfo(); 128 ArrayList<Model<T>> models = new ArrayList<>(); 129 for (int i = 0; i < numMembers; i++) { 130 logger.info("Building model " + i); 131 models.add(trainSingleModel(examples,featureIDs,labelIDs,localRNG,runProvenance)); 132 } 133 EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models)); 134 return new WeightedEnsembleModel<>(ensembleName(),provenance,featureIDs,labelIDs,models,combiner); 135 } 136 137 protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, SplittableRandom localRNG, Map<String,Provenance> runProvenance) { 138 DatasetView<T> bag = DatasetView.createBootstrapView(examples,examples.size(),localRNG.nextInt(),featureIDs,labelIDs); 139 Model<T> newModel = innerTrainer.train(bag,runProvenance); 140 return newModel; 141 } 142 143 @Override 144 public int getInvocationCount() { 145 return trainInvocationCounter; 146 } 147 148 @Override 149 public TrainerProvenance getProvenance() { 150 return new TrainerProvenanceImpl(this); 151 } 152}