001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel.baseline; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.Provenance; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.Model; 025import org.tribuo.MutableDataset; 026import org.tribuo.Trainer; 027import org.tribuo.classification.Label; 028import org.tribuo.classification.LabelFactory; 029import org.tribuo.multilabel.ImmutableMultiLabelInfo; 030import org.tribuo.multilabel.MultiLabel; 031import org.tribuo.provenance.DatasetProvenance; 032import org.tribuo.provenance.ModelProvenance; 033import org.tribuo.provenance.TrainerProvenance; 034import org.tribuo.provenance.impl.TrainerProvenanceImpl; 035 036import java.time.OffsetDateTime; 037import java.util.ArrayList; 038import java.util.Map; 039 040/** 041 * Trains n independent binary {@link Model}s, each of which predicts a single {@link Label}. 042 * <p> 043 * Then wraps it up in an {@link IndependentMultiLabelModel} to provide a {@link MultiLabel} 044 * prediction. 045 * <p> 046 * It trains each model sequentially, and could be optimised to train in parallel. 047 */ 048public class IndependentMultiLabelTrainer implements Trainer<MultiLabel> { 049 050 @Config(mandatory = true,description="Trainer to use for each individual label.") 051 private Trainer<Label> innerTrainer; 052 053 private int trainInvocationCounter = 0; 054 055 /** 056 * for olcut. 057 */ 058 private IndependentMultiLabelTrainer() {} 059 060 public IndependentMultiLabelTrainer(Trainer<Label> innerTrainer) { 061 this.innerTrainer = innerTrainer; 062 } 063 064 @Override 065 public Model<MultiLabel> train(Dataset<MultiLabel> examples, Map<String, Provenance> runProvenance) { 066 if (examples.getOutputInfo().getUnknownCount() > 0) { 067 throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised."); 068 } 069 ImmutableMultiLabelInfo labelInfo = (ImmutableMultiLabelInfo) examples.getOutputIDInfo(); 070 ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); 071 ArrayList<Model<Label>> modelsList = new ArrayList<>(); 072 ArrayList<Label> labelList = new ArrayList<>(); 073 DatasetProvenance datasetProvenance = examples.getProvenance(); 074 //TODO supply more suitable provenance showing it's a single dimension out of many. 075 MutableDataset<Label> trainingData = new MutableDataset<>(datasetProvenance, new LabelFactory()); 076 for (MultiLabel l : labelInfo.getDomain()) { 077 Label label = new Label(l.getLabelString()); 078 trainingData.clear(); 079 labelList.add(label); 080 for (Example<MultiLabel> e : examples) { 081 Label newLabel = e.getOutput().createLabel(label); 082 // This sets the label in the new example to either l or MultiLabel.NEGATIVE_LABEL_STRING. 083 trainingData.add(new BinaryExample(e,newLabel)); 084 } 085 modelsList.add(innerTrainer.train(trainingData)); 086 } 087 ModelProvenance provenance = new ModelProvenance(IndependentMultiLabelModel.class.getName(), OffsetDateTime.now(), datasetProvenance, getProvenance(), runProvenance); 088 trainInvocationCounter++; 089 return new IndependentMultiLabelModel(labelList,modelsList,provenance,featureMap,labelInfo); 090 } 091 092 @Override 093 public int getInvocationCount() { 094 return trainInvocationCounter; 095 } 096 097 @Override 098 public String toString() { 099 return "IndependentMultiLabelTrainer(innerTrainer="+innerTrainer.toString()+")"; 100 } 101 102 @Override 103 public TrainerProvenance getProvenance() { 104 return new TrainerProvenanceImpl(this); 105 } 106} 107