This is to explore how byte pair encoding ( BPE ) works. Code implementation used in this blog can be found in the github repo
What is it used for ?
BPE is a foundational subword tokenization technique used in modern large language models like GPT-2 and GPT-3.
What is so special about it ?
Traditional word-based tokenization struggles with rare words, leading to large vocabularies or OOV tokens.
The key advantage of BPE is that it operates directly on UTF-8 bytes, ensuring any unicode can be represented without unknown tokens (UNK)
We start with a base vocabulary of all 256 possible bytes (0-255). We then iteratively merge the most frequent adjacent byte pairs into new tokens, building a vocabulary of merged subwords.
core intuition : Any Unicode character can be represented as a sequence of bytes (via UTF-8 encoding).
Implementation
Code implementation can be found in the github repo. I will focus on giving the summary behind it :
1. Initialization
Vocabulary starts with 256 single-byte tokens
def __init__(self):
self.vocab = {i: bytes([i]) for i in range(256)}
self.vocab_size = 256
self.merges = {}
2. Training:
# This is the tokenizer training data
train_bpe("hello 😊 hello 😊 suraj is learningAI. suraj is excited. learningAI is fun")
Step1. Convert text to UTF-8 bytes.
Step2. Repeatedly find the most frequent adjacent byte pair.
Step3. Merge it into a new token (ID starting from 256).
Step4. Update the byte representation of the training text.
Step5. Stop when no pair appears ≥2 times (simple stopping criterion; in practice, use a fixed number of merges or a vocab size).
3. Encoding:
Greedily apply all learned merges in order.
def encode(self, text):
byte_text = self.text_to_bytes(text)
for merge_pair in self.merges:
byte_text = self.find_and_replace(byte_text, merge_pair)
return byte_text
4. Decoding:
Reconstruct bytes from token IDs and decode to string.
def decode(self, encoded_text):
bytes_list = [self.vocab[token_id] for token_id in encoded_text]
full_bytes = b''.join(bytes_list)
return full_bytes.decode('utf-8', errors='replace')
Encoding a new sentence:
encoded = bpe.encode("hello suraj is learningAI")
print(encoded)
This encodes into [263, 273, 282] (compressed into fewer tokens, using the training we did in previous step)
Unseen Variations:
bpe.decode(bpe.encode("hello worlf"))
bpe.encode("hindi")
-
worlf is a minor typo , but that doesnt result in OOV error for us.
-
hindi was not in the corpus , hence we fallback to bytes
Wrapping it all up
Better (more extensive/diverse) training data will result in better merges and compression.
Hope this was an informative read . Cheers!