001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.explanations.lime; 018 019import com.oracle.labs.mlrg.olcut.util.Pair; 020import org.tribuo.CategoricalInfo; 021import org.tribuo.Example; 022import org.tribuo.ImmutableFeatureMap; 023import org.tribuo.Model; 024import org.tribuo.MutableDataset; 025import org.tribuo.OutputFactory; 026import org.tribuo.Prediction; 027import org.tribuo.RealInfo; 028import org.tribuo.SparseModel; 029import org.tribuo.SparseTrainer; 030import org.tribuo.VariableIDInfo; 031import org.tribuo.VariableInfo; 032import org.tribuo.WeightedExamples; 033import org.tribuo.classification.Label; 034import org.tribuo.classification.LabelFactory; 035import org.tribuo.classification.explanations.TabularExplainer; 036import org.tribuo.impl.ArrayExample; 037import org.tribuo.interop.ExternalModel; 038import org.tribuo.math.la.SparseVector; 039import org.tribuo.math.la.VectorTuple; 040import org.tribuo.provenance.SimpleDataSourceProvenance; 041import org.tribuo.regression.RegressionFactory; 042import org.tribuo.regression.Regressor; 043import org.tribuo.regression.evaluation.RegressionEvaluation; 044import org.tribuo.regression.evaluation.RegressionEvaluator; 045import org.tribuo.util.Util; 046 047import java.time.OffsetDateTime; 048import java.util.ArrayList; 049import java.util.Iterator; 050import java.util.List; 051import java.util.Map; 052import java.util.Random; 053import java.util.SplittableRandom; 054import java.util.logging.Logger; 055 056/** 057 * LIMEBase merges the lime_base.py and lime_tabular.py implementations, and deals with simple 058 * matrices of numerical or categorical data. If you want a mixture of text, numerical 059 * and categorical data try {@link LIMEColumnar}. For plain text data use {@link LIMEText}. 060 * <p> 061 * See: 062 * <pre> 063 * Ribeiro MT, Singh S, Guestrin C. 064 * "Why should I trust you?: Explaining the predictions of any classifier" 065 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016. 066 * </pre> 067 */ 068public class LIMEBase implements TabularExplainer<Regressor> { 069 private static final Logger logger = Logger.getLogger(LIMEBase.class.getName()); 070 071 public static final double WIDTH_CONSTANT = 0.75; 072 public static final double DISTANCE_DELTA = 1e-12; 073 074 protected static final OutputFactory<Regressor> regressionFactory = new RegressionFactory(); 075 protected static final RegressionEvaluator evaluator = new RegressionEvaluator(true); 076 077 protected final SplittableRandom rng; 078 079 protected final Model<Label> innerModel; 080 081 protected final SparseTrainer<Regressor> explanationTrainer; 082 083 protected final int numSamples; 084 085 protected final long numTrainingExamples; 086 087 protected final double kernelWidth; 088 089 private final ImmutableFeatureMap fMap; 090 091 /** 092 * Constructs a LIME explainer for a model which uses tabular data (i.e., no special treatment for text features). 093 * @param rng The rng to use for sampling. 094 * @param innerModel The model to explain. 095 * @param explanationTrainer The sparse trainer used to explain predictions. 096 * @param numSamples The number of samples to generate for an explanation. 097 */ 098 public LIMEBase(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples) { 099 if (!(explanationTrainer instanceof WeightedExamples)) { 100 throw new IllegalArgumentException("SparseTrainer must implement WeightedExamples, found " + explanationTrainer.toString()); 101 } 102 if (!innerModel.generatesProbabilities()) { 103 throw new IllegalArgumentException("LIME requires the model generate probabilities."); 104 } 105 if (innerModel instanceof ExternalModel) { 106 throw new IllegalArgumentException("LIME requires the model to have been trained in Tribuo. Found " + innerModel.getClass() + " which is an external model."); 107 } 108 this.rng = rng; 109 this.innerModel = innerModel; 110 this.explanationTrainer = explanationTrainer; 111 this.numSamples = numSamples; 112 this.numTrainingExamples = innerModel.getOutputIDInfo().getTotalObservations(); 113 this.kernelWidth = Math.pow(innerModel.getFeatureIDMap().size() * WIDTH_CONSTANT, 2); 114 this.fMap = innerModel.getFeatureIDMap(); 115 } 116 117 @Override 118 public LIMEExplanation explain(Example<Label> example) { 119 return explainWithSamples(example).getA(); 120 } 121 122 protected Pair<LIMEExplanation,List<Example<Regressor>>> explainWithSamples(Example<Label> example) { 123 // Predict using the full model, and generate a new example containing that prediction. 124 Prediction<Label> prediction = innerModel.predict(example); 125 Example<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction),example,1.0f); 126 127 // Sample a dataset. 128 List<Example<Regressor>> sample = sampleData(example); 129 130 // Generate a sparse model on the sampled data. 131 SparseModel<Regressor> model = trainExplainer(labelledExample,sample); 132 133 // Test the sparse model against the predictions of the real model. 134 List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample)); 135 predictions.add(model.predict(labelledExample)); 136 RegressionEvaluation evaluation = evaluator.evaluate(model,predictions,new SimpleDataSourceProvenance("LIMEColumnar sampled data",regressionFactory)); 137 138 return new Pair<>(new LIMEExplanation(model,prediction,evaluation),sample); 139 } 140 141 /** 142 * Sample a dataset based on the input example. 143 * <p> 144 * The sampled dataset uses the feature dimensions from the {@link Model}. 145 * <p> 146 * The outputs are the probability values of each class from the underlying Model, 147 * rather than ground truth outputs. The distance is measured using the 148 * {@link LIMEBase#measureDistance} function, transformed through a kernel and used 149 * as the sampled Example's weight. 150 * @param example The example to sample from. 151 * @return A sampled dataset. 152 */ 153 private List<Example<Regressor>> sampleData(Example<Label> example) { 154 List<Example<Regressor>> output = new ArrayList<>(); 155 156 SparseVector exampleVector = SparseVector.createSparseVector(example,fMap,false); 157 158 Random innerRNG = new Random(rng.nextLong()); 159 for (int i = 0; i < numSamples; i++) { 160 // Sample a new Example. 161 Example<Label> sample = samplePoint(innerRNG,fMap,numTrainingExamples,exampleVector); 162 163 //logger.fine("Itr " + i + " sampled " + sample.toString()); 164 165 // Label it using the full model. 166 Prediction<Label> samplePrediction = innerModel.predict(sample); 167 168 // Measure the distance between this point and the input, to be used as a weight. 169 double distance = measureDistance(fMap,numTrainingExamples,exampleVector, SparseVector.createSparseVector(sample,fMap,false)); 170 171 // Transform distance through the kernel function. 172 distance = kernelDist(distance,kernelWidth); 173 174 // Generate the new sample with the appropriate label and weight. 175 Example<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction),sample,(float)distance); 176 output.add(labelledSample); 177 } 178 179 return output; 180 } 181 182 /** 183 * Samples a single example from the supplied feature map and input vector. 184 * @param rng The rng to use. 185 * @param fMap The feature map describing the domain of the features. 186 * @param numTrainingExamples The number of training examples the fMap has seen. 187 * @param input The input sparse vector to use. 188 * @return An Example sampled from the supplied feature map and input vector. 189 */ 190 public static Example<Label> samplePoint(Random rng, ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input) { 191 ArrayList<String> names = new ArrayList<>(); 192 ArrayList<Double> values = new ArrayList<>(); 193 194 for (VariableInfo info : fMap) { 195 int id = ((VariableIDInfo)info).getID(); 196 double inputValue = input.get(id); 197 198 if (info instanceof CategoricalInfo) { 199 // This one is tricksy as categorical info essentially implicitly includes a zero. 200 CategoricalInfo catInfo = (CategoricalInfo) info; 201 double sample = catInfo.frequencyBasedSample(rng,numTrainingExamples); 202 // If we didn't sample zero. 203 if (Math.abs(sample) > 1e-10) { 204 names.add(info.getName()); 205 values.add(sample); 206 } 207 } else if (info instanceof RealInfo) { 208 RealInfo realInfo = (RealInfo) info; 209 // As realInfo is sparse we sample from the mixture distribution, 210 // either 0 or N(inputValue,variance). 211 // This assumes realInfo never observed a zero, which is enforced from v2.1 212 // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab? 213 // If it's not zero do we want to? 214 int count = realInfo.getCount(); 215 double threshold = count / ((double)numTrainingExamples); 216 if (rng.nextDouble() < threshold) { 217 double variance = realInfo.getVariance(); 218 double sample = (rng.nextGaussian() * Math.sqrt(variance)) + inputValue; 219 names.add(info.getName()); 220 values.add(sample); 221 } 222 } else { 223 throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName()); 224 } 225 } 226 227 return new ArrayExample<>(LabelFactory.UNKNOWN_LABEL,names.toArray(new String[0]),Util.toPrimitiveDouble(values)); 228 } 229 230 /** 231 * Trains the explanation model using the supplied sampled data and the input example. 232 * <p> 233 * The labels are usually the predicted probabilities from the real model. 234 * @param target The input example to explain. 235 * @param samples The sampled data around the input. 236 * @return An explanation model. 237 */ 238 protected SparseModel<Regressor> trainExplainer(Example<Regressor> target, List<Example<Regressor>> samples) { 239 MutableDataset<Regressor> explanationDataset = new MutableDataset<>(new SimpleDataSourceProvenance("explanationDataset", OffsetDateTime.now(), regressionFactory), regressionFactory); 240 explanationDataset.add(target); 241 explanationDataset.addAll(samples); 242 243 SparseModel<Regressor> explainer = explanationTrainer.train(explanationDataset); 244 245 return explainer; 246 } 247 248 /** 249 * Calculates an RBF kernel of a specific width. 250 * @param input The input value. 251 * @param width The width of the kernel. 252 * @return sqrt ( exp ( - input*input / width)) 253 */ 254 public static double kernelDist(double input, double width) { 255 return Math.sqrt(Math.exp(-(input*input) / width)); 256 } 257 258 /** 259 * Measures the distance between an input point and a sampled point. 260 * <p> 261 * This distance function takes into account categorical and real values. It uses 262 * the hamming distance for categoricals and the euclidean distance for real values. 263 * @param fMap The feature map used to determine if a feature is categorical or real. 264 * @param numTrainingExamples The number of training examples the fMap has seen. 265 * @param input The input point. 266 * @param sample The sampled point. 267 * @return The distance between the two points. 268 */ 269 public static double measureDistance(ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input, SparseVector sample) { 270 double score = 0.0; 271 272 Iterator<VectorTuple> itr = input.iterator(); 273 Iterator<VectorTuple> otherItr = sample.iterator(); 274 VectorTuple tuple; 275 VectorTuple otherTuple; 276 while (itr.hasNext() && otherItr.hasNext()) { 277 tuple = itr.next(); 278 otherTuple = otherItr.next(); 279 //after this loop, either itr is out or tuple.index >= otherTuple.index 280 while (itr.hasNext() && (tuple.index < otherTuple.index)) { 281 score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value); 282 tuple = itr.next(); 283 } 284 //after this loop, either otherItr is out or tuple.index <= otherTuple.index 285 while (otherItr.hasNext() && (tuple.index > otherTuple.index)) { 286 score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value); 287 otherTuple = otherItr.next(); 288 } 289 if (tuple.index == otherTuple.index) { 290 //the indices line up, do the calculation. 291 score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value,otherTuple.value); 292 } else { 293 // Now consume both the values as they'll be gone next iteration. 294 // Consume the value in tuple. 295 score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value); 296 // Consume the value in otherTuple. 297 score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value); 298 } 299 } 300 while (itr.hasNext()) { 301 tuple = itr.next(); 302 score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value); 303 } 304 while (otherItr.hasNext()) { 305 otherTuple = otherItr.next(); 306 score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value); 307 } 308 309 return Math.sqrt(score); 310 } 311 312 /** 313 * Calculates the distance between two values for a single feature. 314 * <p> 315 * Assumes the other value is zero as the example is sparse. 316 * @param fMap The feature map which knows if a feature is categorical or real. 317 * @param numTrainingExamples The number of training examples this feature map observed. 318 * @param index The id number for this feature. 319 * @param value One feature value. 320 * @return The distance from zero to the supplied value. 321 */ 322 private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double value) { 323 VariableInfo info = fMap.get(index); 324 if (info instanceof CategoricalInfo) { 325 return 1.0; 326 } else if (info instanceof RealInfo) { 327 RealInfo rInfo = (RealInfo) info; 328 // Fudge the distance calculation so it doesn't overpower the categoricals. 329 double curScore = value * value; 330 double range; 331 // This further fudge is because the RealInfo may have observed a zero if it's sparse, but it might not. 332 if (numTrainingExamples != info.getCount()) { 333 range = Math.max(rInfo.getMax(),0.0) - Math.min(rInfo.getMin(),0.0); 334 } else { 335 range = rInfo.getMax() - rInfo.getMin(); 336 } 337 return curScore / (range*range); 338 } else { 339 throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName()); 340 } 341 } 342 343 /** 344 * Calculates the distance between two values for a single feature. 345 * 346 * @param fMap The feature map which knows if a feature is categorical or real. 347 * @param numTrainingExamples The number of training examples this feature map observed. 348 * @param index The id number for this feature. 349 * @param firstValue The first feature value. 350 * @param secondValue The second feature value. 351 * @return The distance between the two values. 352 */ 353 private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double firstValue, double secondValue) { 354 VariableInfo info = fMap.get(index); 355 if (info instanceof CategoricalInfo) { 356 if (Math.abs(firstValue - secondValue) > DISTANCE_DELTA) { 357 return 1.0; 358 } else { 359 // else the values are the same so the hamming distance is zero. 360 return 0.0; 361 } 362 } else if (info instanceof RealInfo) { 363 RealInfo rInfo = (RealInfo) info; 364 // Fudge the distance calculation so it doesn't overpower the categoricals. 365 double tmp = firstValue - secondValue; 366 double range; 367 // This further fudge is because the RealInfo may have observed a zero if it's sparse, but it might not. 368 if (numTrainingExamples != info.getCount()) { 369 range = Math.max(rInfo.getMax(),0.0) - Math.min(rInfo.getMin(),0.0); 370 } else { 371 range = rInfo.getMax() - rInfo.getMin(); 372 } 373 return (tmp*tmp) / (range*range); 374 } else { 375 throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName()); 376 } 377 } 378 379 /** 380 * Transforms a {@link Prediction} for a multiclass problem into a {@link Regressor} 381 * output which represents the probability for each class. 382 * <p> 383 * Used as the target for LIME Models. 384 * @param prediction A multiclass prediction object. Must contain probabilities. 385 * @return The n dimensional probability output. 386 */ 387 public static Regressor transformOutput(Prediction<Label> prediction) { 388 Map<String,Label> outputs = prediction.getOutputScores(); 389 390 String[] names = new String[outputs.size()]; 391 double[] values = new double[outputs.size()]; 392 393 int i = 0; 394 for (Map.Entry<String,Label> e : outputs.entrySet()) { 395 names[i] = e.getKey(); 396 values[i] = e.getValue().getScore(); 397 i++; 398 } 399 400 return new Regressor(names,values); 401 } 402 403}