001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.anomaly.libsvm;
018
019import org.tribuo.Example;
020import org.tribuo.ImmutableFeatureMap;
021import org.tribuo.ImmutableOutputInfo;
022import org.tribuo.Prediction;
023import org.tribuo.anomaly.Event;
024import org.tribuo.common.libsvm.LibSVMModel;
025import org.tribuo.common.libsvm.LibSVMTrainer;
026import org.tribuo.provenance.ModelProvenance;
027import libsvm.svm;
028import libsvm.svm_model;
029import libsvm.svm_node;
030
031import java.util.Collections;
032import java.util.List;
033
034/**
035 * A anomaly detection model that uses an underlying libSVM model to make the
036 * predictions.
037 * <p>
038 * See:
039 * <pre>
040 * Chang CC, Lin CJ.
041 * "LIBSVM: a library for Support Vector Machines"
042 * ACM transactions on intelligent systems and technology (TIST), 2011.
043 * </pre>
044 * <p>
045 * and for the anomaly detection algorithm:
046 * <pre>
047 * Schölkopf B, Platt J, Shawe-Taylor J, Smola A J, Williamson R C.
048 * "Estimating the support of a high-dimensional distribution"
049 * Neural Computation, 2001, 1443-1471.
050 * </pre>
051 */
052public class LibSVMAnomalyModel extends LibSVMModel<Event> {
053    private static final long serialVersionUID = 1L;
054
055    LibSVMAnomalyModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Event> labelIDMap, List<svm_model> models) {
056        super(name, description, featureIDMap, labelIDMap, models.get(0).param.probability == 1, models);
057    }
058
059    /**
060     * Returns the number of support vectors.
061     * @return The number of support vectors.
062     */
063    public int getNumberOfSupportVectors() {
064        return models.get(0).SV.length;
065    }
066
067    @Override
068    public Prediction<Event> predict(Example<Event> example) {
069        svm_node[] features = LibSVMTrainer.exampleToNodes(example, featureIDMap, null);
070        // Bias feature is always set by the library.
071        if (features.length == 0) {
072            throw new IllegalArgumentException("No features found in Example " + example.toString());
073        }
074        double[] score = new double[1];
075        double prediction = svm.svm_predict_values(models.get(0), features, score);
076        if (prediction < 0.0) {
077            return new Prediction<>(new Event(Event.EventType.ANOMALOUS,score[0]),features.length,example);
078        } else {
079            return new Prediction<>(new Event(Event.EventType.EXPECTED,score[0]),features.length,example);
080        }
081    }
082
083    @Override
084    protected LibSVMAnomalyModel copy(String newName, ModelProvenance newProvenance) {
085        return new LibSVMAnomalyModel(newName,newProvenance,featureIDMap,outputIDInfo, Collections.singletonList(LibSVMModel.copyModel(models.get(0))));
086    }
087
088}