001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering.kmeans; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.Example; 021import org.tribuo.Excuse; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.ImmutableOutputInfo; 024import org.tribuo.Model; 025import org.tribuo.Prediction; 026import org.tribuo.clustering.ClusterID; 027import org.tribuo.clustering.kmeans.KMeansTrainer.Distance; 028import org.tribuo.math.la.DenseVector; 029import org.tribuo.math.la.SparseVector; 030import org.tribuo.provenance.ModelProvenance; 031 032import java.util.Collections; 033import java.util.List; 034import java.util.Map; 035import java.util.Optional; 036 037/** 038 * A K-Means model with a selectable distance function. 039 * <p> 040 * The predict method of this model assigns centres to the provided input, 041 * but it does not update the model's centroids. 042 * <p> 043 * The predict method is single threaded. 044 * <p> 045 * See: 046 * <pre> 047 * J. Friedman, T. Hastie, & R. Tibshirani. 048 * "The Elements of Statistical Learning" 049 * Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a> 050 * </pre> 051 */ 052public class KMeansModel extends Model<ClusterID> { 053 private static final long serialVersionUID = 1L; 054 055 private final DenseVector[] centroidVectors; 056 057 private final Distance distanceType; 058 059 KMeansModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<ClusterID> outputIDInfo, DenseVector[] centroidVectors, Distance distanceType) { 060 super(name,description,featureIDMap,outputIDInfo,false); 061 this.centroidVectors = centroidVectors; 062 this.distanceType = distanceType; 063 } 064 065 /** 066 * Returns a copy of the centroids. 067 * @return The centroids. 068 */ 069 public DenseVector[] getCentroidVectors() { 070 DenseVector[] copies = new DenseVector[centroidVectors.length]; 071 072 for (int i = 0; i < copies.length; i++) { 073 copies[i] = centroidVectors[i].copy(); 074 } 075 076 return copies; 077 } 078 079 @Override 080 public Prediction<ClusterID> predict(Example<ClusterID> example) { 081 SparseVector vector = SparseVector.createSparseVector(example,featureIDMap,false); 082 if (vector.numActiveElements() == 0) { 083 throw new IllegalArgumentException("No features found in Example " + example.toString()); 084 } 085 double minDistance = Double.POSITIVE_INFINITY; 086 int id = -1; 087 for (int i = 0; i < centroidVectors.length; i++) { 088 double distance; 089 switch (distanceType) { 090 case EUCLIDEAN: 091 distance = centroidVectors[i].euclideanDistance(vector); 092 break; 093 case COSINE: 094 distance = centroidVectors[i].cosineDistance(vector); 095 break; 096 case L1: 097 distance = centroidVectors[i].l1Distance(vector); 098 break; 099 default: 100 throw new IllegalStateException("Unknown distance " + distanceType); 101 } 102 if (distance < minDistance) { 103 minDistance = distance; 104 id = i; 105 } 106 } 107 return new Prediction<>(new ClusterID(id),vector.size(),example); 108 } 109 110 @Override 111 public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) { 112 return Collections.emptyMap(); 113 } 114 115 @Override 116 public Optional<Excuse<ClusterID>> getExcuse(Example<ClusterID> example) { 117 return Optional.empty(); 118 } 119 120 @Override 121 protected KMeansModel copy(String newName, ModelProvenance newProvenance) { 122 DenseVector[] newCentroids = new DenseVector[centroidVectors.length]; 123 for (int i = 0; i < centroidVectors.length; i++) { 124 newCentroids[i] = centroidVectors[i].copy(); 125 } 126 return new KMeansModel(newName,newProvenance,featureIDMap,outputIDInfo,newCentroids,distanceType); 127 } 128}