Byte pair encoding

A Gentle Dive into Byte pair encoding!

By Suraj Jha

This is to explore how byte pair encoding ( BPE ) works. Code implementation used in this blog can be found in the github repo

What is it used for ?

BPE is a foundational subword tokenization technique used in modern large language models like GPT-2 and GPT-3.

What is so special about it ?

Traditional word-based tokenization struggles with rare words, leading to large vocabularies or OOV tokens.

The key advantage of BPE is that it operates directly on UTF-8 bytes, ensuring any unicode can be represented without unknown tokens (UNK)

We start with a base vocabulary of all 256 possible bytes (0-255). We then iteratively merge the most frequent adjacent byte pairs into new tokens, building a vocabulary of merged subwords.

core intuition : Any Unicode character can be represented as a sequence of bytes (via UTF-8 encoding).


Implementation

Code implementation can be found in the github repo. I will focus on giving the summary behind it :

1. Initialization

Vocabulary starts with 256 single-byte tokens

def __init__(self):
        self.vocab = {i: bytes([i]) for i in range(256)} 
        self.vocab_size = 256
        self.merges = {}

2. Training:

# This is the tokenizer training data
train_bpe("hello 😊 hello 😊 suraj is learningAI. suraj is excited. learningAI is fun")

Step1. Convert text to UTF-8 bytes.

Step2. Repeatedly find the most frequent adjacent byte pair.

Step3. Merge it into a new token (ID starting from 256).

Step4. Update the byte representation of the training text.

Step5. Stop when no pair appears ≥2 times (simple stopping criterion; in practice, use a fixed number of merges or a vocab size).

3. Encoding:

Greedily apply all learned merges in order.

def encode(self, text):
        byte_text = self.text_to_bytes(text)
        for merge_pair in self.merges:
            byte_text = self.find_and_replace(byte_text, merge_pair)
        return byte_text

4. Decoding:

Reconstruct bytes from token IDs and decode to string.

def decode(self, encoded_text):
        bytes_list = [self.vocab[token_id] for token_id in encoded_text]
        full_bytes = b''.join(bytes_list)
        return full_bytes.decode('utf-8', errors='replace')

Encoding a new sentence:

encoded = bpe.encode("hello suraj is learningAI")
print(encoded) 

This encodes into [263, 273, 282] (compressed into fewer tokens, using the training we did in previous step)

Unseen Variations:

bpe.decode(bpe.encode("hello worlf")) 
bpe.encode("hindi")
  1. worlf is a minor typo , but that doesnt result in OOV error for us.

  2. hindi was not in the corpus , hence we fallback to bytes

Wrapping it all up

Better (more extensive/diverse) training data will result in better merges and compression.

Hope this was an informative read . Cheers!

Tags: BPE
Share: X (Twitter) LinkedIn