001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import com.oracle.labs.mlrg.olcut.util.Pair;
021import org.tribuo.ImmutableOutputInfo;
022
023import java.io.IOException;
024import java.util.Collections;
025import java.util.HashMap;
026import java.util.HashSet;
027import java.util.Iterator;
028import java.util.Map;
029import java.util.Set;
030import java.util.logging.Level;
031import java.util.logging.Logger;
032
033/**
034 * An ImmutableOutputInfo object for Labels.
035 * <p>
036 * Gives each unique label an id number. Also counts each label occurrence like {@link MutableLabelInfo} does,
037 * though the counts are frozen in this object.
038 */
039public class ImmutableLabelInfo extends LabelInfo implements ImmutableOutputInfo<Label> {
040    private static final Logger logger = Logger.getLogger(ImmutableLabelInfo.class.getName());
041
042    private static final long serialVersionUID = 1L;
043
044    private final Map<Integer,String> idLabelMap;
045
046    private final Map<String,Integer> labelIDMap;
047
048    private transient Set<Label> domain;
049
050    private ImmutableLabelInfo(ImmutableLabelInfo info) {
051        super(info);
052        idLabelMap = new HashMap<>();
053        idLabelMap.putAll(info.idLabelMap);
054        labelIDMap = new HashMap<>();
055        labelIDMap.putAll(info.labelIDMap);
056        domain = Collections.unmodifiableSet(new HashSet<>(labels.values()));
057    }
058
059    ImmutableLabelInfo(LabelInfo info) {
060        super(info);
061        idLabelMap = new HashMap<>();
062        labelIDMap = new HashMap<>();
063        int counter = 0;
064        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
065            idLabelMap.put(counter,e.getKey());
066            labelIDMap.put(e.getKey(),counter);
067            counter++;
068        }
069        domain = Collections.unmodifiableSet(new HashSet<>(labels.values()));
070    }
071
072    ImmutableLabelInfo(LabelInfo info, Map<Label,Integer> mapping) {
073        super(info);
074        if (mapping.size() != info.size()) {
075            throw new IllegalStateException("Mapping and info come from different sources, mapping.size() = " + mapping.size() + ", info.size() = " + info.size());
076        }
077
078        idLabelMap = new HashMap<>();
079        labelIDMap = new HashMap<>();
080        for (Map.Entry<Label,Integer> e : mapping.entrySet()) {
081            idLabelMap.put(e.getValue(),e.getKey().label);
082            labelIDMap.put(e.getKey().label,e.getValue());
083        }
084        domain = Collections.unmodifiableSet(new HashSet<>(labels.values()));
085    }
086
087    /**
088     * Returns the set of possible {@link Label}s that this LabelInfo has seen.
089     *
090     * Each label has the default score of Double.NaN.
091     * @return The set of possible labels.
092     */
093    @Override
094    public Set<Label> getDomain() {
095        return domain;
096    }
097
098    @Override
099    public int getID(Label output) {
100        return labelIDMap.getOrDefault(output.getLabel(),-1);
101    }
102
103    @Override
104    public Label getOutput(int id) {
105        String label = idLabelMap.get(id);
106        if (label != null) {
107            return labels.get(label);
108        } else {
109            logger.log(Level.INFO,"No entry found for id " + id);
110            return null;
111        }
112    }
113
114    @Override
115    public long getTotalObservations() {
116        long count = 0;
117        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
118            count += e.getValue().longValue();
119        }
120        return count;
121    }
122
123    /**
124     * Returns the number of times the supplied id was observed before this LabelInfo was frozen.
125     * @param id The id number.
126     * @return The count.
127     */
128    public long getLabelCount(int id) {
129        String label = idLabelMap.get(id);
130        if (label != null) {
131            MutableLong l = labelCounts.get(label);
132            return l.longValue();
133        } else {
134            return 0;
135        }
136    }
137
138    @Override
139    public ImmutableLabelInfo copy() {
140        return new ImmutableLabelInfo(this);
141    }
142
143    @Override
144    public String toReadableString() {
145        StringBuilder builder = new StringBuilder();
146        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
147            if (builder.length() > 0) {
148                builder.append(", ");
149            }
150            builder.append('(');
151            builder.append(labelIDMap.get(e.getKey()));
152            builder.append(',');
153            builder.append(e.getKey());
154            builder.append(',');
155            builder.append(e.getValue().longValue());
156            builder.append(')');
157        }
158        return builder.toString();
159    }
160
161    @Override
162    public Iterator<Pair<Integer, Label>> iterator() {
163        return new ImmutableInfoIterator(idLabelMap);
164    }
165
166    /**
167     * An iterator that converts {@link Map.Entry} into {@link Pair}s on the way out.
168     */
169    private static class ImmutableInfoIterator implements Iterator<Pair<Integer,Label>> {
170
171        private final Iterator<Map.Entry<Integer,String>> itr;
172
173        public ImmutableInfoIterator(Map<Integer,String> idLabelMap) {
174            itr = idLabelMap.entrySet().iterator();
175        }
176
177        @Override
178        public boolean hasNext() {
179            return itr.hasNext();
180        }
181
182        @Override
183        public Pair<Integer, Label> next() {
184            Map.Entry<Integer,String> e = itr.next();
185            return new Pair<>(e.getKey(),new Label(e.getValue()));
186        }
187    }
188
189    private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
190        in.defaultReadObject();
191
192        domain = Collections.unmodifiableSet(new HashSet<>(labels.values()));
193    }
194}