001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.common.nearest;
018
019import com.oracle.labs.mlrg.olcut.config.ArgumentException;
020import com.oracle.labs.mlrg.olcut.config.Option;
021import org.tribuo.classification.ClassificationOptions;
022import org.tribuo.classification.Label;
023import org.tribuo.classification.ensemble.FullyWeightedVotingCombiner;
024import org.tribuo.classification.ensemble.VotingCombiner;
025import org.tribuo.common.nearest.KNNModel.Backend;
026import org.tribuo.common.nearest.KNNTrainer.Distance;
027import org.tribuo.ensemble.EnsembleCombiner;
028
029/**
030 * CLI Options for training a k-nearest neighbour predictor.
031 */
032public class KNNClassifierOptions implements ClassificationOptions<KNNTrainer<Label>> {
033
034    /**
035     * The type of combination function.
036     */
037    public enum EnsembleCombinerType { VOTING, FULLY_WEIGHTED_VOTING}
038
039    @Option(longName="knn-k",usage="K nearest neighbours to use. Defaults to 1.")
040    public int knnK = 1;
041    @Option(longName="knn-num-threads",usage="Number of threads to use. Defaults to 1.")
042    public int knnNumThreads = 1;
043    @Option(longName="knn-distance",usage="Distance metric to use. Defaults to L2.")
044    public Distance knnDistance = Distance.L2;
045    @Option(longName="knn-backend",usage="Parallel backend to use.")
046    public Backend knnBackend = Backend.STREAMS;
047    @Option(longName="knn-voting",usage="Parallel backend to use.")
048    public EnsembleCombinerType knnEnsembleCombiner = EnsembleCombinerType.VOTING;
049
050    @Override
051    public String getOptionsDescription() {
052        return "Options for parameterising a LibLinear classification trainer.";
053    }
054
055    private EnsembleCombiner<Label> getEnsembleCombiner() {
056        switch (knnEnsembleCombiner) {
057            case VOTING:
058                return new VotingCombiner();
059            case FULLY_WEIGHTED_VOTING:
060                return new FullyWeightedVotingCombiner();
061            default:
062                throw new ArgumentException("ensemble combiner", "Unknown ensemble combiner " + knnEnsembleCombiner);
063        }
064    }
065
066    @Override
067    public KNNTrainer<Label> getTrainer() {
068        return new KNNTrainer<>(knnK, knnDistance, knnNumThreads, getEnsembleCombiner(), knnBackend);
069    }
070}