## GPT-style transformer inspection

**OBJECTIVE:** The goal of this session is to familiarize with GPT-style architectures by exploring their activation patterns and attention mechanisms.

In [None]:
! pip install transformer_lens --quiet

In [None]:
import torch
import torch_xla.core.xla_model as xm
import transformer_lens
from transformer_lens import HookedTransformer, utils

import numpy as np
import matplotlib.pyplot as plt

torch.set_grad_enabled(False)

**Question 0:** Run the following model loading code.

In [None]:
model = HookedTransformer.from_pretrained("phi-3")

# get model config
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab

**Question 1**: Tokenize the given sentence using `model.tokenizer.encode(?, return_tensors="pt")` and print the result.

In [None]:
input_text = """<|system|>You are a translator from english to french.<|system|>
<|user|> English: I like coffee with milk. <|user|>
<|assistant|>French: J'aime le cafe avec du lait. <|assistant|>
<|user|> English: When the curious cat spotted a bright laser dot, it leaped across the room with lightning speed. <|user|>
<|assistant|> French: """

#####################################################################################################################################

**Question 2:** Use `model.tokenizer.decode()` to convert the previously tokenized sentence back to its original text and display the result.

**Question 3:** Analyze how the words `sunflower`, `sun`, and `flow` are tokenized. What differences do you observe in their tokenization patterns?

**Question 4:** Print the model and examine the layers.

**Question 4-bis:** Understand the provided forward hooking code (which is a little bit different from the previous sessions).

In [None]:
# remove all the hooks of the model
model.reset_hooks()

# global dict for forward activations
activation_dict = {}

# Define the hook function
def record_activation(activation, hook):
    activation_dict[hook.name] = activation.detach().cpu()

# Example forward hook attached to a layer using `.add_hook()`
model.blocks[24].attn.hook_attn_scores.add_hook(record_activation)

**Question 5:** Attach forward hooks to the embedding layer, and a selected attention layer. *(Hint: You can only attach to layers which are `HookPoint()` modules, and for attention layer you want to look at the attention scores so attach to `hook_attn_scores`)*

**Question 6:** Perform inference on the tokenized output from Question 1 using the `model`, and store the resulting output in the variable `output_tensor`.

**Question 7:** Create a scatter plot using `matplotlib.pyplot.scatter()` to visualize the softmaxed logits of selected tokens. Specifically, do one plot for the last token and several plots for mid-sequence tokens based on the data stored in `output_tensor`.

**Question 8:** Generate a sentence using argmax decoding by using `model.generate()` with `max_new_tokens=50`, `stop_at_eos=True`, and `do_sample = False`. Then, perform a single inference pass on the generated sentence and store the resulting output in the variable `output_on_generated`. The variable `output_on_generated` now serves as the primary input for model inference.

**Question 9:** Use `matplotlib.pyplot.imshow()` to visualize the output of the embedding layer that was captured by the forward hook after inference on `output_on_generated`.

**Question 10:** Use `matplotlib.pyplot.imshow()` to visualize the attention scores from a single attention head, as captured by the forward hook in Question 5.

**Question 11:** Use the provided code to recreate the previous visualization in a more readable format, similar to the one presented in Lecture Slide 315. What issue do you observe in the resulting plot?



In [None]:
# Taken and modified from PyTorch forum
def visualize_attention_scores(attn_scores, text, head_idx=0, spacing=0.5):
    """
    Plot attention scores from a specified head by drawing lines
    between two columns of tokens.

    Args:
        attn_scores: Tensor of shape [batch_size, n_heads, seq_len, seq_len] containing attention scores.
        text: A list of strings that we use as labels.
        head_idx: Which attention head to visualize.
        spacing: Vertical spacing factor between tokens.
    """
    # Extract scores for the single batch and a given attenion head
    head_scores = attn_scores[0, head_idx] # shape: [seq_len, seq_len]

    # Softmax over the row dimension
    head_probs = torch.softmax(head_scores, dim=-1)
    seq_len = head_probs.shape[0]

    # Check if you put the corresponding text into the vizualisation!!!!!!!
    assert seq_len == len(text), (f"Mismatch: attention has seq_len={seq_len} but input_tokens={len(text)}.")

    # Create y-positions for each token, spaced out by `spacing`
    y_positions = [spacing * (seq_len - 1 - i) for i in range(seq_len)]

    # Figure settings
    fig_width = 10
    fig_height = max(6, 0.5 * spacing * seq_len)
    plt.figure(figsize=(fig_width, fig_height))
    ax = plt.gca()

    ax.set_xlim([-0.4, 1.4])
    ax.set_ylim([-spacing, spacing * seq_len])

    # Add the text labels in both columns
    for i, token_id in enumerate(text):
        token_str = token_id
        y_pos = y_positions[i]

        # Left column text
        plt.text(-0.1, y_pos, token_str, ha='right', va='center', fontsize=10)

        # Right column text
        plt.text(1.1, y_pos, token_str, ha='left', va='center', fontsize=10)

    # Draw lines for attention scores with color depending on the attention score
    for q_idx in range(seq_len):
        for k_idx in range(seq_len):
            score = head_probs[q_idx, k_idx].item()

            alpha = score

            # Skip lines with very low attention for clarity
            if alpha < 0.01:
                continue

            x1, y1 = 0.0, y_positions[q_idx]
            x2, y2 = 1.0, y_positions[k_idx]

            plt.plot([x1, x2], [y1, y2], color='blue', alpha=alpha, linewidth=2)

    plt.axis('off')
    plt.title(f"Attention Scores (Head {head_idx})", fontsize=14)
    plt.show()


**Question 12:** Use the provided code to extract the fused word tokens from the `input_text` used in Question 1. The output include the fused word strings and a list of lists, where each sublist contains the indices of tokens that combine to form a single word. Store both output of the function as `fused_text` and `token_to_fuse_idx`.


In [None]:
def group_tokens(input_text):
  """
  Given `input_text` return the list `fused_text` which contains all fused subwords and `fused_tokenIdx_list` with tokens indices to fuse.

  Args:
      input_text: text to preprocess.
  """
  pretokenized_text = model.tokenizer.tokenize(input_text)

  indexis_list = list()
  special_tokens = ["<|user|>", "<|assistant|>", "<|system|>", "<s>", "<0x0A>", "<unk>", "</s>", "<|endoftext|>", "<|end|>"]
  punctuations = [".", ",", "!", "?", ":", ";", "\"", "“", "”", "«", "»", "„", "(", ")", "[", "]", "{", "}", "/"]

  ######### Generation of fused_tokenIdx_list ####################################

  for idx, text in enumerate(pretokenized_text):
    # Describe condition when the token is not a whole word
    if text[0] != "▁" and text not in special_tokens and text not in punctuations:
      # One particular case after special token where you need to add to indexis_list
      if idx >= 1 and pretokenized_text[idx-1] in special_tokens:
        indexis_list.append(idx)
        continue
      continue

    indexis_list.append(idx)
  # Also append len(pretokenized_text) if indexis_list does not contain it
  if len(indexis_list) == 0 or indexis_list[-1] != len(pretokenized_text):
    indexis_list.append(len(pretokenized_text))

  # Fuse the indexis when there is a gap between indexis bigger than 1
  fused_tokenIdx_list = list()
  for i in range(len(indexis_list)-1):
    if indexis_list[i+1] - indexis_list[i] >= 2:
      fused_tokenIdx_list.append(list(range(indexis_list[i], indexis_list[i+1])))
    else:
      fused_tokenIdx_list.append(indexis_list[i])

  # Remove placeholder list
  del indexis_list
  #################################################################################

  ######### Generation `fused_text` by fusing the words when needed ###############
  fused_text = list()
  for idx in fused_tokenIdx_list:
    if type(idx) == list:
      fused_text.append(''.join(pretokenized_text[i] for i in idx))
    else:
      fused_text.append(pretokenized_text[idx])

  return fused_tokenIdx_list, fused_text

**Question 13:** Using the token indices obtained in Question 13, aggregate the attention scores for tokens that form the same word, as demonstrated in lecture slide 314. This fusion should combine the attention scores of all tokens corresponding to each individual word. Store this fused attention score matrix into `fused_attention_tensor`.


In [None]:
def fusion_score_attention(attention_score, input_text):

  fused_tokens, fused_text = group_tokens(input_text)
  fused_attention = attention_score.clone()
  # s = 0
  # print(len(fused_tokens))
  for i in range(len(fused_tokens)):
    idx = fused_tokens[i]
    if type(idx) == list:
      # print(idx)
      # print(i)
      a, b = idx[0], idx[-1]
      # print(a,b)
      fused_column = torch.sum(fused_attention[:, :, :, a:b+1], dim = -1)
      # print(fused_column.shape)
      fused_attention[:, :, :, a] =  fused_column
      fused_attention = torch.cat([fused_attention[:, :, :, :a+1], fused_attention[:, :, :, b+1:]], dim = -1)

      fused_rows = torch.mean(fused_attention[:, :, a:b+1, :], dim = -2)
      # print(fused_attention[:, :, a, :].shape)
      fused_attention[:, :, a, :] =  fused_rows
      # print(fused_rows.shape)
      fused_attention = torch.cat([fused_attention[:, :, :a+1, :], fused_attention[:, :, b+1:, :]], dim = -2)

      num_deleted_row_and_column = b-a
      # s+= num_deleted_row_and_column
      # substract num_deleted_row_and_column from each indexis in fused_tokens
      fused_tokens = [item - num_deleted_row_and_column if isinstance(item, int) else [sub - num_deleted_row_and_column for sub in item] for item in fused_tokens]
      # print(fused_tokens)

  # print(s)

  return fused_attention

**Question 14:** Using the same code as in Question 11, plot the fused attention score matrix `fused_attention_tensor`. Additionally, use the fused words strings `fused_text` obtained from Question 13 as argument of the ploting function to enhance the visualization.


**Question 15:** Analyze the behaviour of other attention heads and attention layers by visualizing their attention scores, following the visualization method used in the previous question.


**Question 16:** Modify the `input_text` to perform a different task of your choice (e.g., translating from English to German, solving a reasoning problem, etc.). After making this change, analyze the attention heads to determine if your previous observations remain consistent with the new input.


**Question 17:** Combine all attention heads in the specified layer by calculating the mean across the head dimension, and visualize the resulting attention matrix using the same approach as in Question 11. Is this a great way to aggregate attention scores across the head dimension?
