001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering;
018
019import com.oracle.labs.mlrg.olcut.provenance.Provenance;
020import com.oracle.labs.mlrg.olcut.util.MutableLong;
021import org.tribuo.ImmutableOutputInfo;
022import org.tribuo.MutableOutputInfo;
023import org.tribuo.OutputFactory;
024import org.tribuo.clustering.evaluation.ClusteringEvaluation;
025import org.tribuo.clustering.evaluation.ClusteringEvaluator;
026import org.tribuo.evaluation.Evaluator;
027import org.tribuo.provenance.OutputFactoryProvenance;
028
029import java.util.HashMap;
030import java.util.Map;
031
032/**
033 * A factory for making ClusterID related classes.
034 * <p>
035 * Parses the ClusterID by calling toString on the input then parsing it as an int.
036 */
037public final class ClusteringFactory implements OutputFactory<ClusterID> {
038    private static final long serialVersionUID = 1L;
039
040    public static final ClusterID UNASSIGNED_CLUSTER_ID = new ClusterID(ClusterID.UNASSIGNED);
041
042    private static final ClusteringFactoryProvenance provenance = new ClusteringFactoryProvenance();
043
044    private static final ClusteringEvaluator evaluator = new ClusteringEvaluator();
045
046    /**
047     * ClusteringFactory is stateless and immutable, but we need to be able to construct them via the config system.
048     */
049    public ClusteringFactory() {}
050
051    /**
052     * Generates a ClusterID by calling toString on the input, then calling Integer.parseInt.
053     * @param label An input value.
054     * @param <V> The type of the input.
055     * @return A ClusterID representing the data.
056     */
057    @Override
058    public <V> ClusterID generateOutput(V label) {
059        return new ClusterID(Integer.parseInt(label.toString()));
060    }
061
062    @Override
063    public ClusterID getUnknownOutput() {
064        return UNASSIGNED_CLUSTER_ID;
065    }
066
067    @Override
068    public MutableOutputInfo<ClusterID> generateInfo() {
069        return new MutableClusteringInfo();
070    }
071
072    /**
073     * Unlike the other info types, clustering directly uses the integer IDs as the stored value,
074     * so this mapping discards the cluster IDs and just uses the supplied integers.
075     * @param mapping The mapping to use.
076     * @return An {@link ImmutableOutputInfo} for the clustering.
077     */
078    @Override
079    public ImmutableOutputInfo<ClusterID> constructInfoForExternalModel(Map<ClusterID,Integer> mapping) {
080        // Validate inputs are dense
081        OutputFactory.validateMapping(mapping);
082
083        Map<Integer, MutableLong> countsMap = new HashMap<>();
084
085        for (Map.Entry<ClusterID,Integer> e : mapping.entrySet()) {
086            countsMap.put(e.getValue(),new MutableLong(1));
087        }
088
089        return new ImmutableClusteringInfo(countsMap);
090    }
091
092    @Override
093    public Evaluator<ClusterID, ClusteringEvaluation> getEvaluator() {
094        return evaluator;
095    }
096
097    @Override
098    public int hashCode() {
099        return "ClusteringFactory".hashCode();
100    }
101
102    @Override
103    public boolean equals(Object obj) {
104        return obj instanceof ClusteringFactory;
105    }
106
107    @Override
108    public OutputFactoryProvenance getProvenance() {
109        return provenance;
110    }
111
112    /**
113     * Provenance for {@link ClusteringFactory}.
114     */
115    public final static class ClusteringFactoryProvenance implements OutputFactoryProvenance {
116        private static final long serialVersionUID = 1L;
117
118        /**
119         * Creates a clustering factory provenance.
120         */
121        ClusteringFactoryProvenance() {}
122
123        /**
124         * Rebuilds a clustering factory provenance from the marshalled form.
125         * @param map The map (which should be empty).
126         */
127        public ClusteringFactoryProvenance(Map<String, Provenance> map) { }
128
129        @Override
130        public String getClassName() {
131            return ClusteringFactory.class.getName();
132        }
133
134        @Override
135        public String toString() {
136            return generateString("OutputFactory");
137        }
138
139        @Override
140        public boolean equals(Object other) {
141            return other instanceof ClusteringFactoryProvenance;
142        }
143
144        @Override
145        public int hashCode() {
146            return 31;
147        }
148    }
149}