001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.multilabel;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import org.tribuo.MutableOutputInfo;
021
022import java.util.Map;
023
024/**
025 * A MutableOutputInfo for working with multi-label tasks.
026 */
027public class MutableMultiLabelInfo extends MultiLabelInfo implements MutableOutputInfo<MultiLabel> {
028    private static final long serialVersionUID = 1L;
029
030    /**
031     * Package private constructor for building MutableMultiLabelInfo, used by {@link MultiLabelFactory}.
032     */
033    MutableMultiLabelInfo() {
034        super();
035    }
036
037    /**
038     * Construct a MutableMultiLabelInfo with it's state copied from another
039     * MultiLabelInfo.
040     * @param info The info to copy.
041     */
042    public MutableMultiLabelInfo(MultiLabelInfo info) {
043        super(info);
044    }
045
046    /**
047     * Throws IllegalStateException if the MultiLabel contains a Label which has a "," in it.
048     * <p>
049     * Such labels are disallowed. There should be an exception thrown when one is constructed
050     * too.
051     * @param output The observed output.
052     */
053    @Override
054    public void observe(MultiLabel output) {
055        if (output == MultiLabelFactory.UNKNOWN_MULTILABEL) {
056            unknownCount++;
057        } else {
058            for (String label : output.getNameSet()) {
059                if (label.contains(",")) {
060                    throw new IllegalStateException("MultiLabel cannot use a Label which contains ','. The supplied label was " + label + ".");
061                }
062                MutableLong value = labelCounts.computeIfAbsent(label, k -> new MutableLong());
063                labels.computeIfAbsent(label, MultiLabel::new);
064                value.increment();
065            }
066            totalCount++;
067        }
068    }
069
070    @Override
071    public void clear() {
072        labelCounts.clear();
073        totalCount = 0;
074    }
075
076    @Override
077    public MutableMultiLabelInfo copy() {
078        return new MutableMultiLabelInfo(this);
079    }
080
081    @Override
082    public String toReadableString() {
083        StringBuilder builder = new StringBuilder();
084        for (Map.Entry<String,MutableLong> e : labelCounts.entrySet()) {
085            if (builder.length() > 0) {
086                builder.append(", ");
087            }
088            builder.append('(');
089            builder.append(e.getKey());
090            builder.append(',');
091            builder.append(e.getValue().longValue());
092            builder.append(')');
093        }
094        return builder.toString();
095    }
096}