001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import com.oracle.labs.mlrg.olcut.util.MutableNumber;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.MutableOutputInfo;
024import org.tribuo.OutputInfo;
025import org.tribuo.classification.Label;
026
027import java.io.IOException;
028import java.util.HashMap;
029import java.util.HashSet;
030import java.util.Iterator;
031import java.util.Map;
032import java.util.Set;
033
034/**
035 * The base class for information about {@link MultiLabel} outputs.
036 */
037public abstract class MultiLabelInfo implements OutputInfo<MultiLabel> {
038    private static final long serialVersionUID = 1L;
039
040    protected final Map<String,MutableLong> labelCounts;
041    protected int unknownCount = 0;
042    protected transient Map<String,MultiLabel> labels;
043
044    protected int totalCount = 0;
045
046    /**
047     * Construct a MultiLabelInfo, initializing the various count variables.
048     */
049    MultiLabelInfo() {
050        labelCounts = new HashMap<>();
051        labels = new HashMap<>();
052    }
053
054    /**
055     * Copy the MultiLabelInfo. The copy ignores the unknown count.
056     * @param other The MultiLabelInfo to copy.
057     */
058    MultiLabelInfo(MultiLabelInfo other) {
059        labelCounts = MutableNumber.copyMap(other.labelCounts);
060        labels = new HashMap<>(other.labels);
061        totalCount = other.totalCount;
062    }
063
064    @Override
065    public int getUnknownCount() {
066        return unknownCount;
067    }
068
069    /**
070     * Returns a set of MultiLabel, where each has a single Label inside it.
071     * The set covers the space of Labels that this MultiLabelInfo has seen.
072     * @return The set of possible labels.
073     */
074    @Override
075    public Set<MultiLabel> getDomain() {
076        return new HashSet<>(labels.values());
077    }
078
079    /**
080     * Get the number of times this Label was observed, or 0 if unknown.
081     * @param label The Label to look for.
082     * @return A non-negative long.
083     */
084    public long getLabelCount(Label label) {
085        MutableLong l = labelCounts.get(label.getLabel());
086        if (l != null) {
087            return l.longValue();
088        } else {
089            return 0;
090        }
091    }
092
093    /**
094     * Get the number of times this String was observed, or 0 if unknown.
095     * @param label The String to look for.
096     * @return A non-negative long.
097     */
098    public long getLabelCount(String label) {
099        MutableLong l = labelCounts.get(label);
100        if (l != null) {
101            return l.longValue();
102        } else {
103            return 0;
104        }
105    }
106
107    @Override
108    public Iterable<Pair<String,Long>> outputCountsIterable() {
109        return () -> new Iterator<Pair<String, Long>>() {
110            Iterator<Map.Entry<String, MutableLong>> itr = labelCounts.entrySet().iterator();
111
112            @Override
113            public boolean hasNext() {
114                return itr.hasNext();
115            }
116
117            @Override
118            public Pair<String, Long> next() {
119                Map.Entry<String, MutableLong> e = itr.next();
120                return new Pair<>(e.getKey(), e.getValue().longValue());
121            }
122        };
123    }
124
125    @Override
126    public int size() {
127        return labelCounts.size();
128    }
129
130    @Override
131    public ImmutableOutputInfo<MultiLabel> generateImmutableOutputInfo() {
132        return new ImmutableMultiLabelInfo(this);
133    }
134
135    @Override
136    public MutableOutputInfo<MultiLabel> generateMutableOutputInfo() {
137        return new MutableMultiLabelInfo(this);
138    }
139
140    @Override
141    public abstract MultiLabelInfo copy();
142
143    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
144        in.defaultReadObject();
145        labels = new HashMap<>();
146        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
147            labels.put(e.getKey(),new MultiLabel(e.getKey()));
148        }
149    }
150}