001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering; 018 019import com.oracle.labs.mlrg.olcut.provenance.Provenance; 020import com.oracle.labs.mlrg.olcut.util.MutableLong; 021import org.tribuo.ImmutableOutputInfo; 022import org.tribuo.MutableOutputInfo; 023import org.tribuo.OutputFactory; 024import org.tribuo.clustering.evaluation.ClusteringEvaluation; 025import org.tribuo.clustering.evaluation.ClusteringEvaluator; 026import org.tribuo.evaluation.Evaluator; 027import org.tribuo.provenance.OutputFactoryProvenance; 028 029import java.util.HashMap; 030import java.util.Map; 031 032/** 033 * A factory for making ClusterID related classes. 034 * <p> 035 * Parses the ClusterID by calling toString on the input then parsing it as an int. 036 */ 037public final class ClusteringFactory implements OutputFactory<ClusterID> { 038 private static final long serialVersionUID = 1L; 039 040 public static final ClusterID UNASSIGNED_CLUSTER_ID = new ClusterID(ClusterID.UNASSIGNED); 041 042 private static final ClusteringFactoryProvenance provenance = new ClusteringFactoryProvenance(); 043 044 private static final ClusteringEvaluator evaluator = new ClusteringEvaluator(); 045 046 /** 047 * ClusteringFactory is stateless and immutable, but we need to be able to construct them via the config system. 048 */ 049 public ClusteringFactory() {} 050 051 /** 052 * Generates a ClusterID by calling toString on the input, then calling Integer.parseInt. 053 * @param label An input value. 054 * @param <V> The type of the input. 055 * @return A ClusterID representing the data. 056 */ 057 @Override 058 public <V> ClusterID generateOutput(V label) { 059 return new ClusterID(Integer.parseInt(label.toString())); 060 } 061 062 @Override 063 public ClusterID getUnknownOutput() { 064 return UNASSIGNED_CLUSTER_ID; 065 } 066 067 @Override 068 public MutableOutputInfo<ClusterID> generateInfo() { 069 return new MutableClusteringInfo(); 070 } 071 072 /** 073 * Unlike the other info types, clustering directly uses the integer IDs as the stored value, 074 * so this mapping discards the cluster IDs and just uses the supplied integers. 075 * @param mapping The mapping to use. 076 * @return An {@link ImmutableOutputInfo} for the clustering. 077 */ 078 @Override 079 public ImmutableOutputInfo<ClusterID> constructInfoForExternalModel(Map<ClusterID,Integer> mapping) { 080 // Validate inputs are dense 081 OutputFactory.validateMapping(mapping); 082 083 Map<Integer, MutableLong> countsMap = new HashMap<>(); 084 085 for (Map.Entry<ClusterID,Integer> e : mapping.entrySet()) { 086 countsMap.put(e.getValue(),new MutableLong(1)); 087 } 088 089 return new ImmutableClusteringInfo(countsMap); 090 } 091 092 @Override 093 public Evaluator<ClusterID, ClusteringEvaluation> getEvaluator() { 094 return evaluator; 095 } 096 097 @Override 098 public int hashCode() { 099 return "ClusteringFactory".hashCode(); 100 } 101 102 @Override 103 public boolean equals(Object obj) { 104 return obj instanceof ClusteringFactory; 105 } 106 107 @Override 108 public OutputFactoryProvenance getProvenance() { 109 return provenance; 110 } 111 112 /** 113 * Provenance for {@link ClusteringFactory}. 114 */ 115 public final static class ClusteringFactoryProvenance implements OutputFactoryProvenance { 116 private static final long serialVersionUID = 1L; 117 118 /** 119 * Creates a clustering factory provenance. 120 */ 121 ClusteringFactoryProvenance() {} 122 123 /** 124 * Rebuilds a clustering factory provenance from the marshalled form. 125 * @param map The map (which should be empty). 126 */ 127 public ClusteringFactoryProvenance(Map<String, Provenance> map) { } 128 129 @Override 130 public String getClassName() { 131 return ClusteringFactory.class.getName(); 132 } 133 134 @Override 135 public String toString() { 136 return generateString("OutputFactory"); 137 } 138 139 @Override 140 public boolean equals(Object other) { 141 return other instanceof ClusteringFactoryProvenance; 142 } 143 144 @Override 145 public int hashCode() { 146 return 31; 147 } 148 } 149}