001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering.kmeans;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.Example;
021import org.tribuo.Excuse;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.ImmutableOutputInfo;
024import org.tribuo.Model;
025import org.tribuo.Prediction;
026import org.tribuo.clustering.ClusterID;
027import org.tribuo.clustering.kmeans.KMeansTrainer.Distance;
028import org.tribuo.math.la.DenseVector;
029import org.tribuo.math.la.SparseVector;
030import org.tribuo.provenance.ModelProvenance;
031
032import java.util.Collections;
033import java.util.List;
034import java.util.Map;
035import java.util.Optional;
036
037/**
038 * A K-Means model with a selectable distance function.
039 * <p>
040 * The predict method of this model assigns centres to the provided input,
041 * but it does not update the model's centroids.
042 * <p>
043 * The predict method is single threaded.
044 * <p>
045 * See:
046 * <pre>
047 * J. Friedman, T. Hastie, &amp; R. Tibshirani.
048 * "The Elements of Statistical Learning"
049 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
050 * </pre>
051 */
052public class KMeansModel extends Model<ClusterID> {
053    private static final long serialVersionUID = 1L;
054
055    private final DenseVector[] centroidVectors;
056
057    private final Distance distanceType;
058
059    KMeansModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<ClusterID> outputIDInfo, DenseVector[] centroidVectors, Distance distanceType) {
060        super(name,description,featureIDMap,outputIDInfo,false);
061        this.centroidVectors = centroidVectors;
062        this.distanceType = distanceType;
063    }
064
065    /**
066     * Returns a copy of the centroids.
067     * @return The centroids.
068     */
069    public DenseVector[] getCentroidVectors() {
070        DenseVector[] copies = new DenseVector[centroidVectors.length];
071
072        for (int i = 0; i < copies.length; i++) {
073            copies[i] = centroidVectors[i].copy();
074        }
075
076        return copies;
077    }
078
079    @Override
080    public Prediction<ClusterID> predict(Example<ClusterID> example) {
081        SparseVector vector = SparseVector.createSparseVector(example,featureIDMap,false);
082        if (vector.numActiveElements() == 0) {
083            throw new IllegalArgumentException("No features found in Example " + example.toString());
084        }
085        double minDistance = Double.POSITIVE_INFINITY;
086        int id = -1;
087        for (int i = 0; i < centroidVectors.length; i++) {
088            double distance;
089            switch (distanceType) {
090                case EUCLIDEAN:
091                    distance = centroidVectors[i].euclideanDistance(vector);
092                    break;
093                case COSINE:
094                    distance = centroidVectors[i].cosineDistance(vector);
095                    break;
096                case L1:
097                    distance = centroidVectors[i].l1Distance(vector);
098                    break;
099                default:
100                    throw new IllegalStateException("Unknown distance " + distanceType);
101            }
102            if (distance < minDistance) {
103                minDistance = distance;
104                id = i;
105            }
106        }
107        return new Prediction<>(new ClusterID(id),vector.size(),example);
108    }
109
110    @Override
111    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
112        return Collections.emptyMap();
113    }
114
115    @Override
116    public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) {
117        return Optional.empty();
118    }
119
120    @Override
121    protected KMeansModel copy(String newName, ModelProvenance newProvenance) {
122        DenseVector[] newCentroids = new DenseVector[centroidVectors.length];
123        for (int i = 0; i < centroidVectors.length; i++) {
124            newCentroids[i] = centroidVectors[i].copy();
125        }
126        return new KMeansModel(newName,newProvenance,featureIDMap,outputIDInfo,newCentroids,distanceType);
127    }
128}