001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.multilabel; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.MutableOutputInfo; 021 022import java.util.Map; 023 024/** 025 * A MutableOutputInfo for working with multi-label tasks. 026 */ 027public class MutableMultiLabelInfo extends MultiLabelInfo implements MutableOutputInfo<MultiLabel> { 028 private static final long serialVersionUID = 1L; 029 030 /** 031 * Package private constructor for building MutableMultiLabelInfo, used by {@link MultiLabelFactory}. 032 */ 033 MutableMultiLabelInfo() { 034 super(); 035 } 036 037 /** 038 * Construct a MutableMultiLabelInfo with it's state copied from another 039 * MultiLabelInfo. 040 * @param info The info to copy. 041 */ 042 public MutableMultiLabelInfo(MultiLabelInfo info) { 043 super(info); 044 } 045 046 /** 047 * Throws IllegalStateException if the MultiLabel contains a Label which has a "," in it. 048 * <p> 049 * Such labels are disallowed. There should be an exception thrown when one is constructed 050 * too. 051 * @param output The observed output. 052 */ 053 @Override 054 public void observe(MultiLabel output) { 055 if (output == MultiLabelFactory.UNKNOWN_MULTILABEL) { 056 unknownCount++; 057 } else { 058 for (String label : output.getNameSet()) { 059 if (label.contains(",")) { 060 throw new IllegalStateException("MultiLabel cannot use a Label which contains ','. The supplied label was " + label + "."); 061 } 062 MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong()); 063 labels.computeIfAbsent(label, MultiLabel::new); 064 value.increment(); 065 } 066 totalCount++; 067 } 068 } 069 070 @Override 071 public void clear() { 072 labelCounts.clear(); 073 totalCount = 0; 074 } 075 076 @Override 077 public MutableMultiLabelInfo copy() { 078 return new MutableMultiLabelInfo(this); 079 } 080 081 @Override 082 public String toReadableString() { 083 StringBuilder builder = new StringBuilder(); 084 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 085 if (builder.length() > 0) { 086 builder.append(", "); 087 } 088 builder.append('('); 089 builder.append(e.getKey()); 090 builder.append(','); 091 builder.append(e.getValue().longValue()); 092 builder.append(')'); 093 } 094 return builder.toString(); 095 } 096}