001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering; 018 019import com.oracle.labs.mlrg.olcut.util.MutableLong; 020import com.oracle.labs.mlrg.olcut.util.MutableNumber; 021import com.oracle.labs.mlrg.olcut.util.Pair; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.MutableOutputInfo; 024import org.tribuo.OutputInfo; 025 026import java.util.HashMap; 027import java.util.HashSet; 028import java.util.Iterator; 029import java.util.Map; 030import java.util.Set; 031 032/** 033 * The base class for a ClusterID OutputInfo. 034 */ 035public abstract class ClusteringInfo implements OutputInfo<ClusterID> { 036 private static final long serialVersionUID = 1L; 037 038 protected final Map<Integer,MutableLong> clusterCounts; 039 protected int unknownCount = 0; 040 041 ClusteringInfo() { 042 clusterCounts = new HashMap<>(); 043 } 044 045 ClusteringInfo(ClusteringInfo other) { 046 clusterCounts = MutableNumber.copyMap(other.clusterCounts); 047 } 048 049 @Override 050 public int getUnknownCount() { 051 return unknownCount; 052 } 053 054 @Override 055 public Set<ClusterID> getDomain() { 056 Set<ClusterID> outputs = new HashSet<>(); 057 for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) { 058 outputs.add(new ClusterID(e.getKey())); 059 } 060 return outputs; 061 } 062 063 @Override 064 public int size() { 065 return clusterCounts.size(); 066 } 067 068 @Override 069 public ImmutableOutputInfo<ClusterID> generateImmutableOutputInfo() { 070 return new ImmutableClusteringInfo(this); 071 } 072 073 @Override 074 public MutableOutputInfo<ClusterID> generateMutableOutputInfo() { 075 return new MutableClusteringInfo(this); 076 } 077 078 @Override 079 public abstract ClusteringInfo copy(); 080 081 @Override 082 public Iterable<Pair<String, Long>> outputCountsIterable() { 083 return () -> new Iterator<Pair<String,Long>>() { 084 Iterator<Map.Entry<Integer,MutableLong>> itr = clusterCounts.entrySet().iterator(); 085 086 @Override 087 public boolean hasNext() { 088 return itr.hasNext(); 089 } 090 091 @Override 092 public Pair<String,Long> next() { 093 Map.Entry<Integer,MutableLong> e = itr.next(); 094 return new Pair<>(""+e.getKey(),e.getValue().longValue()); 095 } 096 }; 097 } 098 099 @Override 100 public String toReadableString() { 101 StringBuilder builder = new StringBuilder(); 102 for (Map.Entry<Integer,MutableLong> e : clusterCounts.entrySet()) { 103 if (builder.length() > 0) { 104 builder.append(", "); 105 } 106 builder.append('('); 107 builder.append(e.getKey()); 108 builder.append(','); 109 builder.append(e.getValue().longValue()); 110 builder.append(')'); 111 } 112 return builder.toString(); 113 } 114}