{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# `text` plot\n", "\n", "This notebook is designed to demonstrate (and so document) how to use the `shap.plots.text` function. It uses a distilled PyTorch BERT model from the transformers package to do sentiment analysis of IMDB movie reviews.\n", "\n", "Note that the prediction function we define takes a list of strings and returns a logit value for the positive class." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": false }, "outputs": [], "source": [ "import nlp\n", "import numpy as np\n", "import scipy as sp\n", "import torch\n", "import transformers\n", "\n", "import shap\n", "\n", "# load a BERT sentiment analysis model\n", "tokenizer = transformers.DistilBertTokenizerFast.from_pretrained(\n", " \"distilbert-base-uncased\"\n", ")\n", "model = transformers.DistilBertForSequenceClassification.from_pretrained(\n", " \"distilbert-base-uncased-finetuned-sst-2-english\"\n", ").cuda()\n", "\n", "\n", "# define a prediction function\n", "def f(x):\n", " tv = torch.tensor(\n", " [\n", " tokenizer.encode(v, padding=\"max_length\", max_length=500, truncation=True)\n", " for v in x\n", " ]\n", " ).cuda()\n", " outputs = model(tv)[0].detach().cpu().numpy()\n", " scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T\n", " val = sp.special.logit(scores[:, 1]) # use one vs rest logit units\n", " return val\n", "\n", "\n", "# build an explainer using a token masker\n", "explainer = shap.Explainer(f, tokenizer)\n", "\n", "# explain the model's predictions on IMDB reviews\n", "imdb_train = nlp.load_dataset(\"imdb\")[\"train\"]\n", "shap_values = explainer(imdb_train[:10], fixed_context=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Single instance text plot\n", "\n", "When we pass a single instance to the text plot we get the importance of each token overlayed on the original text that corresponds to that token. Red regions correspond to parts of the text that increase the output of the model when they are included, while blue regions decrease the output of the model when they are included. In the context of the sentiment analysis model here red corresponds to a more positive review and blue a more negative review.\n", "\n", "Note that importance values returned for text models are often hierarchical and follow the structure of the text. Nonlinear interactions between groups of tokens are often saved and can be used during the plotting process. If the Explanation object passed to the text plot has a `.hierarchical_values` attribute, then small groups of tokens with strong non-linear effects among them will be auto-merged together to form coherent chunks. When the `.hierarchical_values` attribute is present it also means that the explainer may not have completely enumerated all possible token perturbations and so has treated chunks of the text as essentially a single unit. This happens since we often want to explain a text model while evaluating it fewer times than the numbers of tokens in the document. Whenever a region of the input text is not split by the explainer, it is show by the text plot as a single unit.\n", "\n", "The force plot above the text is designed to provide an overview of how all the parts of the text combine to produce the model's output. See the [force plot]() notebook for more details, but the general structure of the plot is positive red features \"pushing\" the model output higher while negative blue features \"push\" the model output lower. The force plot provides much more quantitative information than the text coloring. Hovering over a chuck of text will underline the portion of the force plot that corresponds to that chunk of text, and hovering over a portion of the force plot will underline the corresponding chunk of text.\n", "\n", "Note that clicking on any chunk of text will show the sum of the SHAP values attributed to the tokens in that chunk (clicked again will hide the value)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "