001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import com.oracle.labs.mlrg.olcut.util.Pair;
021import org.tribuo.ImmutableOutputInfo;
022
023import java.io.IOException;
024import java.util.Collections;
025import java.util.HashMap;
026import java.util.HashSet;
027import java.util.Iterator;
028import java.util.Map;
029import java.util.Set;
030import java.util.logging.Level;
031import java.util.logging.Logger;
032
033/**
034 * An ImmutableOutputInfo for working with multi-label tasks.
035 */
036public class ImmutableMultiLabelInfo extends MultiLabelInfo implements ImmutableOutputInfo<MultiLabel> {
037    private static final Logger logger = Logger.getLogger(ImmutableMultiLabelInfo.class.getName());
038
039    private static final long serialVersionUID = 1L;
040
041    private final Map<Integer,String> idLabelMap;
042
043    private final Map<String,Integer> labelIDMap;
044
045    private transient Set<MultiLabel> domain;
046
047    private ImmutableMultiLabelInfo(ImmutableMultiLabelInfo info) {
048        super(info);
049        idLabelMap = new HashMap<>();
050        idLabelMap.putAll(info.idLabelMap);
051        labelIDMap = new HashMap<>();
052        labelIDMap.putAll(info.labelIDMap);
053
054        domain = Collections.unmodifiableSet(new HashSet<>(labels.values()));
055    }
056
057    ImmutableMultiLabelInfo(MultiLabelInfo info) {
058        super(info);
059        idLabelMap = new HashMap<>();
060        labelIDMap = new HashMap<>();
061        int counter = 0;
062        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
063            idLabelMap.put(counter,e.getKey());
064            labelIDMap.put(e.getKey(),counter);
065            counter++;
066        }
067
068        domain = Collections.unmodifiableSet(new HashSet<>(labels.values()));
069    }
070
071    ImmutableMultiLabelInfo(MutableMultiLabelInfo info, Map<MultiLabel, Integer> mapping) {
072        super(info);
073        if (mapping.size() != info.size()) {
074            throw new IllegalStateException("Mapping and info come from different sources, mapping.size() = " + mapping.size() + ", info.size() = " + info.size());
075        }
076
077        idLabelMap = new HashMap<>();
078        labelIDMap = new HashMap<>();
079        for (Map.Entry<MultiLabel,Integer> e : mapping.entrySet()) {
080            MultiLabel ml = e.getKey();
081            Set<String> names = ml.getNameSet();
082            if (names.size() == 1) {
083                String name = names.iterator().next();
084                idLabelMap.put(e.getValue(), name);
085                labelIDMap.put(name, e.getValue());
086            } else {
087                throw new IllegalArgumentException("Mapping must contain a single label per id, but contains " + names + " -> " + e.getValue());
088            }
089        }
090
091        domain = Collections.unmodifiableSet(new HashSet<>(labels.values()));
092    }
093
094    @Override
095    public Set<MultiLabel> getDomain() {
096        return domain;
097    }
098
099    @Override
100    public int getID(MultiLabel output) {
101        return labelIDMap.getOrDefault(output.getLabelString(), -1);
102    }
103
104    @Override
105    public MultiLabel getOutput(int id) {
106        String label = idLabelMap.get(id);
107        if (label != null) {
108            return labels.get(label);
109        } else {
110            logger.log(Level.INFO, "No entry found for id " + id);
111            return null;
112        }
113    }
114
115    @Override
116    public long getTotalObservations() {
117        return totalCount;
118    }
119
120    /**
121     * Gets the count of the label occurrence for the specified id number, or 0 if it's unknown.
122     * @param id The label id.
123     * @return The label count.
124     */
125    public long getLabelCount(int id) {
126        String label = idLabelMap.get(id);
127        if (label != null) {
128            MutableLong l = labelCounts.get(label);
129            return l.longValue();
130        } else {
131            return 0;
132        }
133    }
134
135    @Override
136    public ImmutableMultiLabelInfo copy() {
137        return new ImmutableMultiLabelInfo(this);
138    }
139
140    @Override
141    public String toReadableString() {
142        StringBuilder builder = new StringBuilder();
143        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
144            if (builder.length() > 0) {
145                builder.append(", ");
146            }
147            builder.append('(');
148            builder.append(labelIDMap.get(e.getKey()));
149            builder.append(',');
150            builder.append(e.getKey());
151            builder.append(',');
152            builder.append(e.getValue().longValue());
153            builder.append(')');
154        }
155        return builder.toString();
156    }
157
158    @Override
159    public Iterator<Pair<Integer, MultiLabel>> iterator() {
160        return new ImmutableInfoIterator(idLabelMap);
161    }
162
163    private static class ImmutableInfoIterator implements Iterator<Pair<Integer,MultiLabel>> {
164
165        private final Iterator<Map.Entry<Integer,String>> itr;
166
167        public ImmutableInfoIterator(Map<Integer,String> idLabelMap) {
168            itr = idLabelMap.entrySet().iterator();
169        }
170
171        @Override
172        public boolean hasNext() {
173            return itr.hasNext();
174        }
175
176        @Override
177        public Pair<Integer, MultiLabel> next() {
178            Map.Entry<Integer,String> e = itr.next();
179            return new Pair<>(e.getKey(),new MultiLabel(e.getValue()));
180        }
181    }
182
183    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
184        in.defaultReadObject();
185        domain = Collections.unmodifiableSet(new HashSet<>(labels.values()));
186    }
187}