001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering;
018
019import com.oracle.labs.mlrg.olcut.util.MutableLong;
020import com.oracle.labs.mlrg.olcut.util.MutableNumber;
021import com.oracle.labs.mlrg.olcut.util.Pair;
022import org.tribuo.ImmutableOutputInfo;
023
024import java.util.Collections;
025import java.util.HashSet;
026import java.util.Iterator;
027import java.util.Map;
028import java.util.Set;
029
030/**
031 * An {@link ImmutableOutputInfo} object for ClusterIDs.
032 * <p>
033 * Gives each unique cluster an id number. Also counts each id occurrence like {@link MutableClusteringInfo} does,
034 * though the counts are frozen in this object.
035 */
036public class ImmutableClusteringInfo extends ClusteringInfo implements ImmutableOutputInfo<ClusterID> {
037    private static final long serialVersionUID = 1L;
038
039    private final Set<ClusterID> domain;
040
041    public ImmutableClusteringInfo(Map<Integer,MutableLong> counts) {
042        super();
043        clusterCounts.putAll(MutableNumber.copyMap(counts));
044
045        Set<ClusterID> outputs = new HashSet<>();
046        for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) {
047            outputs.add(new ClusterID(e.getKey()));
048        }
049        domain = Collections.unmodifiableSet(outputs);
050    }
051
052    public ImmutableClusteringInfo(ClusteringInfo other) {
053        super(other);
054        Set<ClusterID> outputs = new HashSet<>();
055        for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) {
056            outputs.add(new ClusterID(e.getKey()));
057        }
058        domain = Collections.unmodifiableSet(outputs);
059    }
060
061    @Override
062    public Set<ClusterID> getDomain() {
063        return domain;
064    }
065
066    @Override
067    public int getID(ClusterID output) {
068        return output.getID();
069    }
070
071    @Override
072    public ClusterID getOutput(int id) {
073        return new ClusterID(id);
074    }
075
076    @Override
077    public long getTotalObservations() {
078        long count = 0;
079        for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) {
080            count += e.getValue().longValue();
081        }
082        return count;
083    }
084
085    @Override
086    public ClusteringInfo copy() {
087        return new ImmutableClusteringInfo(this);
088    }
089
090    @Override
091    public Iterator<Pair<Integer, ClusterID>> iterator() {
092        return new ImmutableInfoIterator(clusterCounts.keySet());
093    }
094
095    private static class ImmutableInfoIterator implements Iterator<Pair<Integer,ClusterID>> {
096
097        private final Iterator<Integer> itr;
098
099        public ImmutableInfoIterator(Set<Integer> idLabelMap) {
100            itr = idLabelMap.iterator();
101        }
102
103        @Override
104        public boolean hasNext() {
105            return itr.hasNext();
106        }
107
108        @Override
109        public Pair<Integer, ClusterID> next() {
110            int id = itr.next();
111            return new Pair<>(id, new ClusterID(id));
112        }
113    }
114}