Adding “attention” to Machine Translation (full python code and explanation)

Amay Gada
Analytics Vidhya
Published in
4 min readFeb 8, 2022

--

In this article, we extend the seq2seq model we built in part 1 and add ‘attention’ to it. I have previously covered the need of sequence to sequence models, modelling it using RNNs in part 1.

Note that this article has snippets of code. For the full code base click here.

I insist going through the part 1 in order to ease into part 2

Why attention?

We want the encoder to remember the important words of the input sentence which affect how the decoder decodes the encoding. For this we use Attention mechanism.

Building an Intuition

Let’s consider the following translation :

ENGLISH
I have a blue cat who ate a can of dog food two days ago .
SPANISH
Tengo un gato azul que se comió una lata de comida para perros hace dos días.

For the model to predict each spanish word in the decoder, it needs to know which english words are most affecting the next word that the decoder will predict.

Since we need to predict 'gato' (cat in spanish), the english word 'cat' has the highest weight and hence importance during prediction.This otherwise would have been lost due to vanishing gradients since the word 'cat' is at the beginning of the input sentence.

Getting into the details

Consider the above model

α1 α2 ... α7  ->  the weights for attention
h1 h2 ... h7 -> output of each RNN layer in the encoder
S0 -> Hidden state of the last RNN layer
[h1, h2, ... h7] => H
[S0] -> S

Computing the attention weights (alphas)

We introduce to trainable parameters: WQ (query weight) and WK (key weight).

computing Key (can be seen as a dense layer)
computing query (can be seen as a dense layer)
computing the attention weights

Discussing Shapes

H -> (64,16,1024)
WK -> (1024,1024)
K -> (64, 16, 1024)
S -> (64,1,1024)
WQ -> (1024,1024)
Q -> (64,1,1024)
alpha -> (64,16,1)

Computing the context vector

The context vector is multiplied with the RNN outputs (H). It can be intuitively seen as multiplying each words output with some ‘importance’ weight.

computing the context vector

Moving on to the Decoder

The word embedding for the current word in the decoder (E) is concatenated with the Context Vector (C). This joint vector is used as the hidden state in the RNN.

Note that before each new word is predicted by the decoder, a new context vector is computed.

Coding up the Dot Product Attention

Building the decoder with Attention

Code

Training

Code

Results

Translate

Do we really need RNNs?

What does attention offer?

  1. Takes into account all the input words (Bidirectional, sequential RNNs)
  2. Computes a context vector that tries to solve vanishing gradients (GRUs/LSTMs)
  3. Allows parallel computation of weights during forward and backward propogation as only dense layers are involved.
We see that attention can do a lot of things RNNs can do.
Hence it is worth to see an experimental setting where we remove RNNs entirely and use only attention.
It is in fact proven that Attention performs faster and much better than RNNs. This research work was done in a paper called Attention is all you need.We'll naively implement pure attention in Part 3 (COMING SOON!)

For the entire Code Base visit

References

  1. https://www.youtube.com/watch?v=pLpzU-xGi2E&t=3028s
  2. https://github.com/YanXuHappygela/NLP-study/blob/master/seq2seq_with_attention.ipynb
  3. https://www.youtube.com/watch?v=B3uws4cLcFw&list=PLgtf4d9zHHO8p_zDKstvqvtkv80jhHxoE

--

--