## LIME for text

In [None]:
! pip install datasets

In [None]:
#@title Python library imports go here
import torch
from transformers import AutoTokenizer, BertForSequenceClassification
from datasets import load_dataset

from sklearn.linear_model import Ridge

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LinearSegmentedColormap
import matplotlib.cm as cm

**Question 0:** The following describes a simple inference procedure for a [BERT](https://huggingface.co/docs/transformers/model_doc/bert) classifier model on [Yelp data](https://huggingface.co/datasets/fancyzhx/yelp_polarity). Experiment with it to understand its components, referring to the documentation as needed. Test each part, such as the tokenizer (observe the outputs), and try different sentences, including both positive and negative text examples.

In [None]:
# load bert tokenizer
tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
# load bert model
model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-yelp-polarity")

# input text to tokenize
text = "Update! Went back last night for dinner, this place is still awesome. I had the Las Vegas Rolls, they were pure deep fried goodness."
# tokenize your input
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)

# inference on your input
with torch.no_grad():
    logits = model(**inputs).logits

# output the predicted class (0 = negative review, 1 = positive review)
predicted_class_id = logits.argmax().item()
print(model.config.id2label[predicted_class_id])

# load dataset
dataset = load_dataset("yelp_polarity")
# Access the train and test splits
train_dataset = dataset['train']
test_dataset = dataset['test']
print(train_dataset[0])
print(test_dataset[1])

**Question 1:** Code a function get_perturbations() that generates $N$ perturbed samples of a given text. For more details, see lecture slides 110-112. (Hint: default $n$ = 1000)

In [None]:
# HINT: use torch.randint to sample s_i (look the slides for this notations)
# HINT: use torch.multinomial or torch.randperm (or numpy.choice which is easier to use) to sample S_i (look the slides for this notations)
# HINT: the index of [UNK] token is 100

**Question 2:** Code a function get_weights() which will compute the weights of the pertubated sample using the cosine distance. For more details, see lecture slide 113.

**Question 3:** Using the previous functions, implement the LIME algorithm as outlined in the lecture slide 114 to explain the predictions of the Bert model.

**Question 4:** Plot a LIME heatmap for a given text, using a color map based on the LIME coefficients of each token.

In [None]:
# HINT: use integer array indexing to generate the heatmap

**Question 5:** Generate a box plot of the LIME coefficients for the top 10 tokens (i.e., those with the highest LIME coefficients) by repeating the LIME method on the same text 10 times. See lecture slide 119 for an example.

In [None]:
# HINT: use the matplotlib.pyplot.boxplot() method.

**Question 6:** Compare the outputs of your LIME implementation with those from the official LimeTextExplainer() method available at https://github.com/marcotcr/lime. Specifically, plot the two LIME heatmaps side-by-side and generate box plots for the top 10 tokens from both implementations.

## Bonus: Play with LIME

**Question 1:** Create a LIME heatmap for a given text by perturbing each token. For each perturbation, replace the token with its nearest neighbor in the embedding space.

**Question 2:** Create a LIME heatmap for a given text by perturbing each token. Use torch.bernoulli() to perturbate a text (as you did for images) instead of two torch.randint()

**Question 3:** Plot multiple box plots on a single figure, each representing a different combination of bandwidth (a parameter of get_weights() function) and perturbed sample size (a parameter of get_pertubations() function). For each combination, run the LIME method 10 times to capture the distribution of coefficients, and display each (bandwidth, sample size) tuple as a separate box plot on the shared figure to visualize the impact of these parameters on the LIME coefficients.

**Question 4:**  Compute the test accuracy of the BERT classifier on the Yelp Polarity dataset.