001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import org.tribuo.MutableOutputInfo; 021 022import java.util.Map; 023 024/** 025 * A mutable {@link LabelInfo}. Can record new observations of Labels, incrementing the 026 * appropriate counts. 027 */ 028public class MutableLabelInfo extends LabelInfo implements MutableOutputInfo<Label> { 029 private static final long serialVersionUID = 1L; 030 031 MutableLabelInfo() { 032 super(); 033 } 034 035 /** 036 * Constructs a mutable deep copy of the supplied label info. 037 * @param info The info to copy. 038 */ 039 public MutableLabelInfo(LabelInfo info) { 040 super(info); 041 } 042 043 @Override 044 public void observe(Label output) { 045 if (output == LabelFactory.UNKNOWN_LABEL) { 046 unknownCount++; 047 } else { 048 String label = output.getLabel(); 049 MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong()); 050 labels.computeIfAbsent(label, Label::new); 051 value.increment(); 052 } 053 } 054 055 @Override 056 public void clear() { 057 labelCounts.clear(); 058 } 059 060 @Override 061 public MutableLabelInfo copy() { 062 return new MutableLabelInfo(this); 063 } 064 065 @Override 066 public String toReadableString() { 067 StringBuilder builder = new StringBuilder(); 068 for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) { 069 if (builder.length() > 0) { 070 builder.append(", "); 071 } 072 builder.append('('); 073 builder.append(e.getKey()); 074 builder.append(','); 075 builder.append(e.getValue().longValue()); 076 builder.append(')'); 077 } 078 return builder.toString(); 079 } 080}