001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import com.oracle.labs.mlrg.olcut.util.MutableNumber; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.ImmutableOutputInfo; 023 024import java.util.Collections; 025import java.util.HashSet; 026import java.util.Iterator; 027import java.util.Map; 028import java.util.Set; 029 030/** 031 * An {@link ImmutableOutputInfo} object for ClusterIDs. 032 * <p> 033 * Gives each unique cluster an id number. Also counts each id occurrence like {@link MutableClusteringInfo} does, 034 * though the counts are frozen in this object. 035 */ 036public class ImmutableClusteringInfo extends ClusteringInfo implements ImmutableOutputInfo<ClusterID> { 037 private static final long serialVersionUID = 1L; 038 039 private final Set<ClusterID> domain; 040 041 public ImmutableClusteringInfo(Map<Integer,MutableLong> counts) { 042 super(); 043 clusterCounts.putAll(MutableNumber.copyMap(counts)); 044 045 Set<ClusterID> outputs = new HashSet<>(); 046 for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) { 047 outputs.add(new ClusterID(e.getKey())); 048 } 049 domain = Collections.unmodifiableSet(outputs); 050 } 051 052 public ImmutableClusteringInfo(ClusteringInfo other) { 053 super(other); 054 Set<ClusterID> outputs = new HashSet<>(); 055 for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) { 056 outputs.add(new ClusterID(e.getKey())); 057 } 058 domain = Collections.unmodifiableSet(outputs); 059 } 060 061 @Override 062 public Set<ClusterID> getDomain() { 063 return domain; 064 } 065 066 @Override 067 public int getID(ClusterID output) { 068 return output.getID(); 069 } 070 071 @Override 072 public ClusterID getOutput(int id) { 073 return new ClusterID(id); 074 } 075 076 @Override 077 public long getTotalObservations() { 078 long count = 0; 079 for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) { 080 count += e.getValue().longValue(); 081 } 082 return count; 083 } 084 085 @Override 086 public ClusteringInfo copy() { 087 return new ImmutableClusteringInfo(this); 088 } 089 090 @Override 091 public Iterator<Pair<Integer, ClusterID>> iterator() { 092 return new ImmutableInfoIterator(clusterCounts.keySet()); 093 } 094 095 private static class ImmutableInfoIterator implements Iterator<Pair<Integer,ClusterID>> { 096 097 private final Iterator<Integer> itr; 098 099 public ImmutableInfoIterator(Set<Integer> idLabelMap) { 100 itr = idLabelMap.iterator(); 101 } 102 103 @Override 104 public boolean hasNext() { 105 return itr.hasNext(); 106 } 107 108 @Override 109 public Pair<Integer, ClusterID> next() { 110 int id = itr.next(); 111 return new Pair<>(id, new ClusterID(id)); 112 } 113 } 114}