001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.MutableOutputInfo;
021
022import java.util.Map;
023
024/**
025 * A mutable {@link LabelInfo}. Can record new observations of Labels, incrementing the
026 * appropriate counts.
027 */
028public class MutableLabelInfo extends LabelInfo implements MutableOutputInfo<Label> {
029    private static final long serialVersionUID = 1L;
030
031    MutableLabelInfo() {
032        super();
033    }
034
035    /**
036     * Constructs a mutable deep copy of the supplied label info.
037     * @param info The info to copy.
038     */
039    public MutableLabelInfo(LabelInfo info) {
040        super(info);
041    }
042
043    @Override
044    public void observe(Label output) {
045        if (output == LabelFactory.UNKNOWN_LABEL) {
046            unknownCount++;
047        } else {
048            String label = output.getLabel();
049            MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
050            labels.computeIfAbsent(label, Label::new);
051            value.increment();
052        }
053    }
054
055    @Override
056    public void clear() {
057        labelCounts.clear();
058    }
059
060    @Override
061    public MutableLabelInfo copy() {
062        return new MutableLabelInfo(this);
063    }
064
065    @Override
066    public String toReadableString() {
067        StringBuilder builder = new StringBuilder();
068        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
069            if (builder.length() > 0) {
070                builder.append(", ");
071            }
072            builder.append('(');
073            builder.append(e.getKey());
074            builder.append(',');
075            builder.append(e.getValue().longValue());
076            builder.append(')');
077        }
078        return builder.toString();
079    }
080}