# **Multilingual Natural Language Processing: Exercise 3**
## Masked Language Modeling Pre-Training with Transformers

## Part 1: __Pen&Paper Exercises__

Given the following model: $f(x) = w_3x^3+w_2x^2+w_1x+w_0$; optimize the models parameters (i.e., $w_0,w_1,w_2,w_3$) using stochastic gradient descent, given the following training examples (x,y): {(1,3),(-1,-5),(0,-3)} and the squared error loss $(y - f(x))^2$. Update the models parameters after each example. Use a learning rate $\eta = 0.1$ and initialize all parameters to $1$ (i.e., $w_0=w_1=w_2=w_3=1$).

## Part 2: __Introduction__

#### __Scope of Tutorial__

Throughout this Jupyter notebook, we will get an understanding of why transfer learning works by closely examining the most common pre-training paradigm of transformer models, namely so-called Masked Language Modeling (MLM). To that end, this notebook combines a conceptual review of MLM with developing the required modelling intutions of the transformer architecture accompanied by some illustrative code.

#### __Further Materials__

Before covering Masked Language Modeling in depth, a lot of easily and freely accessible resources greatly complement this notebook from additional angles.
Please note that these references are provided for your own benefit and do not directly relate to the exam. Naturally, some of these resources will help you hone your understanding on matters that may be immediately relevant, though none of the materials being referred to here are relevant explicitly.

##### __(a) Pre-Training__

That said, this notebook comprises a lot of very dense information, some of the prerequisites you might not have internalised yet to sufficient extent. Below, you can find a lot of great resources to revisit key elements of this walkthrough:
* [Andrej Karpathy's language modelling series](https://karpathy.ai/zero-to-hero.html): excellent walkthrough from scratch to GPT
    * [Causal Language Modelling (i.e., next word prediction) with GPT-style transformer](https://www.youtube.com/watch?v=kCc8FmEb1nY)
* [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/): well known blog entry covering transformers on a very high level
* [Yannic Kilcher](https://www.youtube.com/channel/UCZHmQk67mSJgfCCTn7xBfew): ETH PhD student who summarises all kinds of deep learning papers in easy-to-understand and digestiable videos
    * [Attention is All You Need](https://youtu.be/iDulhoQ2pro): original work to establish transformer architecture
    * [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://youtu.be/-9evrZnBorM)
    * [GPT-3: Language Models are Few-Shot Learners](https://youtu.be/SY5PvZrJhLE)
   
* [Dive into Deep Learning](https://d2l.ai/): open-source repository covering deep learning and related topics in required depth by combing code in theory in illustrative notebooks


##### __(b) Fine-Tuning__

Once you have established an understanding for the transformer architecture and pre-training more generally, downstream fine-tuning derives rather naturally. "Dive into Deep Learning" has a really accessible section on [fine-tuning BERT](https://d2l.ai/chapter_natural-language-processing-applications/finetuning-bert.html).


In practice, Huggingface [transformers](https://huggingface.co/transformers/) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNetâ€¦) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch. Fine-tuning pre-trained models on common tasks is rather straightforward and extensively covered in a plethora of tutorials, like [this](https://huggingface.co/transformers/custom_datasets.html) one provided by Huggingface themselves. In particular, Huggingface has tutorials for
* [Masked Language Modelling](cttps://huggingface.co/docs/transformers/tasks/sequence_classification)
* [Sequence Classification](https://huggingface.co/docs/transformers/tasks/sequence_classification)
* [Question Answering](https://huggingface.co/docs/transformers/tasks/question_answering)

and many more applications.

The key take-away here is there are plenty of libraries available making fine-tuning pretrained models simple and straightforward for all kinds of (supervised) downstream tasks.

## __Agenda__

* __Prerequisities: intuitive understanding of inner products__
* __Masked Language Modeling in Transformers__
    1. _All-in one overview_
    2. _Pre-processing (steps)_
    3. _Embedding Layer in Transformers_
    4. _Transformer as a series of Transformer Blocks_
        * Self-Attention
        * Feed-Forward Network
    5. _Classification Head: Masked Language Modeling as Multi-Class Classification over (BPE) Vocabulary_

## __Prerequisites: Intuitive Understanding of Inner Products as a Compatability Function of Word Vectors__

Linear algebra is at the heart of deep learning. For completeness, we therefore need an intuitive understanding for basics in matrix algebra to more broadly understand how the transformer architecture models natural language. No need to get scared! We luckily can restrict ourselves to the very fundamental concept of inner products and get a grasp how inner products encapsulate compatibility of word representations.

To that end, inner products of two vectors X, Z (think rows and columns in a matrix with equal dimensions) can be seen as a measure of compatability between the two. Intuitively, assuming vectors X and Z are of comparable length, the projection (or how much is X pointing in direction Z) mostly depends on the angle. For instance, should X and Z be orthogonal to one another, there inner product is zero, as X (Z) is projected to the beginning of Z (X) (cf. right-hand side of illustration).

The illustration of matrix multiplication highlights that it can be viewed as pair-wise inner products of row (e.g. green) and column (and gray) vectors, which will be representations for tokens in our input sequences. We generate these matrices by stacking the word vectors of our input sequences. Then, each row in our matrix corrsponds to our token representation. Conversly, the transpose of that matrix frames each column as a token representation. Hence, so-called outer products, e.g. matrix multiplication of stacked word embeddings from one sentence and a transpose of stacked word embeddings of another sentence, (cf. illustration) enable us to efficiently compute pair-wise inner products of our token representation. In practice, an intuitive point of perspective of inner products applied to token vectors is that, whenever two (length normalized) word vectors have a large inner product, we think of them representing similar semantics as they point to comparable directions in vector space. Our discussion on representation learning via Masked Language Modeling in transformers will deeply tie into this critical intuition.

<img src="./img/inner_product.png" alt="algebra" width="1600"/>

## __Walkthrough of Masked Language Modeling__

#### Task Description

Masked Language Modeling (MLM) is a fill-in-the-blank task, where a deep (transformer) model uses the context words surrounding a mask token to try to predict what the masked word should be. For an input that contains one or more mask tokens, the model will generate the most likely substitution for each.

__Example:__
* Input: "I have watched this \[MASK\] and it was awesome."
* Output: "I have watched this movie and it was awesome."

In addition, think about what alternative tokens to "movie" might the model predict just as well? We would expect that words like "film", "play", (sports) "match", and so forth, would probably also be assigned a high likelihood, as these tokens also naturally occur in the context surrounding the \[MASK\] token. We will develop an understanding how these regularities in natural language will help us embed tokens in a vector space, such that comparable tokens in related contexts are embedded by similar vector representations.

More broadly, Masked Language Modeling is a great way to train a language model in a self-supervised setting (without human-annotated labels). Such a model can then be fine-tuned to accomplish various supervised NLP tasks. However, successful pre-training of very deep models on such a self-supervised objective requires both a very large training corpus (rule of thumb 200M+ tokens) as well as a large number of parameters often in excess of 300M. Therefore, many of the models pre-trained are made available, for instance by Google, Facebook, or Huggingface. 

#### High-level Steps in MLM

Masked language Modeling pre-training on a high level works as follows:
1. Preprocess the input sequence as laid out above to get corresponding identifiers for our BPE-tokenized sequence
2. Index the row of the embedding layer with our identifiers to retrieve the corresponding BPE vectors
3. Feed our BPE vectors into the transformer model to transform the BPE vector representations informed by their context
4. Fetch the representation for our masked BPE token and classify to what original BPE the masked token corresponded to (multi-class classification over BPE vocabulary)

As part of the tutorial, we will carefully review all steps as presented in the below illustration of MLM. 

<img src="./img/mlm.png" alt="mlm" width="1800"/>

One aspect becomes clear from the right-hand side of the above illustration (step 4): __Masked Language Modeling is multi-class classification over the entire BPE vocabulary.__ Intuitively, we are classifying what token we had originally masked. The pressing question becomes how does the transformer model (learn to) represent the masked token such that the unnormalized (log-)probability, i.e. the inner product of the masked representation and the embedding of fox, is maximum for fox?\*\* Moreover, how does pre-training models on such an objective help us in downstream tasks like sequence or sentiment classification? Throughout this tutorial, we will conceptually review the above pipeline to develop an intuitive understanding how and why pre-training works.

\*\* Notice that softmax activation casts our inner prodcuts as a probability vector. Hence, the inner products can be considered unnormalized log-probabilities as softmax both normalizes the scores and forces the vector sum to 1.

### __The Transformer Architecture__

To that end, we first take a more architectural perspective of what is going on by examining the transformer architecture on a high level. In doing so, we see that transformer models like BERT typically follow certain modeling patterns. In particular, we have our input token and positional embedding layers, which are subsequently fed into the encoder. The encoder stacks $N$ x transformer blocks, which themselves are composed of Multi-Head Attention and Feed-Forward sub-networks. At last, the resulting, transformed representations then flow into the classification head, which is made up of the transposed token embedding layer (a linear layer) and a softmax activation to output a probability vector over our vocabulary conditioned on our input sequence.

In what follows, we will review the (i) textual preprocessing, (ii) input embeddings, (iii) sequentially examine self-attention and feed-forward networks within transformer blocks, and finally walk through the (iv) MLM classification in detail. 

<img src="./img/transformer.png" alt="transformer" width="1800"/>

Consider the above both flow charts carefully to understand how they relate to each other. In parcticular, embedding layer is encapsulated in token and positional embeddings, the $N$ transformer blocks are captured by the model segment, and at last the classification head is the transpose of the token embedding layer paired with softmax activation.

Also note that the parameters of these layers are initialized with small random values, which are updated in a meaningful way to generalize well across downstream tasks as key part ouf Masked Language Modeling pre-training objective.

#### __The Add & Norm Blocks__

Should you have inspected the transformer chart carefully, you will have realized that we have not yet covered the entirety of the architecture. In particular, the _Add & Norm_ layers are not as important to understand the mere modeling intuitions as they represent architectural choices to facilitate (gradient flow during) training. 

1. __Add:__ The Add-operation encapsulates residual connections from the connections that skip the intermediate layers. Residual representation learning captures the idea that opposed to learning $g(x) = f(x)$, we rather learn $g(x) = x + f(x)$, where you can think of $g$ and $f$ as representing functions (layers) as part of our model. Layers are now intended to step-wise meaningfully amend the original representation rather than learning entire intermediate representation spaces anew. More importantly, in the backward pass, gradients de facto also skip the $f(x)$ layer for improved gradient information flow.
2. __Layer normalization:__ Norm abbreviates layer normalization layers that, on a high level, normalize the learned representations on a token level. Realize how layer normalization therefore ties into our discussion of inner product as measures of compatibility. As vector lengths are made more comparable, the resulting inner products become ever more dependent on the angle between the two token vectors.

### __1. Preprocessing Steps__

Before commencing training of our transformer model, we need to represent the textual input as something the model can understand. Ever since the advent of neural NLP, vector-based representations of tokens have become prevalent due to their desirable properties. Specifically, parameter sharing and end-to-end training of vector-based representations enable token representation learning, such that tokens with similar semantics are embedded nearby in vector space. Notice that this again entails that their length-normalized representations have comparably large inner products, a score you might be familiar with as cosine similarity.

The below illustrated example pipeline highlights how preprocessing a training corpus is transformed into machine-"readable" input in form of vectors. First, we learn a BPE tokenizer (more below) in an unsupervised way, which frames our word vocabulary as a composition of variable-length sequences of byte-pair encodings (BPE) like "\_jump" and "ed". Such tokenization algorithms facilitate generalizing to represent tokens not captured in our training corpus as a composition of BPE, a particulary important feature for morphologically rich languages. We break our input sequence into its BPEs with our learned tokenizer and add the special tokens \[CLS\], \[SEP\], and \[MASK\] and convert the tokens into the ids that correspond to the respective row in our token BPE embedding layer such that the model understands the input. In other words, the 3,721 row of our token embedding matrix (always) corresponds to the vector for "over". In addition, do not get confused by the preceeding "\_", these so-called metaspaces are added to help us keep track what BPE denote the beginning of a word. 

<img src="./img/preprocessing.png" alt="preprocessing-steps" width="1800"/>

#### Subword tokenization: BPE algorithm

__Byte-Pair Encodings__ is one of many approaches to __frame__ an __open token vocabulary as a composition of variable-length character sequences__, i.e. byte-pair encodings, from a fixed-size vocabulary. To that end, the BPE algorithm iteratively counts the most frequent pair of characters within tokens, for instance, _A_ and _B_, and replaces that pair by its fused symbol _AB_. Thus, the resulting vocabulary constitutes the original characters and the fused character n-grams, the latter of which may make up entire tokens. Many variations of the core idea exist, but as BPE balances vocabulary size and decoding efficiency well, the algorithm remains the standard to represent textual input for transformers and is widely used throughout NLP.

BPE thereby bypass issues arising from linguistic morphology and closed vocabularies.  When a model treats tokens as atomic units, it can neither embed out-of-vocabulary tokens nor can it incorporate relatedness among tokens from linguistic morphology a priori. While varying inflections of the same token are typically embedded in near neighborhood, any embeddings-based model first has to learn that property. Out-of-vocabulary tokens naturally occur as vocabularies were typically restricted to the 60K most frequent tokens. The latter problem highlights that the model otherwise disregards the word formation process. For instance, consider agglutinated or compounded words such as the German token "Abwasser|behandlungs|anlage", which seamlessly segments to sewage water treatment plant in English. Variable-length rather than fixed-sized representations are much more intuitive for these types of tokens, which BPE can effectively express.  

For more information, this [blog post](https://leimao.github.io/blog/Byte-Pair-Encoding/) nicely illustrates how a BPE vocabulary is constructed iteratively.

These recommended libraries implement various subword tokenization algorithms: [Huggingface tokenizers](https://github.com/huggingface/tokenizers), [Google's sentencepiece](https://github.com/google/sentencepiece)

### __2. Embedding Layer__

So far we have focused on the token embedding layer, but in fact, our embedding layer actually constitutes two sub-embedding layers for both tokens and their respective positions. As before mentioned, embedding layers typically stack categorical representations (row vectors) to a matrix. In other words, as shown in the MLM graphic above, we retrieve the embedding for the \[CLS\] token by indexing the corresponding row in the token embedding layer. Accordingly, positional encodings represent the row for the token offset within the input sentence in the position embedding matrix.

<img src="./img/positional_encodings.png" alt="embeddings" width="1800"/>

Positional encodings are required as any transformations in the network otherwise are position-invariant, meaning a token regardless of its position would take an identical representation. In other words, the token representations for a fluent sentence and a disorderly shuffled version of the original sequence would be identical up to pivoting token order without positional encodings. Of course, such invariance properties does not hold true for natural language and thus a transformer accounts for token position by adding positional encodings element-wise to the initial token representation.

You can find a pseudo-code implementation for the token and position embeddings below. Here, we create toy token and position embedding layers. Specifically, we define the size of our vocabulary (`N_VOCAB`), the dimensionality of our representations (`embedding_dim`), as well as the maximum sequence length in our training corpus (`maximum_sequence_length`). We then sample 20 (our sequence length) valid token indices (0, ..., 9'999, as part of our vocabulary) and efficiently access their respective respresentations. Next, we get the positional embeddings, i.e. the first 20 rows of our positional embedding matrix, and add them to the corresponding token represenations by positions.

#### Pseudo code: Token Embeddings

In [None]:
# TODO

#### Pseudo code: Positional Encodings

In [None]:
# TODO

The above figure and code showcase highlights how we initially induce our transformer input representations. However, so far our representations merely take the token and its position in the sentence into consideration. Nonetheless, both models have several issues. Each token embedding is an atomic and invariable representation and can only be altered by changing its position. Thus, the models cannot inherently represent, for example, phrases unless explicitly accounted for. Similarly, embeddings also cannot factor in, for instance, polysemy of a token. Moreover, the approach cannot embed linguistic notions such as antonymy between tokens as such tokens appear in similar contexts. While the aforementioned issues are non-exhaustive, they should well illustrate what we cannot yet explicitly model with our representations thus far. Therefore, in what follows, we will discuss how transformer blocks, i.e. repeated layers grouping self-attention and feed-forward sub-networks, try to address these modelling issues to more adequately represent the tokens by considering the context within the input sequence. In doing so, we will also illustrate these repeated transformations on a vector-level to link our input from the code showcase to modelling intuitions. That said, always consider carefully how we model natural language. While pre-trained transformers represent a large progress, not all of the aforementioned problems are entirely remediated.

### __3. Transformers as a Series of Transformer Blocks__

As mentioned, transformers stack repeated transformer blocks $N$ times, which themselves are made up of self-attention and feed-forward sub-networks. These transformer blocks are the engine to transform our input embeddings into representations meaningful for our MLM pre-training and post-hoc fine-tuning tasks. Moreover, in doing so, the resulting token representations address a lot of the issues elaborated upon in the previous paragraph. Consequently, transformers are the backbone of the many state-of-the-art models across NLP tasks.

#### __Self-Attention__

In principle, a self-attention module compares every token to every other token, including itself, in the input sequence, and reweighting the vectors of each token to include contextual relevance from the aforementioned contextual comparison. How does the self-attention module manage to incorporate such a notion into our token vector representations?

On a high level, self-attention on the sequence-level works as follows:
1. Get Keys, Values, Queries: linearly transform representations of our input sequence to Keys, Values, and Queries
2. Compute attention weights: the inner product ("compatibility") of our Keys and Queries tokens to get attention weights by query token (remember, Keys & Queries are tokens from our input sequence)
3. By query token, take the element-wise sum of values weighted by attention weights to compute the context vector \*\*
4. Add the context vector to the original untransformed vector of the corresponding query (residual connection) to generate the new token representation

Note: a linear transformation is nothing but a linear layer in a neural network.

The below figure showcases attention for a single query. The entire procedure can be seamlessly scaled to full sequences when the Queries and Values represents the linear transformations of stacked token representations.

\*\* Recall that element-wise sum (cf. adding token and positonal embeddings) means that we sum up vectors along the hidden dimension to reduce to a vector of the same dimensionality. The vectors are now weighted by the attention weights prior to summing them up.

<img src="./img/attention.png" alt="mlm" width="1800"/>

Intuitively, as the model processes each token, self-attention allows it to look at all other tokens, again including itself, in the input sequence for information that can help lead to a better encoding for said token. Thereby, the token representation now additionally embeds information from the context within the input sequence. A couple of considerations will hone our understanding of self-attention. On one hand, a large attention weight on a token itself means the largest share of the original token information is preserved. On the other hand, self-attention infuses contextual information from the sequence also into our \[MASK\] token, which the model will rely upon to recover (classify) the original token it stemmed from.

Furthermore, multi-head attention is a design trick to allow token representations to attend to multiple tokens in a single self-attention module. To that end, the above process now is split into even chunks by token representation (cf. bottom right of figure). For instance, if we have 8 attention heads, the dimensionality of our token representation has to be divisible by 8 to result in 8 evenly sized chunks. Each chunk now represents sub-segments of the token representation for which the model simultaneously and separately performs the above laid out self-attention procedure. Thereby, each token can more easily attend to multiple tokens in a single segment by attending to varying tokens by chunk (i.e. by attention head).

In [None]:
# TODO: Self-Attention

To further shape our understanding between algebra and self-attention. Ask yourselves, what token should get the largest attention weight in our _sample code_?\*\* Here, we use the [argmax](https://pytorch.org/docs/stable/generated/torch.argmax.html) function to find out. In brief, the argmax function returns the index for the largest value. Furthermore, we can, similar to above, set an axis along which we want to perform the operation, namely along the columns (since our rows represent the query tokens).

\*\* Remember, this is merely an illustrative sample to demonstrate some of the connections we have been drawing thus far.

In [None]:
# The argmax function returns the index for the largest value: https://pytorch.org/docs/stable/generated/torch.argmax.html 
attention_weights.argmax(dim=1)

What does the output mean? Well, the argmax of each row reflects the offset of the original query token. This makes sense as we performed pair-wise inner products to the unaugmented input. Intuitively, without any linear transformation or other influences prior to weight computation, each token will naturally attend itself the most. Can you see why that is? Specifically, remember inner products reflect compatibility and here we are essentially demonstrating that, assuming comparable lengths between our token representations, that each token itself points most closesest in the same direction. We briefly verify that all tokens have comparable lengths by examining their l2-norm, which is a common measure for vector length.

In [None]:
# we see that vector norms are comparable; hence, the inner product is primarily affected by the angle between vectors
token_embeds.norm(p=2, dim=1)

#### __Feed-Forward Sub-Network__

The feed-forward sub-network can be seen as two position independent recompositions of our token representations resulting from self-attention. To get an understanding of what that means, we go through the statement step-by-step.

* Position-independence: both the upscaling and downscaling layers within the feed-forward network recompose the inputs independent of the position of the respective token. Say \[CLS\] was shuffled towards the end with an identical input representation, the transformation would be just the same. Why? We see that the immediate outputs are pair-wise inner products of the input representations and the weights with an added bias on top. As all weights for both up and down scaling are shared independently across (token) positions, these transformations (combined, recomposing the input) are identical irrespective of position.
* Recomposition(s): 
    1. Upscaling: the initial upscaling layer intuitively increases the representation spaces as the intermediate representation dimensionality, the so-called feed-forward dimension, is typically c.2-4x the input embedding representation dimensionality.
    2. Non-linearity: the non-linearity is most commonly GELU, which is a smoother variant of ReLU. This is an empirical choice and has been shown to work well. Non-linearity again are required to enhance representative capacity of models.
    3. Downscaling: the downscaling layer can be seen as a filter of our expanded token representations, recomposing the feed-forward dimensionality of the intermediate representation back to the embedding dimension

<img src="./img/ffn.png" alt="ffn" width="1800"/>

To develop an understanding why recomposing our representations might be important, recall as to how our token representations have been actually transformed up to this point. Specifically, notice that multi-head attention performs self-attention multiple times simultaneously on even chunks of our input representation. That said, it is not perfectly clear apriori how many attention heads we should incorporate in our self-attention modules, as for instance, the optimal number of attention heads might as well depend on the sequence length. In other words, too many attention heads for too short sequences intuitively could encode repeated information. The feed-forward sub-networks elegantly address such issues by recomposing our self-attended representations.

Below, we again briefly examine a pseudo-code implementation of our feed-forward (sub-)network.

In [None]:
# TODO: Feed-Foward

#### __Transformer Block as a Whole__

Our discussions on self-attention and the feed-forward sub-network highlight what transformer blocks actually represent: a series of transformations that first re-fine our representation in context of the entire sequence (self-attention) and then recompose these resulting representations in the feed-forward sub-network.

While we have had a look at the sub-modules individually, one might wonder as to why we need to repeatedly stack transformer blocks? On a high level, larger parameterization (larger vocabularies, higher embedding and feed-forward dimensionalities) expand our representation space and facilitate training. Deeper models sequentially transform input, such that they are linearly separable (i.e. linearly classifiable) for our Masked Language Modeling task. A lot of ongoing research efforts investigate how to improve parameter efficiency of transformer models. For instance, smaller student transformer networks can be taught the knowledge of larger transformers. Nevertheless, that does not have to easily translate to learning these parameters right away at training time, as model depth, layer initialization, training data & objective, and so forth, all play an important role during training. A proper discussion of parameters in deep learning, however, is beyond this tutorial as we focus on modeling intuitions of transformer layers and the task of Masked Language Modeling.

## __Classification Head: Masked Language Modeling as Multi-Class Classification over (BPE) Vocabulary__

As priorly mentioned, the classification head performs multi-class classification over the BPE vocabulary for the \[MASK\] input representation. The weights of the classification head (excl. the bias vector) are tied to the transpose of the token embedding layers, meaning they share the parameters. Recall our initial intuition of dot products as a measure of compatibility. In other words, the token embedding that is most compatible to the transformed representation of our \[MASK\] token will be the token that the model predicts as their inner product will be largest over the rest of the vocabulary. How did we arrive at that presentation? We feteched the initial token representation for our \[MASK\] token from the embedding layer, added the positional encoding, and fed the entire sequence accordingly through a series of transformer blocks. Thereby, we repeatedly enabled the \[MASK\] token to attend to the sequence and recomposed its resulting representation. In doing so, the transformer infuses the required (sentence-level) semantics and syntactics into the representation, such that the embedding is similar to tokens that would take similar representations in comparable contexts. The resulting inner product of our \[MASK\] token representations and our classification head can be interpreted as unnormalized log probablities. Consequently, the softmax function transforms these scores into a valid probability distribution. Such a formulation allows us to optimize our parameters using the gradients derived from minimizing the cross-entropy between our prediction and actual label. As the illustration below highlights, we intuitively try to minimize the spreads between predicted and actual probability distributions: the model will try to increase (reduce) the predicted probability for true (false) tokens. To that end, more specifically, the gradients of the spreads in predicted and actual probability distributions encode information as to how model weights have to be updated to better perform Masked Language Modeling.

But why can the transformer learn what how the representation should look like for the \[MASK\] token? Intuitively, the model converges towards parameters that embed similar semantics in close neighborhood while also respecting syntactics in the embedding vector space; in other words, dogs and cats take similar representations in comparable sentences and so forth, as such tokens occur in similar contexts. Such weights coincide with a good optimization for MLM, as potential candiates for our masked tokens in all likelihood often occur in these similar contexts. We thereby learn to adequately infuse contextual information into the representations of our masked tokens. Furthermore, also recall that we not always mask the sampled tokens, but also sometimes allow the model to just predict the BPE token that is already in the sequence, as otherwise there is too much disparity between training in downstream tasks, for which we never mask tokens, and pre-training. In practice, one might imagine that Masked Language Modeling actually is an extremely difficult feat to perfectly solve. While ambiguitiy in natural language most likely prevents us from very accurately predicting (recovering) all combinations of masked token, we efficiently seize on these regularities in natural language to learn suitable representations for our tokens.

<img src="./img/cls_head.png" alt="cls_head" width="1800"/>

Consequently, we now also have derived an understanding why it is sensible to share parameters between our token embedding and classification layers. Not only do we reduce the effective parameter count of our (typically already very large) transformer model, but we also retain a notion of representative consistency, which ideally facilitates training.

In addition, the aforementioned process of representation learning highlights why pre-training is important. Whenever we fine-tune for downstream tasks for which we do not have an abundant amount of labelled data, the model now can better generalize to unseen input as it learned the underlying relationships between tokens, i.e. sentence-level semantics, to a large degree already during pretraining. In other words, we learn what semantics in our representation space, i.e. the sequences, its tokens, and how they relate to each other as part of our training data, is relevant to predict the label of our task. Due to pretraining, we can now lever our prior language model knowledge to transfer seamlessly to similar semantics as represented by transformed token embeddings.

An example implementation of our classification head consequently is quite simple and naturally follows from our discussion. 

In [None]:
# TODO: Classification

In [None]:
print(f"{input_ids=}")
print(f"{probabilities.argmax(1)=}")

In a real training setting we would pass the unnormalized log-probabilities and our vector with reference labels to `torch.nn.CrossEntropyLoss`, which then performs the above-illustrated backward pass to update the weights.