001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel.baseline; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Prediction; 026import org.tribuo.classification.Label; 027import org.tribuo.multilabel.MultiLabel; 028import org.tribuo.provenance.ModelProvenance; 029 030import java.util.ArrayList; 031import java.util.HashMap; 032import java.util.HashSet; 033import java.util.List; 034import java.util.Map; 035import java.util.Optional; 036import java.util.Set; 037 038/** 039 * A {@link Model} which wraps n binary models, where n is the 040 * size of the MultiLabel domain. Each model independently predicts 041 * a single binary label. 042 * <p> 043 * It is possible for the prediction to produce an empty MultiLabel 044 * when none of the binary Labels were predicted. 045 */ 046public class IndependentMultiLabelModel extends Model<MultiLabel> { 047 private static final long serialVersionUID = 1L; 048 049 private final List<Model<Label>> models; 050 private final List<Label> labels; 051 052 /** 053 * The list of Label and list of Model must be in the same order, and have a bijection. 054 * @param labels The list of labels this model was trained on. 055 * @param models The list of individual binary models. 056 * @param description A description of the trainer. 057 * @param featureMap The feature domain used in training. 058 * @param labelInfo The label domain used in training. 059 */ 060 IndependentMultiLabelModel(List<Label> labels, List<Model<Label>> models, ModelProvenance description, ImmutableFeatureMap featureMap, ImmutableOutputInfo<MultiLabel> labelInfo) { 061 super(null,description,featureMap,labelInfo,models.get(0).generatesProbabilities()); 062 this.labels = labels; 063 this.models = models; 064 } 065 066 @Override 067 public Prediction<MultiLabel> predict(Example<MultiLabel> example) { 068 Set<Label> predictedLabels = new HashSet<>(); 069 BinaryExample e = new BinaryExample(example,null); 070 int numUsed = 0; 071 for (Model<Label> m : models) { 072 Prediction<Label> p = m.predict(e); 073 if (numUsed < p.getNumActiveFeatures()) { 074 numUsed = p.getNumActiveFeatures(); 075 } 076 if (!p.getOutput().getLabel().equals(MultiLabel.NEGATIVE_LABEL_STRING)) { 077 predictedLabels.add(p.getOutput()); 078 } 079 } 080 return new Prediction<>(new MultiLabel(predictedLabels),numUsed,example); 081 } 082 083 /** 084 * This aggregates the top features from each of the models. 085 * <p> 086 * If the individual models support per label features, then only the features 087 * for the positive label are aggregated. 088 * <p> 089 * @param n the number of features to return. If this value is less than 0, 090 * all features should be returned for each class. 091 * @return The top n features. 092 */ 093 @Override 094 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 095 Map<String,List<Pair<String,Double>>> map = new HashMap<>(); 096 for (int i = 0; i < models.size(); i++) { 097 Model<Label> m = models.get(i); 098 String label = labels.get(i).getLabel(); 099 Map<String,List<Pair<String,Double>>> modelMap = m.getTopFeatures(n); 100 if (modelMap != null) { 101 if (modelMap.size() == 1) { 102 map.put(label,modelMap.get(Model.ALL_OUTPUTS)); 103 } else { 104 map.merge(label,modelMap.get(label),(List<Pair<String,Double>> l, List<Pair<String,Double>> r) -> {l.addAll(r); return l;}); 105 } 106 } 107 } 108 return map; 109 } 110 111 @Override 112 public Optional<Excuse<MultiLabel>> getExcuse(Example<MultiLabel> example) { 113 //TODO implement this to return the per label excuses. 114 return Optional.empty(); 115 } 116 117 @Override 118 protected IndependentMultiLabelModel copy(String newName, ModelProvenance newProvenance) { 119 List<Model<Label>> newModels = new ArrayList<>(); 120 for (Model<Label> e : models) { 121 newModels.add(e.copy()); 122 } 123 return new IndependentMultiLabelModel(labels,newModels,newProvenance,featureIDMap,outputIDInfo); 124 } 125}