001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.anomaly.libsvm; 018 019import org.tribuo.Example; 020import org.tribuo.ImmutableFeatureMap; 021import org.tribuo.ImmutableOutputInfo; 022import org.tribuo.Prediction; 023import org.tribuo.anomaly.Event; 024import org.tribuo.common.libsvm.LibSVMModel; 025import org.tribuo.common.libsvm.LibSVMTrainer; 026import org.tribuo.provenance.ModelProvenance; 027import libsvm.svm; 028import libsvm.svm_model; 029import libsvm.svm_node; 030 031import java.util.Collections; 032import java.util.List; 033 034/** 035 * A anomaly detection model that uses an underlying libSVM model to make the 036 * predictions. 037 * <p> 038 * See: 039 * <pre> 040 * Chang CC, Lin CJ. 041 * "LIBSVM: a library for Support Vector Machines" 042 * ACM transactions on intelligent systems and technology (TIST), 2011. 043 * </pre> 044 * <p> 045 * and for the anomaly detection algorithm: 046 * <pre> 047 * Schölkopf B, Platt J, Shawe-Taylor J, Smola A J, Williamson R C. 048 * "Estimating the support of a high-dimensional distribution" 049 * Neural Computation, 2001, 1443-1471. 050 * </pre> 051 */ 052public class LibSVMAnomalyModel extends LibSVMModel<Event> { 053 private static final long serialVersionUID = 1L; 054 055 LibSVMAnomalyModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Event> labelIDMap, List<svm_model> models) { 056 super(name, description, featureIDMap, labelIDMap, models.get(0).param.probability == 1, models); 057 } 058 059 /** 060 * Returns the number of support vectors. 061 * @return The number of support vectors. 062 */ 063 public int getNumberOfSupportVectors() { 064 return models.get(0).SV.length; 065 } 066 067 @Override 068 public Prediction<Event> predict(Example<Event> example) { 069 svm_node[] features = LibSVMTrainer.exampleToNodes(example, featureIDMap, null); 070 // Bias feature is always set by the library. 071 if (features.length == 0) { 072 throw new IllegalArgumentException("No features found in Example " + example.toString()); 073 } 074 double[] score = new double[1]; 075 double prediction = svm.svm_predict_values(models.get(0), features, score); 076 if (prediction < 0.0) { 077 return new Prediction<>(new Event(Event.EventType.ANOMALOUS,score[0]),features.length,example); 078 } else { 079 return new Prediction<>(new Event(Event.EventType.EXPECTED,score[0]),features.length,example); 080 } 081 } 082 083 @Override 084 protected LibSVMAnomalyModel copy(String newName, ModelProvenance newProvenance) { 085 return new LibSVMAnomalyModel(newName,newProvenance,featureIDMap,outputIDInfo, Collections.singletonList(LibSVMModel.copyModel(models.get(0)))); 086 } 087 088}