001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import com.oracle.labs.mlrg.olcut.util.MutableNumber; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.MutableOutputInfo; 024import org.tribuo.OutputInfo; 025import org.tribuo.classification.Label; 026 027import java.io.IOException; 028import java.util.HashMap; 029import java.util.HashSet; 030import java.util.Iterator; 031import java.util.Map; 032import java.util.Set; 033 034/** 035 * The base class for information about {@link MultiLabel} outputs. 036 */ 037public abstract class MultiLabelInfo implements OutputInfo<MultiLabel> { 038 private static final long serialVersionUID = 1L; 039 040 protected final Map<String,MutableLong> labelCounts; 041 protected int unknownCount = 0; 042 protected transient Map<String,MultiLabel> labels; 043 044 protected int totalCount = 0; 045 046 /** 047 * Construct a MultiLabelInfo, initializing the various count variables. 048 */ 049 MultiLabelInfo() { 050 labelCounts = new HashMap<>(); 051 labels = new HashMap<>(); 052 } 053 054 /** 055 * Copy the MultiLabelInfo. The copy ignores the unknown count. 056 * @param other The MultiLabelInfo to copy. 057 */ 058 MultiLabelInfo(MultiLabelInfo other) { 059 labelCounts = MutableNumber.copyMap(other.labelCounts); 060 labels = new HashMap<>(other.labels); 061 totalCount = other.totalCount; 062 } 063 064 @Override 065 public int getUnknownCount() { 066 return unknownCount; 067 } 068 069 /** 070 * Returns a set of MultiLabel, where each has a single Label inside it. 071 * The set covers the space of Labels that this MultiLabelInfo has seen. 072 * @return The set of possible labels. 073 */ 074 @Override 075 public Set<MultiLabel> getDomain() { 076 return new HashSet<>(labels.values()); 077 } 078 079 /** 080 * Get the number of times this Label was observed, or 0 if unknown. 081 * @param label The Label to look for. 082 * @return A non-negative long. 083 */ 084 public long getLabelCount(Label label) { 085 MutableLong l = labelCounts.get(label.getLabel()); 086 if (l != null) { 087 return l.longValue(); 088 } else { 089 return 0; 090 } 091 } 092 093 /** 094 * Get the number of times this String was observed, or 0 if unknown. 095 * @param label The String to look for. 096 * @return A non-negative long. 097 */ 098 public long getLabelCount(String label) { 099 MutableLong l = labelCounts.get(label); 100 if (l != null) { 101 return l.longValue(); 102 } else { 103 return 0; 104 } 105 } 106 107 @Override 108 public Iterable<Pair<String,Long>> outputCountsIterable() { 109 return () -> new Iterator<Pair<String, Long>>() { 110 Iterator<Map.Entry<String, MutableLong>> itr = labelCounts.entrySet().iterator(); 111 112 @Override 113 public boolean hasNext() { 114 return itr.hasNext(); 115 } 116 117 @Override 118 public Pair<String, Long> next() { 119 Map.Entry<String, MutableLong> e = itr.next(); 120 return new Pair<>(e.getKey(), e.getValue().longValue()); 121 } 122 }; 123 } 124 125 @Override 126 public int size() { 127 return labelCounts.size(); 128 } 129 130 @Override 131 public ImmutableOutputInfo<MultiLabel> generateImmutableOutputInfo() { 132 return new ImmutableMultiLabelInfo(this); 133 } 134 135 @Override 136 public MutableOutputInfo<MultiLabel> generateMutableOutputInfo() { 137 return new MutableMultiLabelInfo(this); 138 } 139 140 @Override 141 public abstract MultiLabelInfo copy(); 142 143 private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException { 144 in.defaultReadObject(); 145 labels = new HashMap<>(); 146 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 147 labels.put(e.getKey(),new MultiLabel(e.getKey())); 148 } 149 } 150}