001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import com.oracle.labs.mlrg.olcut.util.MutableNumber;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.MutableOutputInfo;
024import org.tribuo.OutputInfo;
025
026import java.io.IOException;
027import java.util.HashMap;
028import java.util.HashSet;
029import java.util.Iterator;
030import java.util.Map;
031import java.util.Set;
032
033/**
034 * The base class for information about multi-class classification Labels.
035 */
036public abstract class LabelInfo implements OutputInfo<Label> {
037    private static final long serialVersionUID = 1L;
038
039    /**
040     * The occurrence counts of each label.
041     */
042    protected final Map<String,MutableLong> labelCounts;
043    /**
044     * The number of unknown labels this LabelInfo has seen.
045     */
046    protected int unknownCount = 0;
047    /**
048     * The label domain.
049     */
050    protected transient Map<String,Label> labels;
051
052    /**
053     * Constructs an empty label info.
054     */
055    LabelInfo() {
056        labelCounts = new HashMap<>();
057        labels = new HashMap<>();
058    }
059
060    /**
061     * Copies the label info apart from the unknown count.
062     * @param other The label info to copy.
063     */
064    LabelInfo(LabelInfo other) {
065        labelCounts = MutableNumber.copyMap(other.labelCounts);
066        labels = new HashMap<>();
067        labels.putAll(other.labels);
068    }
069
070    @Override
071    public int getUnknownCount() {
072        return unknownCount;
073    }
074
075    /**
076     * Returns the set of possible {@link Label}s that this LabelInfo has seen.
077     * <p>
078     * Each label has the default score of Double.NaN.
079     * @return The set of possible labels.
080     */
081    @Override
082    public Set<Label> getDomain() {
083        return new HashSet<>(labels.values());
084    }
085
086    /**
087     * Gets the count of the supplied label, or 0 if the label is unknown.
088     * @param label A Label.
089     * @return A non-negative long.
090     */
091    public long getLabelCount(Label label) {
092        MutableLong l = labelCounts.get(label.getLabel());
093        if (l != null) {
094            return l.longValue();
095        } else {
096            return 0;
097        }
098    }
099
100    /**
101     * Gets the count of the supplied label, or 0 if the label is unknown.
102     * @param label A String representing a Label.
103     * @return A non-negative long.
104     */
105    public long getLabelCount(String label) {
106        MutableLong l = labelCounts.get(label);
107        if (l != null) {
108            return l.longValue();
109        } else {
110            return 0;
111        }
112    }
113
114    @Override
115    public Iterable<Pair<String,Long>> outputCountsIterable() {
116        return () -> new Iterator<Pair<String,Long>>() {
117            Iterator<Map.Entry<String,MutableLong>> itr = labelCounts.entrySet().iterator();
118
119            @Override
120            public boolean hasNext() {
121                return itr.hasNext();
122            }
123
124            @Override
125            public Pair<String,Long> next() {
126                Map.Entry<String,MutableLong> e = itr.next();
127                return new Pair<>(e.getKey(),e.getValue().longValue());
128            }
129        };
130    }
131
132    /**
133     * The number of unique {@link Label}s this LabelInfo has seen.
134     * @return The number of unique labels.
135     */
136    @Override
137    public int size() {
138        return labelCounts.size();
139    }
140
141    @Override
142    public ImmutableOutputInfo<Label> generateImmutableOutputInfo() {
143        return new ImmutableLabelInfo(this);
144    }
145
146    @Override
147    public MutableOutputInfo<Label> generateMutableOutputInfo() {
148        return new MutableLabelInfo(this);
149    }
150
151    @Override
152    public abstract LabelInfo copy();
153
154    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
155        in.defaultReadObject();
156        labels = new HashMap<>();
157        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
158            labels.put(e.getKey(),new Label(e.getKey()));
159        }
160    }
161}