001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.anomaly.libsvm;
018
019import com.oracle.labs.mlrg.olcut.provenance.Provenance;
020import com.oracle.labs.mlrg.olcut.util.Pair;
021import org.tribuo.Dataset;
022import org.tribuo.Example;
023import org.tribuo.ImmutableFeatureMap;
024import org.tribuo.ImmutableOutputInfo;
025import org.tribuo.anomaly.Event;
026import org.tribuo.anomaly.Event.EventType;
027import org.tribuo.common.libsvm.LibSVMModel;
028import org.tribuo.common.libsvm.LibSVMTrainer;
029import org.tribuo.common.libsvm.SVMParameters;
030import org.tribuo.provenance.ModelProvenance;
031import libsvm.svm;
032import libsvm.svm_model;
033import libsvm.svm_node;
034import libsvm.svm_parameter;
035import libsvm.svm_problem;
036
037import java.util.ArrayList;
038import java.util.Collections;
039import java.util.List;
040import java.util.Map;
041import java.util.logging.Logger;
042
043/**
044 * A trainer for anomaly models that uses LibSVM.
045 * <p>
046 * See:
047 * <pre>
048 * Chang CC, Lin CJ.
049 * "LIBSVM: a library for Support Vector Machines"
050 * ACM transactions on intelligent systems and technology (TIST), 2011.
051 * </pre>
052 * <p>
053 * and for the anomaly detection algorithm:
054 * <pre>
055 * Schölkopf B, Platt J, Shawe-Taylor J, Smola A J, Williamson R C.
056 * "Estimating the support of a high-dimensional distribution"
057 * Neural Computation, 2001, 1443-1471.
058 * </pre>
059 */
060public class LibSVMAnomalyTrainer extends LibSVMTrainer<Event> {
061    private static final Logger logger = Logger.getLogger(LibSVMAnomalyTrainer.class.getName());
062
063    /**
064     * For OLCUT.
065     */
066    protected LibSVMAnomalyTrainer() {}
067
068    /**
069     * Creates a one-class LibSVM trainer using the supplied parameters.
070     * @param parameters The training parameters.
071     */
072    public LibSVMAnomalyTrainer(SVMParameters<Event> parameters) {
073        super(parameters);
074    }
075
076    /**
077     * Used by the OLCUT configuration system, and should not be called by external code.
078     */
079    @Override
080    public void postConfig() {
081        super.postConfig();
082        if (!svmType.isAnomaly()) {
083            throw new IllegalArgumentException("Supplied classification or regression parameters to an anomaly detection SVM.");
084        }
085    }
086
087    @Override
088    public LibSVMModel<Event> train(Dataset<Event> dataset, Map<String, Provenance> instanceProvenance) {
089        for (Pair<String,Long> p : dataset.getOutputInfo().outputCountsIterable()) {
090            if (p.getA().equals(EventType.ANOMALOUS.toString()) && (p.getB() > 0)) {
091                throw new IllegalArgumentException("LibSVMAnomalyTrainer only supports EXPECTED events at training time.");
092            }
093        }
094        return super.train(dataset,instanceProvenance);
095    }
096
097    @Override
098    protected LibSVMModel<Event> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Event> outputIDInfo, List<svm_model> models) {
099        return new LibSVMAnomalyModel("svm-anomaly-detection-model", provenance, featureIDMap, outputIDInfo, models);
100    }
101
102    @Override
103    protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) {
104        svm_problem problem = new svm_problem();
105        problem.l = outputs[0].length;
106        problem.x = features;
107        problem.y = outputs[0];
108        if (curParams.gamma == 0) {
109            curParams.gamma = 1.0 / numFeatures;
110        }
111        String checkString = svm.svm_check_parameter(problem, curParams);
112        if(checkString != null) {
113            throw new IllegalArgumentException("Error checking SVM parameters: " + checkString);
114        }
115        return Collections.singletonList(svm.svm_train(problem, curParams));
116    }
117
118    @Override
119    protected Pair<svm_node[][], double[][]> extractData(Dataset<Event> data, ImmutableOutputInfo<Event> outputInfo, ImmutableFeatureMap featureMap) {
120        double[][] ys = new double[1][data.size()];
121        svm_node[][] xs = new svm_node[data.size()][];
122        List<svm_node> buffer = new ArrayList<>();
123        int i = 0;
124        for (Example<Event> example : data) {
125            ys[0][i] = extractOutput(example.getOutput());
126            xs[i] = exampleToNodes(example, featureMap, buffer);
127            i++;
128        }
129        return new Pair<>(xs,ys);
130    }
131
132    /**
133     * Converts an output into a double for use in training.
134     * <p>
135     * By convention {@link EventType#EXPECTED} is 1.0, other events are -1.0.
136     * @param output The output to convert.
137     * @return The double value.
138     */
139    protected double extractOutput(Event output) {
140        if (output.getType() == Event.EventType.EXPECTED) {
141            return 1.0;
142        } else {
143            return -1.0;
144        }
145    }
146}