001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import com.oracle.labs.mlrg.olcut.util.Pair; 021import org.tribuo.ImmutableOutputInfo; 022 023import java.io.IOException; 024import java.util.Collections; 025import java.util.HashMap; 026import java.util.HashSet; 027import java.util.Iterator; 028import java.util.Map; 029import java.util.Set; 030import java.util.logging.Level; 031import java.util.logging.Logger; 032 033/** 034 * An ImmutableOutputInfo object for Labels. 035 * <p> 036 * Gives each unique label an id number. Also counts each label occurrence like {@link MutableLabelInfo} does, 037 * though the counts are frozen in this object. 038 */ 039public class ImmutableLabelInfo extends LabelInfo implements ImmutableOutputInfo<Label> { 040 private static final Logger logger = Logger.getLogger(ImmutableLabelInfo.class.getName()); 041 042 private static final long serialVersionUID = 1L; 043 044 private final Map<Integer,String> idLabelMap; 045 046 private final Map<String,Integer> labelIDMap; 047 048 private transient Set<Label> domain; 049 050 private ImmutableLabelInfo(ImmutableLabelInfo info) { 051 super(info); 052 idLabelMap = new HashMap<>(); 053 idLabelMap.putAll(info.idLabelMap); 054 labelIDMap = new HashMap<>(); 055 labelIDMap.putAll(info.labelIDMap); 056 domain = Collections.unmodifiableSet(new HashSet<>(labels.values())); 057 } 058 059 ImmutableLabelInfo(LabelInfo info) { 060 super(info); 061 idLabelMap = new HashMap<>(); 062 labelIDMap = new HashMap<>(); 063 int counter = 0; 064 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 065 idLabelMap.put(counter,e.getKey()); 066 labelIDMap.put(e.getKey(),counter); 067 counter++; 068 } 069 domain = Collections.unmodifiableSet(new HashSet<>(labels.values())); 070 } 071 072 ImmutableLabelInfo(LabelInfo info, Map<Label,Integer> mapping) { 073 super(info); 074 if (mapping.size() != info.size()) { 075 throw new IllegalStateException("Mapping and info come from different sources, mapping.size() = " + mapping.size() + ", info.size() = " + info.size()); 076 } 077 078 idLabelMap = new HashMap<>(); 079 labelIDMap = new HashMap<>(); 080 for (Map.Entry<Label,Integer> e : mapping.entrySet()) { 081 idLabelMap.put(e.getValue(),e.getKey().label); 082 labelIDMap.put(e.getKey().label,e.getValue()); 083 } 084 domain = Collections.unmodifiableSet(new HashSet<>(labels.values())); 085 } 086 087 /** 088 * Returns the set of possible {@link Label}s that this LabelInfo has seen. 089 * 090 * Each label has the default score of Double.NaN. 091 * @return The set of possible labels. 092 */ 093 @Override 094 public Set<Label> getDomain() { 095 return domain; 096 } 097 098 @Override 099 public int getID(Label output) { 100 return labelIDMap.getOrDefault(output.getLabel(),-1); 101 } 102 103 @Override 104 public Label getOutput(int id) { 105 String label = idLabelMap.get(id); 106 if (label != null) { 107 return labels.get(label); 108 } else { 109 logger.log(Level.INFO,"No entry found for id " + id); 110 return null; 111 } 112 } 113 114 @Override 115 public long getTotalObservations() { 116 long count = 0; 117 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 118 count += e.getValue().longValue(); 119 } 120 return count; 121 } 122 123 /** 124 * Returns the number of times the supplied id was observed before this LabelInfo was frozen. 125 * @param id The id number. 126 * @return The count. 127 */ 128 public long getLabelCount(int id) { 129 String label = idLabelMap.get(id); 130 if (label != null) { 131 MutableLong l = labelCounts.get(label); 132 return l.longValue(); 133 } else { 134 return 0; 135 } 136 } 137 138 @Override 139 public ImmutableLabelInfo copy() { 140 return new ImmutableLabelInfo(this); 141 } 142 143 @Override 144 public String toReadableString() { 145 StringBuilder builder = new StringBuilder(); 146 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 147 if (builder.length() > 0) { 148 builder.append(", "); 149 } 150 builder.append('('); 151 builder.append(labelIDMap.get(e.getKey())); 152 builder.append(','); 153 builder.append(e.getKey()); 154 builder.append(','); 155 builder.append(e.getValue().longValue()); 156 builder.append(')'); 157 } 158 return builder.toString(); 159 } 160 161 @Override 162 public Iterator<Pair<Integer, Label>> iterator() { 163 return new ImmutableInfoIterator(idLabelMap); 164 } 165 166 /** 167 * An iterator that converts {@link Map.Entry} into {@link Pair}s on the way out. 168 */ 169 private static class ImmutableInfoIterator implements Iterator<Pair<Integer,Label>> { 170 171 private final Iterator<Map.Entry<Integer,String>> itr; 172 173 public ImmutableInfoIterator(Map<Integer,String> idLabelMap) { 174 itr = idLabelMap.entrySet().iterator(); 175 } 176 177 @Override 178 public boolean hasNext() { 179 return itr.hasNext(); 180 } 181 182 @Override 183 public Pair<Integer, Label> next() { 184 Map.Entry<Integer,String> e = itr.next(); 185 return new Pair<>(e.getKey(),new Label(e.getValue())); 186 } 187 } 188 189 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 190 in.defaultReadObject(); 191 192 domain = Collections.unmodifiableSet(new HashSet<>(labels.values())); 193 } 194}