001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.clustering.kmeans; 018 019import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; 020import com.oracle.labs.mlrg.olcut.config.Option; 021import com.oracle.labs.mlrg.olcut.config.Options; 022import com.oracle.labs.mlrg.olcut.config.UsageException; 023import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; 024import com.oracle.labs.mlrg.olcut.util.Pair; 025import org.tribuo.Dataset; 026import org.tribuo.Model; 027import org.tribuo.clustering.ClusterID; 028import org.tribuo.clustering.ClusteringFactory; 029import org.tribuo.clustering.evaluation.ClusteringEvaluation; 030import org.tribuo.clustering.kmeans.KMeansTrainer.Distance; 031import org.tribuo.data.DataOptions; 032 033import java.io.IOException; 034import java.util.logging.Logger; 035 036 037/** 038 * Build and run a k-means clustering model for a standard dataset. 039 */ 040public class TrainTest { 041 042 private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); 043 044 /** 045 * Options for the K-Means CLI. 046 */ 047 public static class KMeansOptions implements Options { 048 @Override 049 public String getOptionsDescription() { 050 return "Trains and evaluates a K-Means model on the specified dataset."; 051 } 052 public DataOptions general; 053 054 @Option(charName='n',longName="num-clusters",usage="Number of clusters to infer. Defaults to 5.") 055 public int centroids = 5; 056 @Option(charName='i',longName="iterations",usage="Maximum number of iterations. Defaults to 10.") 057 public int iterations = 10; 058 @Option(charName='d',longName="distance",usage="Distance function to use in the e step. Defaults to EUCLIDEAN.") 059 public Distance distance = Distance.EUCLIDEAN; 060 @Option(charName='t',longName="num-threads",usage="Number of threads to use (default 4, range (1, num hw threads)).") 061 public int numThreads = 4; 062 } 063 064 /** 065 * @param args the command line arguments 066 * @throws IOException if there is any error reading the examples. 067 */ 068 public static void main(String[] args) throws IOException { 069 // 070 // Use the labs format logging. 071 LabsLogFormatter.setAllLogFormatters(); 072 073 KMeansOptions o = new KMeansOptions(); 074 ConfigurationManager cm; 075 try { 076 cm = new ConfigurationManager(args,o); 077 } catch (UsageException e) { 078 logger.info(e.getMessage()); 079 return; 080 } 081 082 if (o.general.trainingPath == null) { 083 logger.info(cm.usage()); 084 return; 085 } 086 087 ClusteringFactory factory = new ClusteringFactory(); 088 089 Pair<Dataset<ClusterID>,Dataset<ClusterID>> data = o.general.load(factory); 090 Dataset<ClusterID> train = data.getA(); 091 092 //public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, int seed) 093 KMeansTrainer trainer = new KMeansTrainer(o.centroids,o.iterations,o.distance,o.numThreads,o.general.seed); 094 Model<ClusterID> model = trainer.train(train); 095 logger.info("Finished training model"); 096 ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model,train); 097 logger.info("Finished evaluating model"); 098 System.out.println("Normalized MI = " + evaluation.normalizedMI()); 099 System.out.println("Adjusted MI = " + evaluation.adjustedMI()); 100 101 if (o.general.outputPath != null) { 102 o.general.saveModel(model); 103 } 104 } 105}