Attention & transformers
Tokenization, embeddings, scaled dot-product attention, multi-head attention, the transformer block, and how the same architecture powers modern NLP, vision (ViT), and even graph and audio models.
Tokenization
Models don't see text — they see integer token IDs. The tokenizer converts strings to IDs and back.
- Word-level: simple but breaks on unseen words and inflates vocabulary.
- Character-level: tiny vocabulary, very long sequences.
- Subword (BPE / WordPiece / SentencePiece): the modern default. Common words stay whole, rare words split into subword pieces. Vocabulary 32K–128K typical.
# conceptual sketch — using a BPE-style tokenizer
text = "Transformers tokenize text into integer IDs."
tokens = tokenizer.encode(text) # e.g. [1532, 18, 9241, ...]
back = tokenizer.decode(tokens) # "Transformers tokenize ..."
Embeddings
Each token ID is looked up in an embedding table to produce a dense vector (typically 256–4096 dimensions).
import torch.nn as nn
vocab_size = 32_000
d_model = 512
token_embed = nn.Embedding(vocab_size, d_model)
ids = torch.tensor([[1, 5, 9, 4]]) # batch=1, seq_len=4
x = token_embed(ids) # shape (1, 4, 512)
Embeddings are learned during training — semantically similar tokens end up nearby in the vector space.
Positional embeddings. Self-attention is permutation-equivariant — it doesn't know token order. So position is injected explicitly, either as learned positional embeddings, sinusoidal embeddings, or rotary (RoPE) embeddings added or applied to the input vectors.
Scaled dot-product attention
The core mechanism. Given three matrices — Queries Q, Keys K, Values V — attention computes:
- Q Kᵀ measures how strongly each query "matches" each key.
- ÷ √dₖ prevents the dot products from growing too large at high dimension (which would saturate the softmax).
- softmax turns scores into a probability distribution over positions.
- · V takes a weighted average of value vectors according to those probabilities.
import torch
import torch.nn.functional as F
def attention(Q, K, V, mask=None):
# Q, K, V: (batch, n_heads, seq_len, d_head)
d_k = Q.size(-1)
scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
return weights @ V
Multi-head attention
Instead of one attention with full d_model dimensions, split into h heads of size d_model / h. Each head learns to look at a different relationship; their outputs are concatenated and linearly projected.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, D = x.shape
# project and reshape to (B, n_heads, T, d_head)
Q = self.W_q(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
out = attention(Q, K, V, mask) # (B, n_heads, T, d_head)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(out)
The transformer block
A single transformer block stacks: self-attention → residual + LayerNorm → feed-forward MLP → residual + LayerNorm.
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads)
self.norm1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(),
nn.Linear(d_ff, d_model),
)
self.norm2 = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = x + self.drop(self.attn(self.norm1(x), mask))
x = x + self.drop(self.ff(self.norm2(x)))
return x
Stack N of these (typically 6–96) and you have a transformer encoder. Add a causal mask to the attention and you have a decoder (used in GPT-style language models).
Pre-training & fine-tuning
- Pre-training. Train on a massive unlabeled corpus with a self-supervised objective:
- Causal (autoregressive) — predict the next token. Used by GPT-family models.
- Masked language modeling (MLM) — mask 15% of tokens and predict them. Used by BERT-family models.
- Fine-tuning. Take a pre-trained model and continue training on a small labeled dataset for the target task (classification, QA, sentiment).
- Parameter-efficient fine-tuning (PEFT). LoRA, adapters — freeze the base model and train only a small set of additional parameters. Cheap, fast, almost as good.
- Prompting. For sufficiently large models, the task is described in natural language and no weights change at all (zero-shot / few-shot in context).
NLP applications
- Text classification. Sentence in → encoder → pool → linear head → label.
- Named entity recognition (NER). Token-level classification.
- Question answering. Predict start & end token positions of the answer in a context.
- Summarization & translation. Encoder-decoder transformer.
- Generation. Decoder-only autoregressive sampling.
Computer vision
- CNNs — convolutional backbones (ResNet, EfficientNet). Still dominant for small data and edge deployment.
- Vision Transformers (ViT) — chop the image into 16×16 patches, treat each patch as a token, run a standard transformer. Beats CNNs at large data scale.
- Object detection — predict bounding boxes + class labels. Two-stage (Faster R-CNN) vs. single-stage (YOLO, SSD, DETR).
- Segmentation — per-pixel labels. U-Net, Mask R-CNN.
- Generative models — autoencoders (compress and reconstruct), VAEs (probabilistic), GANs (adversarial), diffusion models (iteratively denoise from noise).
Common pitfalls
- Forgetting positional information. Without positions, the transformer treats input as a bag of tokens.
- Wrong attention mask. For causal models you need a strictly lower-triangular mask; for padded batches you need to mask out padding tokens.
- Numerical instability in softmax. Always divide by √dₖ before softmax to avoid saturation.
- Memory blows up. Attention is O(n²) in sequence length. Long sequences need tricks (sparse attention, sliding window, FlashAttention).
- Overfitting on small fine-tuning datasets. Use small learning rate, low number of epochs, and a warmup schedule.