001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.clustering.kmeans;
018
019import com.oracle.labs.mlrg.olcut.config.Option;
020import com.oracle.labs.mlrg.olcut.config.Options;
021import org.tribuo.Trainer;
022import org.tribuo.clustering.kmeans.KMeansTrainer.Distance;
023
024import java.util.logging.Logger;
025
026/**
027 * OLCUT {@link Options} for the K-Means implementation.
028 */
029public class KMeansOptions implements Options {
030    private static final Logger logger = Logger.getLogger(KMeansOptions.class.getName());
031
032    @Option(longName="kmeans-interations",usage="Iterations of the k-means algorithm. Defaults to 10.")
033    public int iterations = 10;
034    @Option(longName="kmeans-num-centroids",usage="Number of centroids in K-Means. Defaults to 10.")
035    public int centroids = 10;
036    @Option(longName="kmeans-distance",usage="Distance function in K-Means. Defaults to EUCLIDEAN.")
037    public Distance distance = Distance.EUCLIDEAN;
038    @Option(longName="kmeans-num-threads",usage="Number of computation threads in K-Means. Defaults to 4.")
039    public int numThreads = 4;
040    @Option(longName="kmeans-seed", usage = "Sets the random seed for K-Means.")
041    private long seed = Trainer.DEFAULT_SEED;
042
043    public KMeansTrainer getTrainer() {
044        logger.info("Configuring K-Means Trainer");
045        //public KMeansTrainer(int centroids, int iterations, Distance distanceType, int numThreads, int seed) {
046        return new KMeansTrainer(centroids,iterations,distance,numThreads,seed);
047    }
048}