001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.anomaly.libsvm; 018 019import com.oracle.labs.mlrg.olcut.provenance.Provenance; 020import com.oracle.labs.mlrg.olcut.util.Pair; 021import org.tribuo.Dataset; 022import org.tribuo.Example; 023import org.tribuo.ImmutableFeatureMap; 024import org.tribuo.ImmutableOutputInfo; 025import org.tribuo.anomaly.Event; 026import org.tribuo.anomaly.Event.EventType; 027import org.tribuo.common.libsvm.LibSVMModel; 028import org.tribuo.common.libsvm.LibSVMTrainer; 029import org.tribuo.common.libsvm.SVMParameters; 030import org.tribuo.provenance.ModelProvenance; 031import libsvm.svm; 032import libsvm.svm_model; 033import libsvm.svm_node; 034import libsvm.svm_parameter; 035import libsvm.svm_problem; 036 037import java.util.ArrayList; 038import java.util.Collections; 039import java.util.List; 040import java.util.Map; 041import java.util.logging.Logger; 042 043/** 044 * A trainer for anomaly models that uses LibSVM. 045 * <p> 046 * See: 047 * <pre> 048 * Chang CC, Lin CJ. 049 * "LIBSVM: a library for Support Vector Machines" 050 * ACM transactions on intelligent systems and technology (TIST), 2011. 051 * </pre> 052 * <p> 053 * and for the anomaly detection algorithm: 054 * <pre> 055 * Schölkopf B, Platt J, Shawe-Taylor J, Smola A J, Williamson R C. 056 * "Estimating the support of a high-dimensional distribution" 057 * Neural Computation, 2001, 1443-1471. 058 * </pre> 059 */ 060public class LibSVMAnomalyTrainer extends LibSVMTrainer<Event> { 061 private static final Logger logger = Logger.getLogger(LibSVMAnomalyTrainer.class.getName()); 062 063 /** 064 * For OLCUT. 065 */ 066 protected LibSVMAnomalyTrainer() {} 067 068 /** 069 * Creates a one-class LibSVM trainer using the supplied parameters. 070 * @param parameters The training parameters. 071 */ 072 public LibSVMAnomalyTrainer(SVMParameters<Event> parameters) { 073 super(parameters); 074 } 075 076 /** 077 * Used by the OLCUT configuration system, and should not be called by external code. 078 */ 079 @Override 080 public void postConfig() { 081 super.postConfig(); 082 if (!svmType.isAnomaly()) { 083 throw new IllegalArgumentException("Supplied classification or regression parameters to an anomaly detection SVM."); 084 } 085 } 086 087 @Override 088 public LibSVMModel<Event> train(Dataset<Event> dataset, Map<String, Provenance> instanceProvenance) { 089 for (Pair<String,Long> p : dataset.getOutputInfo().outputCountsIterable()) { 090 if (p.getA().equals(EventType.ANOMALOUS.toString()) && (p.getB() > 0)) { 091 throw new IllegalArgumentException("LibSVMAnomalyTrainer only supports EXPECTED events at training time."); 092 } 093 } 094 return super.train(dataset,instanceProvenance); 095 } 096 097 @Override 098 protected LibSVMModel<Event> createModel(ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Event> outputIDInfo, List<svm_model> models) { 099 return new LibSVMAnomalyModel("svm-anomaly-detection-model", provenance, featureIDMap, outputIDInfo, models); 100 } 101 102 @Override 103 protected List<svm_model> trainModels(svm_parameter curParams, int numFeatures, svm_node[][] features, double[][] outputs) { 104 svm_problem problem = new svm_problem(); 105 problem.l = outputs[0].length; 106 problem.x = features; 107 problem.y = outputs[0]; 108 if (curParams.gamma == 0) { 109 curParams.gamma = 1.0 / numFeatures; 110 } 111 String checkString = svm.svm_check_parameter(problem, curParams); 112 if(checkString != null) { 113 throw new IllegalArgumentException("Error checking SVM parameters: " + checkString); 114 } 115 return Collections.singletonList(svm.svm_train(problem, curParams)); 116 } 117 118 @Override 119 protected Pair<svm_node[][], double[][]> extractData(Dataset<Event> data, ImmutableOutputInfo<Event> outputInfo, ImmutableFeatureMap featureMap) { 120 double[][] ys = new double[1][data.size()]; 121 svm_node[][] xs = new svm_node[data.size()][]; 122 List<svm_node> buffer = new ArrayList<>(); 123 int i = 0; 124 for (Example<Event> example : data) { 125 ys[0][i] = extractOutput(example.getOutput()); 126 xs[i] = exampleToNodes(example, featureMap, buffer); 127 i++; 128 } 129 return new Pair<>(xs,ys); 130 } 131 132 /** 133 * Converts an output into a double for use in training. 134 * <p> 135 * By convention {@link EventType#EXPECTED} is 1.0, other events are -1.0. 136 * @param output The output to convert. 137 * @return The double value. 138 */ 139 protected double extractOutput(Event output) { 140 if (output.getType() == Event.EventType.EXPECTED) { 141 return 1.0; 142 } else { 143 return -1.0; 144 } 145 } 146}