001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import com.oracle.labs.mlrg.olcut.util.Pair; 021import org.tribuo.ImmutableOutputInfo; 022 023import java.io.IOException; 024import java.util.Collections; 025import java.util.HashMap; 026import java.util.HashSet; 027import java.util.Iterator; 028import java.util.Map; 029import java.util.Set; 030import java.util.logging.Level; 031import java.util.logging.Logger; 032 033/** 034 * An ImmutableOutputInfo for working with multi-label tasks. 035 */ 036public class ImmutableMultiLabelInfo extends MultiLabelInfo implements ImmutableOutputInfo<MultiLabel> { 037 private static final Logger logger = Logger.getLogger(ImmutableMultiLabelInfo.class.getName()); 038 039 private static final long serialVersionUID = 1L; 040 041 private final Map<Integer,String> idLabelMap; 042 043 private final Map<String,Integer> labelIDMap; 044 045 private transient Set<MultiLabel> domain; 046 047 private ImmutableMultiLabelInfo(ImmutableMultiLabelInfo info) { 048 super(info); 049 idLabelMap = new HashMap<>(); 050 idLabelMap.putAll(info.idLabelMap); 051 labelIDMap = new HashMap<>(); 052 labelIDMap.putAll(info.labelIDMap); 053 054 domain = Collections.unmodifiableSet(new HashSet<>(labels.values())); 055 } 056 057 ImmutableMultiLabelInfo(MultiLabelInfo info) { 058 super(info); 059 idLabelMap = new HashMap<>(); 060 labelIDMap = new HashMap<>(); 061 int counter = 0; 062 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 063 idLabelMap.put(counter,e.getKey()); 064 labelIDMap.put(e.getKey(),counter); 065 counter++; 066 } 067 068 domain = Collections.unmodifiableSet(new HashSet<>(labels.values())); 069 } 070 071 ImmutableMultiLabelInfo(MutableMultiLabelInfo info, Map<MultiLabel, Integer> mapping) { 072 super(info); 073 if (mapping.size() != info.size()) { 074 throw new IllegalStateException("Mapping and info come from different sources, mapping.size() = " + mapping.size() + ", info.size() = " + info.size()); 075 } 076 077 idLabelMap = new HashMap<>(); 078 labelIDMap = new HashMap<>(); 079 for (Map.Entry<MultiLabel,Integer> e : mapping.entrySet()) { 080 MultiLabel ml = e.getKey(); 081 Set<String> names = ml.getNameSet(); 082 if (names.size() == 1) { 083 String name = names.iterator().next(); 084 idLabelMap.put(e.getValue(), name); 085 labelIDMap.put(name, e.getValue()); 086 } else { 087 throw new IllegalArgumentException("Mapping must contain a single label per id, but contains " + names + " -> " + e.getValue()); 088 } 089 } 090 091 domain = Collections.unmodifiableSet(new HashSet<>(labels.values())); 092 } 093 094 @Override 095 public Set<MultiLabel> getDomain() { 096 return domain; 097 } 098 099 @Override 100 public int getID(MultiLabel output) { 101 return labelIDMap.getOrDefault(output.getLabelString(), -1); 102 } 103 104 @Override 105 public MultiLabel getOutput(int id) { 106 String label = idLabelMap.get(id); 107 if (label != null) { 108 return labels.get(label); 109 } else { 110 logger.log(Level.INFO, "No entry found for id " + id); 111 return null; 112 } 113 } 114 115 @Override 116 public long getTotalObservations() { 117 return totalCount; 118 } 119 120 /** 121 * Gets the count of the label occurrence for the specified id number, or 0 if it's unknown. 122 * @param id The label id. 123 * @return The label count. 124 */ 125 public long getLabelCount(int id) { 126 String label = idLabelMap.get(id); 127 if (label != null) { 128 MutableLong l = labelCounts.get(label); 129 return l.longValue(); 130 } else { 131 return 0; 132 } 133 } 134 135 @Override 136 public ImmutableMultiLabelInfo copy() { 137 return new ImmutableMultiLabelInfo(this); 138 } 139 140 @Override 141 public String toReadableString() { 142 StringBuilder builder = new StringBuilder(); 143 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 144 if (builder.length() > 0) { 145 builder.append(", "); 146 } 147 builder.append('('); 148 builder.append(labelIDMap.get(e.getKey())); 149 builder.append(','); 150 builder.append(e.getKey()); 151 builder.append(','); 152 builder.append(e.getValue().longValue()); 153 builder.append(')'); 154 } 155 return builder.toString(); 156 } 157 158 @Override 159 public Iterator<Pair<Integer, MultiLabel>> iterator() { 160 return new ImmutableInfoIterator(idLabelMap); 161 } 162 163 private static class ImmutableInfoIterator implements Iterator<Pair<Integer,MultiLabel>> { 164 165 private final Iterator<Map.Entry<Integer,String>> itr; 166 167 public ImmutableInfoIterator(Map<Integer,String> idLabelMap) { 168 itr = idLabelMap.entrySet().iterator(); 169 } 170 171 @Override 172 public boolean hasNext() { 173 return itr.hasNext(); 174 } 175 176 @Override 177 public Pair<Integer, MultiLabel> next() { 178 Map.Entry<Integer,String> e = itr.next(); 179 return new Pair<>(e.getKey(),new MultiLabel(e.getValue())); 180 } 181 } 182 183 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 184 in.defaultReadObject(); 185 domain = Collections.unmodifiableSet(new HashSet<>(labels.values())); 186 } 187}