001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel.baseline;
018
019import com.oracle.labs.mlrg.olcut.config.Config;
020import com.oracle.labs.mlrg.olcut.provenance.Provenance;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.Model;
025import org.tribuo.MutableDataset;
026import org.tribuo.Trainer;
027import org.tribuo.classification.Label;
028import org.tribuo.classification.LabelFactory;
029import org.tribuo.multilabel.ImmutableMultiLabelInfo;
030import org.tribuo.multilabel.MultiLabel;
031import org.tribuo.provenance.DatasetProvenance;
032import org.tribuo.provenance.ModelProvenance;
033import org.tribuo.provenance.TrainerProvenance;
034import org.tribuo.provenance.impl.TrainerProvenanceImpl;
035
036import java.time.OffsetDateTime;
037import java.util.ArrayList;
038import java.util.Map;
039
040/**
041 * Trains n independent binary {@link Model}s, each of which predicts a single {@link Label}.
042 * <p>
043 * Then wraps it up in an {@link IndependentMultiLabelModel} to provide a {@link MultiLabel}
044 * prediction.
045 * <p>
046 * It trains each model sequentially, and could be optimised to train in parallel.
047 */
048public class IndependentMultiLabelTrainer implements Trainer<MultiLabel> {
049
050    @Config(mandatory = true,description="Trainer to use for each individual label.")
051    private Trainer<Label> innerTrainer;
052
053    private int trainInvocationCounter = 0;
054
055    /**
056     * for olcut.
057     */
058    private IndependentMultiLabelTrainer() {}
059
060    public IndependentMultiLabelTrainer(Trainer<Label> innerTrainer) {
061        this.innerTrainer = innerTrainer;
062    }
063
064    @Override
065    public Model<MultiLabel> train(Dataset<MultiLabel> examples, Map<String, Provenance> runProvenance) {
066        if (examples.getOutputInfo().getUnknownCount() > 0) {
067            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
068        }
069        ImmutableMultiLabelInfo labelInfo = (ImmutableMultiLabelInfo) examples.getOutputIDInfo();
070        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
071        ArrayList<Model<Label>> modelsList = new ArrayList<>();
072        ArrayList<Label> labelList = new ArrayList<>();
073        DatasetProvenance datasetProvenance = examples.getProvenance();
074        //TODO supply more suitable provenance showing it's a single dimension out of many.
075        MutableDataset<Label> trainingData = new MutableDataset<>(datasetProvenance, new LabelFactory());
076        for (MultiLabel l : labelInfo.getDomain()) {
077            Label label = new Label(l.getLabelString());
078            trainingData.clear();
079            labelList.add(label);
080            for (Example<MultiLabel> e : examples) {
081                Label newLabel = e.getOutput().createLabel(label);
082                // This sets the label in the new example to either l or MultiLabel.NEGATIVE_LABEL_STRING.
083                trainingData.add(new BinaryExample(e,newLabel));
084            }
085            modelsList.add(innerTrainer.train(trainingData));
086        }
087        ModelProvenance provenance = new ModelProvenance(IndependentMultiLabelModel.class.getName(), OffsetDateTime.now(), datasetProvenance, getProvenance(), runProvenance);
088        trainInvocationCounter++;
089        return new IndependentMultiLabelModel(labelList,modelsList,provenance,featureMap,labelInfo);
090    }
091
092    @Override
093    public int getInvocationCount() {
094        return trainInvocationCounter;
095    }
096
097    @Override
098    public String toString() {
099        return "IndependentMultiLabelTrainer(innerTrainer="+innerTrainer.toString()+")";
100    }
101
102    @Override
103    public TrainerProvenance getProvenance() {
104        return new TrainerProvenanceImpl(this);
105    }
106}
107