001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel;
018
019import com.oracle.labs.mlrg.olcut.provenance.Provenance;
020import com.oracle.labs.mlrg.olcut.util.Pair;
021import org.tribuo.ImmutableOutputInfo;
022import org.tribuo.MutableOutputInfo;
023import org.tribuo.OutputFactory;
024import org.tribuo.classification.Label;
025import org.tribuo.classification.LabelFactory;
026import org.tribuo.evaluation.Evaluator;
027import org.tribuo.multilabel.evaluation.MultiLabelEvaluation;
028import org.tribuo.multilabel.evaluation.MultiLabelEvaluator;
029import org.tribuo.provenance.OutputFactoryProvenance;
030
031import java.util.ArrayList;
032import java.util.Collection;
033import java.util.List;
034import java.util.Map;
035import java.util.Set;
036
037/**
038 * A factory for generating MultiLabel objects and their associated OutputInfo and Evaluator objects.
039 */
040public final class MultiLabelFactory implements OutputFactory<MultiLabel> {
041    private static final long serialVersionUID = 1L;
042
043    public static final MultiLabel UNKNOWN_MULTILABEL = new MultiLabel(LabelFactory.UNKNOWN_LABEL);
044
045    private static final MultiLabelFactoryProvenance provenance = new MultiLabelFactoryProvenance();
046
047    private static final MultiLabelEvaluator evaluator = new MultiLabelEvaluator();
048
049    /**
050     * Construct a MultiLabelFactory.
051     */
052    public MultiLabelFactory() {}
053
054    /**
055     * Parses the MultiLabel value either by toStringing the input and calling {@link MultiLabel#parseString}
056     * or if it's a {@link Collection} iterating over the elements calling toString on each element in turn and using
057     * {@link MultiLabel#parseElement}.
058     * @param label An input value.
059     * @param <V> The type of the input value.
060     * @return A MultiLabel
061     */
062    @Override
063    public <V> MultiLabel generateOutput(V label) {
064        if (label instanceof Collection) {
065            Collection<?> c = (Collection<?>) label;
066            List<Pair<String,Boolean>> dimensions = new ArrayList<>();
067            for (Object o : c) {
068                dimensions.add(MultiLabel.parseElement(o.toString()));
069            }
070            return MultiLabel.createFromPairList(dimensions);
071        }
072        return MultiLabel.parseString(label.toString());
073    }
074
075    @Override
076    public MultiLabel getUnknownOutput() {
077        return UNKNOWN_MULTILABEL;
078    }
079
080    @Override
081    public MutableOutputInfo<MultiLabel> generateInfo() {
082        return new MutableMultiLabelInfo();
083    }
084
085    @Override
086    public ImmutableOutputInfo<MultiLabel> constructInfoForExternalModel(Map<MultiLabel,Integer> mapping) {
087        // Validate inputs are dense
088        OutputFactory.validateMapping(mapping);
089
090        MutableMultiLabelInfo info = new MutableMultiLabelInfo();
091
092        for (Map.Entry<MultiLabel,Integer> e : mapping.entrySet()) {
093            info.observe(e.getKey());
094        }
095
096        return new ImmutableMultiLabelInfo(info,mapping);
097    }
098
099    @Override
100    public Evaluator<MultiLabel, MultiLabelEvaluation> getEvaluator() {
101        return evaluator;
102    }
103
104    @Override
105    public int hashCode() {
106        return "MultiLabelFactory".hashCode();
107    }
108
109    @Override
110    public boolean equals(Object obj) {
111        return obj instanceof MultiLabelFactory;
112    }
113
114    @Override
115    public OutputFactoryProvenance getProvenance() {
116        return provenance;
117    }
118
119    /**
120     * Generates a comma separated string of labels from a {@link Set} of {@link Label}.
121     * @param input A Set of Label objects.
122     * @return A (possibly empty) comma separated string.
123     */
124    public static String generateLabelString(Set<Label> input) {
125        if (input.isEmpty()) {
126            return "";
127        }
128        List<String> list = new ArrayList<>();
129        for (Label l : input) {
130            list.add(l.getLabel());
131        }
132        list.sort(String::compareTo);
133
134        StringBuilder builder = new StringBuilder();
135        for (String s : list) {
136            if (s.contains(",")) {
137                throw new IllegalStateException("MultiLabel cannot contain a label with a ',', found " + s + ".");
138            }
139            builder.append(s);
140            builder.append(',');
141        }
142        builder.deleteCharAt(builder.length()-1);
143        return builder.toString();
144    }
145
146    /**
147     * Provenance for {@link MultiLabelFactory}.
148     */
149    public final static class MultiLabelFactoryProvenance implements OutputFactoryProvenance {
150        private static final long serialVersionUID = 1L;
151
152        /**
153         * Constructs a multi-label factory provenance.
154         */
155        MultiLabelFactoryProvenance() {}
156
157        /**
158         * Constructs a multi-label factory provenance from the empty marshalled form.
159         * @param map An empty map.
160         */
161        public MultiLabelFactoryProvenance(Map<String, Provenance> map) { }
162
163        @Override
164        public String getClassName() {
165            return MultiLabelFactory.class.getName();
166        }
167
168        @Override
169        public String toString() {
170            return generateString("OutputFactory");
171        }
172
173        @Override
174        public boolean equals(Object other) {
175            return other instanceof MultiLabelFactoryProvenance;
176        }
177
178        @Override
179        public int hashCode() {
180            return 31;
181        }
182    }
183}