001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification;
018
019import com.oracle.labs.mlrg.olcut.provenance.Provenance;
020import org.tribuo.ImmutableOutputInfo;
021import org.tribuo.MutableOutputInfo;
022import org.tribuo.OutputFactory;
023import org.tribuo.classification.evaluation.LabelEvaluation;
024import org.tribuo.classification.evaluation.LabelEvaluator;
025import org.tribuo.evaluation.Evaluator;
026import org.tribuo.provenance.OutputFactoryProvenance;
027
028import java.util.Map;
029
030/**
031 * A factory for making Label related classes.
032 * <p>
033 * Parses the Label by calling toString on the input.
034 * <p>
035 * Label factories have no state, and are all equal to each other.
036 */
037public final class LabelFactory implements OutputFactory<Label> {
038    private static final long serialVersionUID = 1L;
039
040    /**
041     * The singleton unknown label, used for unlablled examples.
042     */
043    public static final Label UNKNOWN_LABEL = new Label(Label.UNKNOWN);
044
045    private static final OutputFactoryProvenance provenance = new LabelFactoryProvenance();
046
047    private static final LabelEvaluator evaluator = new LabelEvaluator();
048
049    /**
050     * Constructs a label factory.
051     */
052    public LabelFactory() {}
053
054    /**
055     * Generates the Label string by calling toString
056     * on the input.
057     * @param label An input value.
058     * @param <V> The type of the input.
059     * @return A Label object.
060     */
061    @Override
062    public <V> Label generateOutput(V label) {
063        return new Label(label.toString());
064    }
065
066    @Override
067    public Label getUnknownOutput() {
068        return UNKNOWN_LABEL;
069    }
070
071    /**
072     * Generates an empty MutableLabelInfo.
073     * @return An empty MutableLabelInfo.
074     */
075    @Override
076    public MutableOutputInfo<Label> generateInfo() {
077        return new MutableLabelInfo();
078    }
079
080    @Override
081    public ImmutableOutputInfo<Label> constructInfoForExternalModel(Map<Label,Integer> mapping) {
082        // Validate inputs are dense
083        OutputFactory.validateMapping(mapping);
084
085        MutableLabelInfo info = new MutableLabelInfo();
086
087        for (Map.Entry<Label,Integer> e : mapping.entrySet()) {
088            info.observe(e.getKey());
089        }
090
091        return new ImmutableLabelInfo(info,mapping);
092    }
093
094    @Override
095    public Evaluator<Label,LabelEvaluation> getEvaluator() {
096        return evaluator;
097    }
098
099    @Override
100    public int hashCode() {
101        return "LabelFactory".hashCode();
102    }
103
104    @Override
105    public boolean equals(Object obj) {
106        return obj instanceof LabelFactory;
107    }
108
109    @Override
110    public OutputFactoryProvenance getProvenance() {
111        return provenance;
112    }
113
114    /**
115     * Provenance for {@link LabelFactory}.
116     */
117    public final static class LabelFactoryProvenance implements OutputFactoryProvenance {
118        private static final long serialVersionUID = 1L;
119
120        LabelFactoryProvenance() {}
121
122        /**
123         * Constructor used by the provenance serialization system.
124         * <p>
125         * As the label factory has no state, the argument is expected to be empty, and it's contents are ignored.
126         * @param map The provenance map to use.
127         */
128        public LabelFactoryProvenance(Map<String, Provenance> map) { }
129
130        @Override
131        public String getClassName() {
132            return LabelFactory.class.getName();
133        }
134
135        @Override
136        public String toString() {
137            return generateString("OutputFactory");
138        }
139
140        @Override
141        public boolean equals(Object other) {
142            return other instanceof LabelFactoryProvenance;
143        }
144
145        @Override
146        public int hashCode() {
147            return 31;
148        }
149    }
150}