001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.classification.Classifiable;
021import org.tribuo.classification.Label;
022
023import java.util.ArrayList;
024import java.util.Collections;
025import java.util.HashMap;
026import java.util.HashSet;
027import java.util.List;
028import java.util.Map;
029import java.util.Objects;
030import java.util.Set;
031import java.util.stream.Collectors;
032
033/**
034 * A class for multi-label classification.
035 * <p>
036 * Multi-label classification is where a (possibly empty) set of labels
037 * is predicted for each example. For example, predicting that a Reuters
038 * article has both the Finance and Sports labels.
039 */
040public class MultiLabel implements Classifiable<MultiLabel> {
041    private static final long serialVersionUID = 1L;
042
043    public static final String NEGATIVE_LABEL_STRING = "ML##NEGATIVE";
044    /**
045     * A Label representing the binary negative label. Used in binary
046     * approaches to multi-label classification to represent the absence
047     * of a Label.
048     */
049    public static final Label NEGATIVE_LABEL = new Label(NEGATIVE_LABEL_STRING);
050
051    private final String label;
052    private final double score;
053    private final Set<Label> labels;
054    private final Set<String> labelStrings;
055
056    /**
057     * Builds a MultiLabel object from a Set of Labels.
058     *
059     * Sets the whole set score to {@link Double#NaN}.
060     * @param labels A set of (possibly scored) labels.
061     */
062    public MultiLabel(Set<Label> labels) {
063        this(labels,Double.NaN);
064    }
065
066    /**
067     * Builds a MultiLabel object from a Set of Labels,
068     * when the whole set has a score as well as (optionally)
069     * the individual labels.
070     * @param labels A set of (possibly scored) labels.
071     * @param score An overall score for the set.
072     */
073    public MultiLabel(Set<Label> labels, double score) {
074        this.label = MultiLabelFactory.generateLabelString(labels);
075        this.score = score;
076        this.labels = Collections.unmodifiableSet(new HashSet<>(labels));
077        Set<String> temp = new HashSet<>();
078        for (Label l : labels) {
079            temp.add(l.getLabel());
080        }
081        this.labelStrings = Collections.unmodifiableSet(temp);
082    }
083
084    /**
085     * Builds a MultiLabel with a single String label.
086     *
087     * The created {@link Label} is unscored and used by MultiLabelInfo.
088     *
089     * Sets the whole set score to {@link Double#NaN}.
090     * @param label The label.
091     */
092    public MultiLabel(String label) {
093        this(new Label(label));
094    }
095
096    /**
097     * Builds a MultiLabel from a single Label.
098     *
099     * Sets the whole set score to {@link Double#NaN}.
100     * @param label The label.
101     */
102    public MultiLabel(Label label) {
103        this(Collections.singleton(label));
104    }
105
106    /**
107     * Creates a binary label from this multilabel.
108     * The returned Label is the input parameter if
109     * this MultiLabel contains that Label, and
110     * {@link MultiLabel#NEGATIVE_LABEL} otherwise.
111     * @param otherLabel The input label.
112     * @return A binarised form of this MultiLabel.
113     */
114    public Label createLabel(Label otherLabel) {
115        if (labelStrings.contains(otherLabel.getLabel())) {
116            return otherLabel;
117        } else {
118            return NEGATIVE_LABEL;
119        }
120    }
121
122    /**
123     * Returns a comma separated string representing
124     * the labels in this multilabel instance.
125     * @return A comma separated string of labels.
126     */
127    public String getLabelString() {
128        return label;
129    }
130
131    /**
132     * The overall score for this set of labels.
133     * @return The score for this MultiLabel.
134     */
135    public double getScore() {
136        return score;
137    }
138
139    /**
140     * The set of labels contained in this multilabel.
141     * @return The set of labels.
142     */
143    public Set<Label> getLabelSet() {
144        return new HashSet<>(labels);
145    }
146
147    /**
148     * The set of strings that represent the labels in this multilabel.
149     * @return The set of strings.
150     */
151    public Set<String> getNameSet() {
152        return new HashSet<>(labelStrings);
153    }
154
155    /**
156     * Does this MultiLabel contain this string?
157     * @param input A string representing a {@link Label}.
158     * @return True if the label string is in this MultiLabel.
159     */
160    public boolean contains(String input) {
161        return labelStrings.contains(input);
162    }
163
164    /**
165     * Does this MultiLabel contain this Label?
166     * @param input A {@link Label}.
167     * @return True if the label is in this MultiLabel.
168     */
169    public boolean contains(Label input) {
170        return labels.contains(input);
171    }
172
173    @Override
174    public boolean equals(Object o) {
175        if (this == o) return true;
176        if (o == null || getClass() != o.getClass()) return false;
177
178        MultiLabel that = (MultiLabel) o;
179
180        return labelStrings != null ? labelStrings.equals(that.labelStrings) : that.labelStrings == null;
181    }
182
183    @Override
184    public boolean fullEquals(MultiLabel o) {
185        if (this == o) return true;
186        if (o == null || getClass() != o.getClass()) return false;
187
188        if (Double.compare(score, o.score) != 0) {
189            return false;
190        }
191        Map<String,Double> thisMap = new HashMap<>();
192        for (Label l : labels) {
193            thisMap.put(l.getLabel(),l.getScore());
194        }
195        Map<String,Double> thatMap = new HashMap<>();
196        for (Label l : o.labels) {
197            thatMap.put(l.getLabel(),l.getScore());
198        }
199        if (thisMap.size() == thatMap.size()) {
200            for (Map.Entry<String,Double> e : thisMap.entrySet()) {
201                Double thisValue = e.getValue();
202                Double thatValue = thatMap.get(e.getKey());
203                if ((thatValue == null) || Double.compare(thisValue,thatValue) != 0) {
204                    return false;
205                }
206            }
207            return true;
208        } else {
209            return false;
210        }
211    }
212
213    @Override
214    public int hashCode() {
215        return Objects.hash(labelStrings);
216    }
217
218    @Override
219    public String toString() {
220        StringBuilder builder = new StringBuilder();
221
222        builder.append("(LabelSet={");
223        for (Label l : labels) {
224            builder.append(l.toString());
225            builder.append(',');
226        }
227        builder.deleteCharAt(builder.length()-1);
228        builder.append('}');
229        if (!Double.isNaN(score)) {
230            builder.append(",OverallScore=");
231            builder.append(score);
232        }
233        builder.append(")");
234
235        return builder.toString();
236    }
237
238    @Override
239    public MultiLabel copy() {
240        return new MultiLabel(labels,score);
241    }
242
243    /**
244     * For a MultiLabel with label set = {a, b, c}, outputs a string of the form:
245     * <pre>
246     * "a=true,b=true,c=true"
247     * </pre>
248     * If includeConfidence is set to true, outputs a string of the form:
249     * <pre>
250     * "a=true,b=true,c=true:0.5"
251     * </pre>
252     * where the last element after the colon is this label's score.
253     *
254     * @param includeConfidence Include whatever confidence score the label contains, if known.
255     * @return a comma-separated, densified string representation of this MultiLabel
256     */
257    @Override
258    public String getSerializableForm(boolean includeConfidence) {
259        /*
260         * Note: Due to the sparse implementation of MultiLabel, all 'key=value' pairs will have value=true. That is,
261         * say 'all possible labels' for a dataset are {R1,R2} but this particular example has label set = {R1}. Then
262         * this method will output only "R1=true",whereas one might expect "R1=true,R2=false". Nevertheless, we generate
263         * the 'serializable form' of this MultiLabel in this way to be consistent with that of other multi-output types
264         * such as MultipleRegressor.
265         */
266        String str = labels.stream()
267                .map(label -> String.format("%s=%b", label, true))
268                .collect(Collectors.joining(","));
269        if (includeConfidence) {
270            return str + ":" + score;
271        }
272        return str;
273    }
274
275    /**
276     * Parses a string of the form:
277     * dimension-name=output,...,dimension-name=output
278     * where output must be readable by {@link Boolean#parseBoolean(String)}.
279     * @param s The string form of a multi-label example.
280     * @return A {@link MultiLabel} parsed from the input string.
281     */
282    public static MultiLabel parseString(String s) {
283        return parseString(s,',');
284    }
285
286    /**
287     * Parses a string of the form:
288     * <pre>
289     * dimension-name=output&lt;splitChar&gt;...&lt;splitChar&gt;dimension-name=output
290     * </pre>
291     * where output must be readable by {@link Boolean#parseBoolean}.
292     * @param s The string form of a multilabel output.
293     * @param splitChar The char to split on.
294     * @return A {@link MultiLabel} output parsed from the input string.
295     */
296    public static MultiLabel parseString(String s, char splitChar) {
297        if (splitChar == '=') {
298            throw new IllegalArgumentException("Can't split on an equals symbol");
299        }
300        String[] tokens = s.split(""+splitChar);
301        List<Pair<String,Boolean>> pairs = new ArrayList<>();
302        for (String token : tokens) {
303            pairs.add(parseElement(token));
304        }
305        return createFromPairList(pairs);
306    }
307
308    /**
309     * Parses a string of the form:
310     *
311     * <pre>
312     *     class1=true
313     * </pre>
314     *
315     * OR of the form:
316     *
317     * <pre>
318     *     class1
319     * </pre>
320     *
321     * In the first case, the value in the "key=value" pair must be parseable by {@link Boolean#parseBoolean(String)}.
322     *
323     * TODO: Boolean.parseBoolean("1") returns false. We may want to think more carefully about this case.
324     *
325     * @param s The string form of a single dimension from a multilabel input.
326     * @return A tuple representing the dimension name and the value.
327     */
328    public static Pair<String,Boolean> parseElement(String s) {
329        if (s.isEmpty()) {
330            return new Pair<>("", false);
331        }
332        String[] split = s.split("=");
333        if (split.length == 2) {
334            //
335            // Case: "Class1=TRUE,Class2=FALSE"
336            return new Pair<>(split[0],Boolean.parseBoolean(split[1]));
337        } else if (split.length == 1) {
338            //
339            // Case: "Class1,Class2"
340            return new Pair<>(split[0], true);
341        } else {
342            throw new IllegalArgumentException("Failed to parse element " + s);
343        }
344    }
345
346    /**
347     * Creates a MultiLabel from a list of dimensions.
348     * @param dimensions The dimensions to use.
349     * @return A MultiLabel representing these dimensions.
350     */
351    public static MultiLabel createFromPairList(List<Pair<String,Boolean>> dimensions) {
352        Set<Label> labels = new HashSet<>();
353        for (int i = 0; i < dimensions.size(); i++) {
354            Pair<String,Boolean> p = dimensions.get(i);
355            String name = p.getA();
356            boolean value = p.getB();
357            if (value) {
358                labels.add(new Label(name));
359            }
360        }
361        return new MultiLabel(labels);
362    }
363}