001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import com.oracle.labs.mlrg.olcut.util.MutableNumber; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.MutableOutputInfo; 024import org.tribuo.OutputInfo; 025 026import java.io.IOException; 027import java.util.HashMap; 028import java.util.HashSet; 029import java.util.Iterator; 030import java.util.Map; 031import java.util.Set; 032 033/** 034 * The base class for information about multi-class classification Labels. 035 */ 036public abstract class LabelInfo implements OutputInfo<Label> { 037 private static final long serialVersionUID = 1L; 038 039 /** 040 * The occurrence counts of each label. 041 */ 042 protected final Map<String,MutableLong> labelCounts; 043 /** 044 * The number of unknown labels this LabelInfo has seen. 045 */ 046 protected int unknownCount = 0; 047 /** 048 * The label domain. 049 */ 050 protected transient Map<String,Label> labels; 051 052 /** 053 * Constructs an empty label info. 054 */ 055 LabelInfo() { 056 labelCounts = new HashMap<>(); 057 labels = new HashMap<>(); 058 } 059 060 /** 061 * Copies the label info apart from the unknown count. 062 * @param other The label info to copy. 063 */ 064 LabelInfo(LabelInfo other) { 065 labelCounts = MutableNumber.copyMap(other.labelCounts); 066 labels = new HashMap<>(); 067 labels.putAll(other.labels); 068 } 069 070 @Override 071 public int getUnknownCount() { 072 return unknownCount; 073 } 074 075 /** 076 * Returns the set of possible {@link Label}s that this LabelInfo has seen. 077 * <p> 078 * Each label has the default score of Double.NaN. 079 * @return The set of possible labels. 080 */ 081 @Override 082 public Set<Label> getDomain() { 083 return new HashSet<>(labels.values()); 084 } 085 086 /** 087 * Gets the count of the supplied label, or 0 if the label is unknown. 088 * @param label A Label. 089 * @return A non-negative long. 090 */ 091 public long getLabelCount(Label label) { 092 MutableLong l = labelCounts.get(label.getLabel()); 093 if (l != null) { 094 return l.longValue(); 095 } else { 096 return 0; 097 } 098 } 099 100 /** 101 * Gets the count of the supplied label, or 0 if the label is unknown. 102 * @param label A String representing a Label. 103 * @return A non-negative long. 104 */ 105 public long getLabelCount(String label) { 106 MutableLong l = labelCounts.get(label); 107 if (l != null) { 108 return l.longValue(); 109 } else { 110 return 0; 111 } 112 } 113 114 @Override 115 public Iterable<Pair<String,Long>> outputCountsIterable() { 116 return () -> new Iterator<Pair<String,Long>>() { 117 Iterator<Map.Entry<String,MutableLong>> itr = labelCounts.entrySet().iterator(); 118 119 @Override 120 public boolean hasNext() { 121 return itr.hasNext(); 122 } 123 124 @Override 125 public Pair<String,Long> next() { 126 Map.Entry<String,MutableLong> e = itr.next(); 127 return new Pair<>(e.getKey(),e.getValue().longValue()); 128 } 129 }; 130 } 131 132 /** 133 * The number of unique {@link Label}s this LabelInfo has seen. 134 * @return The number of unique labels. 135 */ 136 @Override 137 public int size() { 138 return labelCounts.size(); 139 } 140 141 @Override 142 public ImmutableOutputInfo<Label> generateImmutableOutputInfo() { 143 return new ImmutableLabelInfo(this); 144 } 145 146 @Override 147 public MutableOutputInfo<Label> generateMutableOutputInfo() { 148 return new MutableLabelInfo(this); 149 } 150 151 @Override 152 public abstract LabelInfo copy(); 153 154 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 155 in.defaultReadObject(); 156 labels = new HashMap<>(); 157 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 158 labels.put(e.getKey(),new Label(e.getKey())); 159 } 160 } 161}