001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.classification.Classifiable; 021import org.tribuo.classification.Label; 022 023import java.util.ArrayList; 024import java.util.Collections; 025import java.util.HashMap; 026import java.util.HashSet; 027import java.util.List; 028import java.util.Map; 029import java.util.Objects; 030import java.util.Set; 031import java.util.stream.Collectors; 032 033/** 034 * A class for multi-label classification. 035 * <p> 036 * Multi-label classification is where a (possibly empty) set of labels 037 * is predicted for each example. For example, predicting that a Reuters 038 * article has both the Finance and Sports labels. 039 */ 040public class MultiLabel implements Classifiable<MultiLabel> { 041 private static final long serialVersionUID = 1L; 042 043 public static final String NEGATIVE_LABEL_STRING = "ML##NEGATIVE"; 044 /** 045 * A Label representing the binary negative label. Used in binary 046 * approaches to multi-label classification to represent the absence 047 * of a Label. 048 */ 049 public static final Label NEGATIVE_LABEL = new Label(NEGATIVE_LABEL_STRING); 050 051 private final String label; 052 private final double score; 053 private final Set<Label> labels; 054 private final Set<String> labelStrings; 055 056 /** 057 * Builds a MultiLabel object from a Set of Labels. 058 * 059 * Sets the whole set score to {@link Double#NaN}. 060 * @param labels A set of (possibly scored) labels. 061 */ 062 public MultiLabel(Set<Label> labels) { 063 this(labels,Double.NaN); 064 } 065 066 /** 067 * Builds a MultiLabel object from a Set of Labels, 068 * when the whole set has a score as well as (optionally) 069 * the individual labels. 070 * @param labels A set of (possibly scored) labels. 071 * @param score An overall score for the set. 072 */ 073 public MultiLabel(Set<Label> labels, double score) { 074 this.label = MultiLabelFactory.generateLabelString(labels); 075 this.score = score; 076 this.labels = Collections.unmodifiableSet(new HashSet<>(labels)); 077 Set<String> temp = new HashSet<>(); 078 for (Label l : labels) { 079 temp.add(l.getLabel()); 080 } 081 this.labelStrings = Collections.unmodifiableSet(temp); 082 } 083 084 /** 085 * Builds a MultiLabel with a single String label. 086 * 087 * The created {@link Label} is unscored and used by MultiLabelInfo. 088 * 089 * Sets the whole set score to {@link Double#NaN}. 090 * @param label The label. 091 */ 092 public MultiLabel(String label) { 093 this(new Label(label)); 094 } 095 096 /** 097 * Builds a MultiLabel from a single Label. 098 * 099 * Sets the whole set score to {@link Double#NaN}. 100 * @param label The label. 101 */ 102 public MultiLabel(Label label) { 103 this(Collections.singleton(label)); 104 } 105 106 /** 107 * Creates a binary label from this multilabel. 108 * The returned Label is the input parameter if 109 * this MultiLabel contains that Label, and 110 * {@link MultiLabel#NEGATIVE_LABEL} otherwise. 111 * @param otherLabel The input label. 112 * @return A binarised form of this MultiLabel. 113 */ 114 public Label createLabel(Label otherLabel) { 115 if (labelStrings.contains(otherLabel.getLabel())) { 116 return otherLabel; 117 } else { 118 return NEGATIVE_LABEL; 119 } 120 } 121 122 /** 123 * Returns a comma separated string representing 124 * the labels in this multilabel instance. 125 * @return A comma separated string of labels. 126 */ 127 public String getLabelString() { 128 return label; 129 } 130 131 /** 132 * The overall score for this set of labels. 133 * @return The score for this MultiLabel. 134 */ 135 public double getScore() { 136 return score; 137 } 138 139 /** 140 * The set of labels contained in this multilabel. 141 * @return The set of labels. 142 */ 143 public Set<Label> getLabelSet() { 144 return new HashSet<>(labels); 145 } 146 147 /** 148 * The set of strings that represent the labels in this multilabel. 149 * @return The set of strings. 150 */ 151 public Set<String> getNameSet() { 152 return new HashSet<>(labelStrings); 153 } 154 155 /** 156 * Does this MultiLabel contain this string? 157 * @param input A string representing a {@link Label}. 158 * @return True if the label string is in this MultiLabel. 159 */ 160 public boolean contains(String input) { 161 return labelStrings.contains(input); 162 } 163 164 /** 165 * Does this MultiLabel contain this Label? 166 * @param input A {@link Label}. 167 * @return True if the label is in this MultiLabel. 168 */ 169 public boolean contains(Label input) { 170 return labels.contains(input); 171 } 172 173 @Override 174 public boolean equals(Object o) { 175 if (this == o) return true; 176 if (o == null || getClass() != o.getClass()) return false; 177 178 MultiLabel that = (MultiLabel) o; 179 180 return labelStrings != null ? labelStrings.equals(that.labelStrings) : that.labelStrings == null; 181 } 182 183 @Override 184 public boolean fullEquals(MultiLabel o) { 185 if (this == o) return true; 186 if (o == null || getClass() != o.getClass()) return false; 187 188 if (Double.compare(score, o.score) != 0) { 189 return false; 190 } 191 Map<String,Double> thisMap = new HashMap<>(); 192 for (Label l : labels) { 193 thisMap.put(l.getLabel(),l.getScore()); 194 } 195 Map<String,Double> thatMap = new HashMap<>(); 196 for (Label l : o.labels) { 197 thatMap.put(l.getLabel(),l.getScore()); 198 } 199 if (thisMap.size() == thatMap.size()) { 200 for (Map.Entry<String,Double> e : thisMap.entrySet()) { 201 Double thisValue = e.getValue(); 202 Double thatValue = thatMap.get(e.getKey()); 203 if ((thatValue == null) || Double.compare(thisValue,thatValue) != 0) { 204 return false; 205 } 206 } 207 return true; 208 } else { 209 return false; 210 } 211 } 212 213 @Override 214 public int hashCode() { 215 return Objects.hash(labelStrings); 216 } 217 218 @Override 219 public String toString() { 220 StringBuilder builder = new StringBuilder(); 221 222 builder.append("(LabelSet={"); 223 for (Label l : labels) { 224 builder.append(l.toString()); 225 builder.append(','); 226 } 227 builder.deleteCharAt(builder.length()-1); 228 builder.append('}'); 229 if (!Double.isNaN(score)) { 230 builder.append(",OverallScore="); 231 builder.append(score); 232 } 233 builder.append(")"); 234 235 return builder.toString(); 236 } 237 238 @Override 239 public MultiLabel copy() { 240 return new MultiLabel(labels,score); 241 } 242 243 /** 244 * For a MultiLabel with label set = {a, b, c}, outputs a string of the form: 245 * <pre> 246 * "a=true,b=true,c=true" 247 * </pre> 248 * If includeConfidence is set to true, outputs a string of the form: 249 * <pre> 250 * "a=true,b=true,c=true:0.5" 251 * </pre> 252 * where the last element after the colon is this label's score. 253 * 254 * @param includeConfidence Include whatever confidence score the label contains, if known. 255 * @return a comma-separated, densified string representation of this MultiLabel 256 */ 257 @Override 258 public String getSerializableForm(boolean includeConfidence) { 259 /* 260 * Note: Due to the sparse implementation of MultiLabel, all 'key=value' pairs will have value=true. That is, 261 * say 'all possible labels' for a dataset are {R1,R2} but this particular example has label set = {R1}. Then 262 * this method will output only "R1=true",whereas one might expect "R1=true,R2=false". Nevertheless, we generate 263 * the 'serializable form' of this MultiLabel in this way to be consistent with that of other multi-output types 264 * such as MultipleRegressor. 265 */ 266 String str = labels.stream() 267 .map(label -> String.format("%s=%b", label, true)) 268 .collect(Collectors.joining(",")); 269 if (includeConfidence) { 270 return str + ":" + score; 271 } 272 return str; 273 } 274 275 /** 276 * Parses a string of the form: 277 * dimension-name=output,...,dimension-name=output 278 * where output must be readable by {@link Boolean#parseBoolean(String)}. 279 * @param s The string form of a multi-label example. 280 * @return A {@link MultiLabel} parsed from the input string. 281 */ 282 public static MultiLabel parseString(String s) { 283 return parseString(s,','); 284 } 285 286 /** 287 * Parses a string of the form: 288 * <pre> 289 * dimension-name=output<splitChar>...<splitChar>dimension-name=output 290 * </pre> 291 * where output must be readable by {@link Boolean#parseBoolean}. 292 * @param s The string form of a multilabel output. 293 * @param splitChar The char to split on. 294 * @return A {@link MultiLabel} output parsed from the input string. 295 */ 296 public static MultiLabel parseString(String s, char splitChar) { 297 if (splitChar == '=') { 298 throw new IllegalArgumentException("Can't split on an equals symbol"); 299 } 300 String[] tokens = s.split(""+splitChar); 301 List<Pair<String,Boolean>> pairs = new ArrayList<>(); 302 for (String token : tokens) { 303 pairs.add(parseElement(token)); 304 } 305 return createFromPairList(pairs); 306 } 307 308 /** 309 * Parses a string of the form: 310 * 311 * <pre> 312 * class1=true 313 * </pre> 314 * 315 * OR of the form: 316 * 317 * <pre> 318 * class1 319 * </pre> 320 * 321 * In the first case, the value in the "key=value" pair must be parseable by {@link Boolean#parseBoolean(String)}. 322 * 323 * TODO: Boolean.parseBoolean("1") returns false. We may want to think more carefully about this case. 324 * 325 * @param s The string form of a single dimension from a multilabel input. 326 * @return A tuple representing the dimension name and the value. 327 */ 328 public static Pair<String,Boolean> parseElement(String s) { 329 if (s.isEmpty()) { 330 return new Pair<>("", false); 331 } 332 String[] split = s.split("="); 333 if (split.length == 2) { 334 // 335 // Case: "Class1=TRUE,Class2=FALSE" 336 return new Pair<>(split[0],Boolean.parseBoolean(split[1])); 337 } else if (split.length == 1) { 338 // 339 // Case: "Class1,Class2" 340 return new Pair<>(split[0], true); 341 } else { 342 throw new IllegalArgumentException("Failed to parse element " + s); 343 } 344 } 345 346 /** 347 * Creates a MultiLabel from a list of dimensions. 348 * @param dimensions The dimensions to use. 349 * @return A MultiLabel representing these dimensions. 350 */ 351 public static MultiLabel createFromPairList(List<Pair<String,Boolean>> dimensions) { 352 Set<Label> labels = new HashSet<>(); 353 for (int i = 0; i < dimensions.size(); i++) { 354 Pair<String,Boolean> p = dimensions.get(i); 355 String name = p.getA(); 356 boolean value = p.getB(); 357 if (value) { 358 labels.add(new Label(name)); 359 } 360 } 361 return new MultiLabel(labels); 362 } 363}