001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel.baseline;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.classification.Label;
027import org.tribuo.multilabel.MultiLabel;
028import org.tribuo.provenance.ModelProvenance;
029
030import java.util.ArrayList;
031import java.util.HashMap;
032import java.util.HashSet;
033import java.util.List;
034import java.util.Map;
035import java.util.Optional;
036import java.util.Set;
037
038/**
039 * A {@link Model} which wraps n binary models, where n is the
040 * size of the MultiLabel domain. Each model independently predicts
041 * a single binary label.
042 * <p>
043 * It is possible for the prediction to produce an empty MultiLabel
044 * when none of the binary Labels were predicted.
045 */
046public class IndependentMultiLabelModel extends Model<MultiLabel> {
047    private static final long serialVersionUID = 1L;
048
049    private final List<Model<Label>> models;
050    private final List<Label> labels;
051
052    /**
053     * The list of Label and list of Model must be in the same order, and have a bijection.
054     * @param labels The list of labels this model was trained on.
055     * @param models The list of individual binary models.
056     * @param description A description of the trainer.
057     * @param featureMap The feature domain used in training.
058     * @param labelInfo The label domain used in training.
059     */
060    IndependentMultiLabelModel(List<Label> labels, List<Model<Label>> models, ModelProvenance description, ImmutableFeatureMap featureMap, ImmutableOutputInfo<MultiLabel> labelInfo) {
061        super(null,description,featureMap,labelInfo,models.get(0).generatesProbabilities());
062        this.labels = labels;
063        this.models = models;
064    }
065
066    @Override
067    public Prediction<MultiLabel> predict(Example<MultiLabel> example) {
068        Set<Label> predictedLabels = new HashSet<>();
069        BinaryExample e = new BinaryExample(example,null);
070        int numUsed = 0;
071        for (Model<Label> m : models) {
072            Prediction<Label> p = m.predict(e);
073            if (numUsed < p.getNumActiveFeatures()) {
074                numUsed = p.getNumActiveFeatures();
075            }
076            if (!p.getOutput().getLabel().equals(MultiLabel.NEGATIVE_LABEL_STRING)) {
077                predictedLabels.add(p.getOutput());
078            }
079        }
080        return new Prediction<>(new MultiLabel(predictedLabels),numUsed,example);
081    }
082
083    /**
084     * This aggregates the top features from each of the models.
085     * <p>
086     * If the individual models support per label features, then only the features
087     * for the positive label are aggregated.
088     * <p>
089     * @param n the number of features to return. If this value is less than 0,
090     * all features should be returned for each class.
091     * @return The top n features.
092     */
093    @Override
094    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
095        Map<String,List<Pair<String,Double>>> map = new HashMap<>();
096        for (int i = 0; i < models.size(); i++) {
097            Model<Label> m = models.get(i);
098            String label = labels.get(i).getLabel();
099            Map<String,List<Pair<String,Double>>> modelMap = m.getTopFeatures(n);
100            if (modelMap != null) {
101                if (modelMap.size() == 1) {
102                    map.put(label,modelMap.get(Model.ALL_OUTPUTS));
103                } else {
104                    map.merge(label,modelMap.get(label),(List<Pair<String,Double>> l, List<Pair<String,Double>> r) -> {l.addAll(r); return l;});
105                }
106            }
107        }
108        return map;
109    }
110
111    @Override
112    public Optional<Excuse<MultiLabel>> getExcuse(Example<MultiLabel> example) {
113        //TODO implement this to return the per label excuses.
114        return Optional.empty();
115    }
116
117    @Override
118    protected IndependentMultiLabelModel copy(String newName, ModelProvenance newProvenance) {
119        List<Model<Label>> newModels = new ArrayList<>();
120        for (Model<Label> e : models) {
121            newModels.add(e.copy());
122        }
123        return new IndependentMultiLabelModel(labels,newModels,newProvenance,featureIDMap,outputIDInfo);
124    }
125}