001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.explanations.lime;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.CategoricalInfo;
021import org.tribuo.Example;
022import org.tribuo.ImmutableFeatureMap;
023import org.tribuo.Model;
024import org.tribuo.MutableDataset;
025import org.tribuo.OutputFactory;
026import org.tribuo.Prediction;
027import org.tribuo.RealInfo;
028import org.tribuo.SparseModel;
029import org.tribuo.SparseTrainer;
030import org.tribuo.VariableIDInfo;
031import org.tribuo.VariableInfo;
032import org.tribuo.WeightedExamples;
033import org.tribuo.classification.Label;
034import org.tribuo.classification.LabelFactory;
035import org.tribuo.classification.explanations.TabularExplainer;
036import org.tribuo.impl.ArrayExample;
037import org.tribuo.interop.ExternalModel;
038import org.tribuo.math.la.SparseVector;
039import org.tribuo.math.la.VectorTuple;
040import org.tribuo.provenance.SimpleDataSourceProvenance;
041import org.tribuo.regression.RegressionFactory;
042import org.tribuo.regression.Regressor;
043import org.tribuo.regression.evaluation.RegressionEvaluation;
044import org.tribuo.regression.evaluation.RegressionEvaluator;
045import org.tribuo.util.Util;
046
047import java.time.OffsetDateTime;
048import java.util.ArrayList;
049import java.util.Iterator;
050import java.util.List;
051import java.util.Map;
052import java.util.Random;
053import java.util.SplittableRandom;
054import java.util.logging.Logger;
055
056/**
057 * LIMEBase merges the lime_base.py and lime_tabular.py implementations, and deals with simple
058 * matrices of numerical or categorical data. If you want a mixture of text, numerical
059 * and categorical data try {@link LIMEColumnar}. For plain text data use {@link LIMEText}.
060 * <p>
061 * See:
062 * <pre>
063 * Ribeiro MT, Singh S, Guestrin C.
064 * "Why should I trust you?: Explaining the predictions of any classifier"
065 * Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining 2016.
066 * </pre>
067 */
068public class LIMEBase implements TabularExplainer<Regressor> {
069    private static final Logger logger = Logger.getLogger(LIMEBase.class.getName());
070
071    public static final double WIDTH_CONSTANT = 0.75;
072    public static final double DISTANCE_DELTA = 1e-12;
073
074    protected static final OutputFactory<Regressor> regressionFactory = new RegressionFactory();
075    protected static final RegressionEvaluator evaluator = new RegressionEvaluator(true);
076
077    protected final SplittableRandom rng;
078
079    protected final Model<Label> innerModel;
080
081    protected final SparseTrainer<Regressor> explanationTrainer;
082
083    protected final int numSamples;
084
085    protected final long numTrainingExamples;
086
087    protected final double kernelWidth;
088
089    private final ImmutableFeatureMap fMap;
090
091    /**
092     * Constructs a LIME explainer for a model which uses tabular data (i.e., no special treatment for text features).
093     * @param rng The rng to use for sampling.
094     * @param innerModel The model to explain.
095     * @param explanationTrainer The sparse trainer used to explain predictions.
096     * @param numSamples The number of samples to generate for an explanation.
097     */
098    public LIMEBase(SplittableRandom rng, Model<Label> innerModel, SparseTrainer<Regressor> explanationTrainer, int numSamples) {
099        if (!(explanationTrainer instanceof WeightedExamples)) {
100            throw new IllegalArgumentException("SparseTrainer must implement WeightedExamples, found " + explanationTrainer.toString());
101        }
102        if (!innerModel.generatesProbabilities()) {
103            throw new IllegalArgumentException("LIME requires the model generate probabilities.");
104        }
105        if (innerModel instanceof ExternalModel) {
106            throw new IllegalArgumentException("LIME requires the model to have been trained in Tribuo. Found " + innerModel.getClass() + " which is an external model.");
107        }
108        this.rng = rng;
109        this.innerModel = innerModel;
110        this.explanationTrainer = explanationTrainer;
111        this.numSamples = numSamples;
112        this.numTrainingExamples = innerModel.getOutputIDInfo().getTotalObservations();
113        this.kernelWidth = Math.pow(innerModel.getFeatureIDMap().size() * WIDTH_CONSTANT, 2);
114        this.fMap = innerModel.getFeatureIDMap();
115    }
116
117    @Override
118    public LIMEExplanation explain(Example<Label> example) {
119        return explainWithSamples(example).getA();
120    }
121
122    protected Pair<LIMEExplanation,List<Example<Regressor>>> explainWithSamples(Example<Label> example) {
123        // Predict using the full model, and generate a new example containing that prediction.
124        Prediction<Label> prediction = innerModel.predict(example);
125        Example<Regressor> labelledExample = new ArrayExample<>(transformOutput(prediction),example,1.0f);
126
127        // Sample a dataset.
128        List<Example<Regressor>> sample = sampleData(example);
129
130        // Generate a sparse model on the sampled data.
131        SparseModel<Regressor> model = trainExplainer(labelledExample,sample);
132
133        // Test the sparse model against the predictions of the real model.
134        List<Prediction<Regressor>> predictions = new ArrayList<>(model.predict(sample));
135        predictions.add(model.predict(labelledExample));
136        RegressionEvaluation evaluation = evaluator.evaluate(model,predictions,new SimpleDataSourceProvenance("LIMEColumnar sampled data",regressionFactory));
137
138        return new Pair<>(new LIMEExplanation(model,prediction,evaluation),sample);
139    }
140
141    /**
142     * Sample a dataset based on the input example.
143     * <p>
144     * The sampled dataset uses the feature dimensions from the {@link Model}.
145     * <p>
146     * The outputs are the probability values of each class from the underlying Model,
147     * rather than ground truth outputs. The distance is measured using the
148     * {@link LIMEBase#measureDistance} function, transformed through a kernel and used
149     * as the sampled Example's weight.
150     * @param example The example to sample from.
151     * @return A sampled dataset.
152     */
153    private List<Example<Regressor>> sampleData(Example<Label> example) {
154        List<Example<Regressor>> output = new ArrayList<>();
155
156        SparseVector exampleVector = SparseVector.createSparseVector(example,fMap,false);
157
158        Random innerRNG = new Random(rng.nextLong());
159        for (int i = 0; i < numSamples; i++) {
160            // Sample a new Example.
161            Example<Label> sample = samplePoint(innerRNG,fMap,numTrainingExamples,exampleVector);
162
163            //logger.fine("Itr " + i + " sampled " + sample.toString());
164
165            // Label it using the full model.
166            Prediction<Label> samplePrediction = innerModel.predict(sample);
167
168            // Measure the distance between this point and the input, to be used as a weight.
169            double distance = measureDistance(fMap,numTrainingExamples,exampleVector, SparseVector.createSparseVector(sample,fMap,false));
170
171            // Transform distance through the kernel function.
172            distance = kernelDist(distance,kernelWidth);
173
174            // Generate the new sample with the appropriate label and weight.
175            Example<Regressor> labelledSample = new ArrayExample<>(transformOutput(samplePrediction),sample,(float)distance);
176            output.add(labelledSample);
177        }
178
179        return output;
180    }
181
182    /**
183     * Samples a single example from the supplied feature map and input vector.
184     * @param rng The rng to use.
185     * @param fMap The feature map describing the domain of the features.
186     * @param numTrainingExamples The number of training examples the fMap has seen.
187     * @param input The input sparse vector to use.
188     * @return An Example sampled from the supplied feature map and input vector.
189     */
190    public static Example<Label> samplePoint(Random rng, ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input) {
191        ArrayList<String> names = new ArrayList<>();
192        ArrayList<Double> values = new ArrayList<>();
193
194        for (VariableInfo info : fMap) {
195            int id = ((VariableIDInfo)info).getID();
196            double inputValue = input.get(id);
197
198            if (info instanceof CategoricalInfo) {
199                // This one is tricksy as categorical info essentially implicitly includes a zero.
200                CategoricalInfo catInfo = (CategoricalInfo) info;
201                double sample = catInfo.frequencyBasedSample(rng,numTrainingExamples);
202                // If we didn't sample zero.
203                if (Math.abs(sample) > 1e-10) {
204                    names.add(info.getName());
205                    values.add(sample);
206                }
207            } else if (info instanceof RealInfo) {
208                RealInfo realInfo = (RealInfo) info;
209                // As realInfo is sparse we sample from the mixture distribution,
210                // either 0 or N(inputValue,variance).
211                // This assumes realInfo never observed a zero, which is enforced from v2.1
212                // TODO check this makes sense. If the input value is zero do we still want to sample spike and slab?
213                // If it's not zero do we want to?
214                int count = realInfo.getCount();
215                double threshold = count / ((double)numTrainingExamples);
216                if (rng.nextDouble() < threshold) {
217                    double variance = realInfo.getVariance();
218                    double sample = (rng.nextGaussian() * Math.sqrt(variance)) + inputValue;
219                    names.add(info.getName());
220                    values.add(sample);
221                }
222            } else {
223                throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
224            }
225        }
226
227        return new ArrayExample<>(LabelFactory.UNKNOWN_LABEL,names.toArray(new String[0]),Util.toPrimitiveDouble(values));
228    }
229
230    /**
231     * Trains the explanation model using the supplied sampled data and the input example.
232     * <p>
233     * The labels are usually the predicted probabilities from the real model.
234     * @param target The input example to explain.
235     * @param samples The sampled data around the input.
236     * @return An explanation model.
237     */
238    protected SparseModel<Regressor> trainExplainer(Example<Regressor> target, List<Example<Regressor>> samples) {
239        MutableDataset<Regressor> explanationDataset = new MutableDataset<>(new SimpleDataSourceProvenance("explanationDataset", OffsetDateTime.now(), regressionFactory), regressionFactory);
240        explanationDataset.add(target);
241        explanationDataset.addAll(samples);
242
243        SparseModel<Regressor> explainer = explanationTrainer.train(explanationDataset);
244
245        return explainer;
246    }
247
248    /**
249     * Calculates an RBF kernel of a specific width.
250     * @param input The input value.
251     * @param width The width of the kernel.
252     * @return sqrt ( exp ( - input*input / width))
253     */
254    public static double kernelDist(double input, double width) {
255        return Math.sqrt(Math.exp(-(input*input) / width));
256    }
257
258    /**
259     * Measures the distance between an input point and a sampled point.
260     * <p>
261     * This distance function takes into account categorical and real values. It uses
262     * the hamming distance for categoricals and the euclidean distance for real values.
263     * @param fMap The feature map used to determine if a feature is categorical or real.
264     * @param numTrainingExamples The number of training examples the fMap has seen.
265     * @param input The input point.
266     * @param sample The sampled point.
267     * @return The distance between the two points.
268     */
269    public static double measureDistance(ImmutableFeatureMap fMap, long numTrainingExamples, SparseVector input, SparseVector sample) {
270        double score = 0.0;
271
272        Iterator<VectorTuple> itr = input.iterator();
273        Iterator<VectorTuple> otherItr = sample.iterator();
274        VectorTuple tuple;
275        VectorTuple otherTuple;
276        while (itr.hasNext() && otherItr.hasNext()) {
277            tuple = itr.next();
278            otherTuple = otherItr.next();
279            //after this loop, either itr is out or tuple.index >= otherTuple.index
280            while (itr.hasNext() && (tuple.index < otherTuple.index)) {
281                score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value);
282                tuple = itr.next();
283            }
284            //after this loop, either otherItr is out or tuple.index <= otherTuple.index
285            while (otherItr.hasNext() && (tuple.index > otherTuple.index)) {
286                score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value);
287                otherTuple = otherItr.next();
288            }
289            if (tuple.index == otherTuple.index) {
290                //the indices line up, do the calculation.
291                score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value,otherTuple.value);
292            } else {
293                // Now consume both the values as they'll be gone next iteration.
294                // Consume the value in tuple.
295                score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value);
296                // Consume the value in otherTuple.
297                score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value);
298            }
299        }
300        while (itr.hasNext()) {
301            tuple = itr.next();
302            score += calculateSingleDistance(fMap,numTrainingExamples,tuple.index,tuple.value);
303        }
304        while (otherItr.hasNext()) {
305            otherTuple = otherItr.next();
306            score += calculateSingleDistance(fMap,numTrainingExamples,otherTuple.index,otherTuple.value);
307        }
308
309        return Math.sqrt(score);
310    }
311
312    /**
313     * Calculates the distance between two values for a single feature.
314     * <p>
315     * Assumes the other value is zero as the example is sparse.
316     * @param fMap The feature map which knows if a feature is categorical or real.
317     * @param numTrainingExamples The number of training examples this feature map observed.
318     * @param index The id number for this feature.
319     * @param value One feature value.
320     * @return The distance from zero to the supplied value.
321     */
322    private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double value) {
323        VariableInfo info = fMap.get(index);
324        if (info instanceof CategoricalInfo) {
325            return 1.0;
326        } else if (info instanceof RealInfo) {
327            RealInfo rInfo = (RealInfo) info;
328            // Fudge the distance calculation so it doesn't overpower the categoricals.
329            double curScore = value * value;
330            double range;
331            // This further fudge is because the RealInfo may have observed a zero if it's sparse, but it might not.
332            if (numTrainingExamples != info.getCount()) {
333                range = Math.max(rInfo.getMax(),0.0) - Math.min(rInfo.getMin(),0.0);
334            } else {
335                range = rInfo.getMax() - rInfo.getMin();
336            }
337            return curScore / (range*range);
338        } else {
339            throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
340        }
341    }
342
343    /**
344     * Calculates the distance between two values for a single feature.
345     *
346     * @param fMap The feature map which knows if a feature is categorical or real.
347     * @param numTrainingExamples The number of training examples this feature map observed.
348     * @param index The id number for this feature.
349     * @param firstValue The first feature value.
350     * @param secondValue The second feature value.
351     * @return The distance between the two values.
352     */
353    private static double calculateSingleDistance(ImmutableFeatureMap fMap, long numTrainingExamples, int index, double firstValue, double secondValue) {
354        VariableInfo info = fMap.get(index);
355        if (info instanceof CategoricalInfo) {
356            if (Math.abs(firstValue - secondValue) > DISTANCE_DELTA) {
357                return 1.0;
358            } else {
359                // else the values are the same so the hamming distance is zero.
360                return 0.0;
361            }
362        } else if (info instanceof RealInfo) {
363            RealInfo rInfo = (RealInfo) info;
364            // Fudge the distance calculation so it doesn't overpower the categoricals.
365            double tmp = firstValue - secondValue;
366            double range;
367            // This further fudge is because the RealInfo may have observed a zero if it's sparse, but it might not.
368            if (numTrainingExamples != info.getCount()) {
369                range = Math.max(rInfo.getMax(),0.0) - Math.min(rInfo.getMin(),0.0);
370            } else {
371                range = rInfo.getMax() - rInfo.getMin();
372            }
373            return (tmp*tmp) / (range*range);
374        } else {
375            throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + info.getClass().getName());
376        }
377    }
378
379    /**
380     * Transforms a {@link Prediction} for a multiclass problem into a {@link Regressor}
381     * output which represents the probability for each class.
382     * <p>
383     * Used as the target for LIME Models.
384     * @param prediction A multiclass prediction object. Must contain probabilities.
385     * @return The n dimensional probability output.
386     */
387    public static Regressor transformOutput(Prediction<Label> prediction) {
388        Map<String,Label> outputs = prediction.getOutputScores();
389
390        String[] names = new String[outputs.size()];
391        double[] values = new double[outputs.size()];
392
393        int i = 0;
394        for (Map.Entry<String,Label> e : outputs.entrySet()) {
395            names[i] = e.getKey();
396            values[i] = e.getValue().getScore();
397            i++;
398        }
399
400        return new Regressor(names,values);
401    }
402
403}