001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import com.oracle.labs.mlrg.olcut.util.MutableNumber;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.MutableOutputInfo;
024import org.tribuo.OutputInfo;
025
026import java.util.HashMap;
027import java.util.HashSet;
028import java.util.Iterator;
029import java.util.Map;
030import java.util.Set;
031
032/**
033 * The base class for a ClusterID OutputInfo.
034 */
035public abstract class ClusteringInfo implements OutputInfo<ClusterID> {
036    private static final long serialVersionUID = 1L;
037
038    protected final Map<Integer,MutableLong> clusterCounts;
039    protected int unknownCount = 0;
040
041    ClusteringInfo() {
042        clusterCounts = new HashMap<>();
043    }
044
045    ClusteringInfo(ClusteringInfo other) {
046        clusterCounts = MutableNumber.copyMap(other.clusterCounts);
047    }
048
049    @Override
050    public int getUnknownCount() {
051        return unknownCount;
052    }
053
054    @Override
055    public Set<ClusterID> getDomain() {
056        Set<ClusterID> outputs = new HashSet<>();
057        for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) {
058            outputs.add(new ClusterID(e.getKey()));
059        }
060        return outputs;
061    }
062
063    @Override
064    public int size() {
065        return clusterCounts.size();
066    }
067
068    @Override
069    public ImmutableOutputInfo<ClusterID> generateImmutableOutputInfo() {
070        return new ImmutableClusteringInfo(this);
071    }
072
073    @Override
074    public MutableOutputInfo<ClusterID> generateMutableOutputInfo() {
075        return new MutableClusteringInfo(this);
076    }
077
078    @Override
079    public abstract ClusteringInfo copy();
080
081    @Override
082    public Iterable<Pair<String, Long>> outputCountsIterable() {
083        return () -> new Iterator<Pair<String,Long>>() {
084            Iterator<Map.Entry<Integer,MutableLong>> itr = clusterCounts.entrySet().iterator();
085
086            @Override
087            public boolean hasNext() {
088                return itr.hasNext();
089            }
090
091            @Override
092            public Pair<String,Long> next() {
093                Map.Entry<Integer,MutableLong> e = itr.next();
094                return new Pair<>(""+e.getKey(),e.getValue().longValue());
095            }
096        };
097    }
098
099    @Override
100    public String toReadableString() {
101        StringBuilder builder = new StringBuilder();
102        for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) {
103            if (builder.length() > 0) {
104                builder.append(", ");
105            }
106            builder.append('(');
107            builder.append(e.getKey());
108            builder.append(',');
109            builder.append(e.getValue().longValue());
110            builder.append(')');
111        }
112        return builder.toString();
113    }
114}