NLP Subword 之 BBPE(Byte-level BPE) 算法原理
本文将介绍以下内容:
- 1. BBPE 算法原理
- 2. BBPE 算法流程
- 2.1 构建初始词表
- 2.2 统计频率
- 2.3 合并频率最高的字节对
- 2.4 重复合并步骤
- 2.5 分词
- 2.6 解码
- 3. BBPE 算法源码实现 Demo
通过使用发现BPE理论上还是会出现OOV的,当词汇表的大小受限时,一些较少频繁出现的子词和没有在训练过程中见过的子词,就会无法进入词汇表出现OOV,而Byte-level BPE(BBPE)理论上是不会出现这个情况的。
由于BBPE相较BPE是使用UTF-8字节编码的字节级别作为最小单位,所以在开始之前,最好先阅读笔者这边文章:
看懂 Unicode 与 UTF-8 编码全过程
1. BBPE 算法原理
Byte-level BPE(BBPE)是一种基于字节对编码(Byte Pair Encoding,BPE)的改进算法,它解决了传统BPE在处理某些特殊字符或OOV(Out-of-Vocabulary,超出词汇表)问题时的不足。与传统BPE基于字符级别的粒度不同,BBPE使用UTF-8字节编码的字节级别作为最小单位,因此能够处理所有可能的字符,而不依赖于特定语言的字符集或词汇表。
关键特性:
- 字节级别表示: BBPE使用UTF-8字节表示每个字符,从而解决了OOV问题。每个字节由0到255的整数表示,能够涵盖所有字符,包括多字节的Unicode字符。
- 避免OOV: 由于所有字符都被映射为字节,BBPE算法不再依赖于字符集的定义,避免了传统BPE中可能出现的未见过的子词(OOV)问题。
- 合并策略: BBPE通过合并频率最高的字节对来扩展词汇表,生成新的子词单元。随着合并次数的增加,字节对的频率分布逐渐变化,直到达到指定的词汇表大小。
BBPE与BPE的区别:
- BPE最小词汇表是字符级,而BBPE是字节级别,通过UTF-8的编码方式这一个字节的256位的范围,理论上可以表示这个世界上的所有字符。
2. BBPE 算法流程
2.1 构建初始词表
BBPE的初始词表由所有可能的字节(0至255)组成。每个字节被映射为一个唯一的符号,可以通过bytes_to_unicode
函数实现:
def bytes_to_unicode():bs = list(range(33, 127)) + list(range(161, 173)) + list(range(174, 256))cs = bs[:]n = 0for b in range(256):if b not in bs:bs.append(b)cs.append(256 + n)n += 1cs = [chr(c) for c in cs]return dict(zip(bs, cs))
- 字节到Unicode映射: 通过
bytes_to_unicode
函数,0至255的字节值被映射到可见的Unicode字符。 - 字节到符号的映射: 每个字节(0-255)都映射到一个符号(Unicode字符),为后续的训练准备了初步的词汇表。
2.2 统计频率
在初始词表构建之后,BBPE通过遍历文本,统计所有字节对(subword)的出现频率。这一过程类似于传统BPE中的“统计步骤”,但粒度变为字节级别。每个字节对的出现频率在get_stats
函数中进行计算:
def get_stats(self):pairs = defaultdict(int)for word, freq in self.vocab.items():for i in range(len(word)-1):pairs[(word[i], word[i+1])] += freqreturn pairs
- 字节对统计: 遍历文本中的每个词,统计每个相邻字节对(byte-pair)的出现频率。
2.3 合并频率最高的字节对
BBPE会从统计出的字节对中选择出现频率最高的字节对,并将其合并成一个新的符号。这个过程对应了BPE中的“合并步骤”,但在BBPE中,合并的单位是字节对。
def merge_vocab(self, pair):a, b = pairtoken_a = self._id_to_token(a)token_b = self._id_to_token(b)new_token = token_a + token_bif new_token not in self.token2id:self.token2id[new_token] = self.next_idself.id2token[self.next_id] = new_tokenself.next_id += 1new_id = self.token2id[new_token]return (token_a, token_b)
- 合并频率最高的字节对: 选择频率最高的字节对pair,将其合并成一个新的符号,并更新词汇表。
2.4 重复合并步骤
BBPE的核心思想是不断重复合并操作,直到达到预定的词汇表大小。每次合并后,都会更新词汇表和字节对的频率统计。重复执行合并直到词汇表大小符合要求或没有更多的可合并的字节对为止。
def train(self, save_merges="merges.txt", save_vocab="vocab.txt"):num_merges = self.vocab_size - len(set(ch for word in self.vocab for ch in word))for i in range(num_merges):pairs = self.get_stats()if not pairs:breakbest = max(pairs, key=pairs.get)merged = self.merge_vocab(best)self.merges.append(merged)
- 重复合并: 通过train方法控制合并的次数,直到词汇表大小达到设定值。
2.5 分词
训练完成后,BBPE可以根据训练得到的词汇表对新的文本进行分词。分词的过程中,BBPE会将文本中的字节序列映射到词汇表中的符号,直到文本被分解成一个个基本的子词单元。
def encode_word(self, word):word = list(word.encode("utf-8"))word = [self.token2id[self._id_to_token(b)] for b in word]pairs = self.get_pairs(word)while pairs:bigram = min(pairs, key=lambda p: self.bpe_ranks.get((self._id_to_token(p[0]), self._id_to_token(p[1])), float("inf")))new_word = []i = 0new_token = self._id_to_token(bigram[0]) + self._id_to_token(bigram[1])new_id = self.token2id[new_token]while i < len(word):if i < len(word)-1 and (word[i], word[i+1]) == bigram:new_word.append(new_id)i += 2else:new_word.append(word[i])i += 1word = new_wordif len(word) == 1:breakpairs = self.get_pairs(word)return word
- 字节序列转子词: 通过encode_word方法对每个词进行编码,将其转换为一个或多个子词单元。
2.6 解码
BBPE还实现了将分词结果转换回原始文本的解码功能。通过decode_word函数,BBPE能够将分词后的符号转换回字节,再转为字符,最终恢复为原始的UTF-8文本。
def decode_word(self, tokens):text = "".join([self.id2token[t] for t in tokens])byte_decoder = {v: k for k, v in self.byte_encoder.items()}byte_seq = [byte_decoder.get(ch, ord(ch)) for ch in text]return bytes(byte_seq).decode("utf-8", errors="ignore")
- 符号转换回字节并解码: decode_word将BBPE的符号转换回字节序列,并解码为原始的UTF-8文本。
注:BBPE能够基于字节级的粒度进行高效的分词和词汇表生成,避免了传统BPE可能出现的OOV问题,具有更好的语言通用性,特别适用于多语言和特殊字符处理。
3. BBPE 算法源码实现 Demo
import re
from collections import defaultdict# ----------------- GPT-2 风格的 bytes_to_unicode -----------------
def bytes_to_unicode():"""GPT-2 中用于把字节(0-255)映射成可见的 Unicode 字符。避免直接用不可见/控制字符。"""bs = list(range(33, 127)) + list(range(161, 173)) + list(range(174, 256))cs = bs[:]n = 0for b in range(256):if b not in bs:bs.append(b)cs.append(256 + n)n += 1cs = [chr(c) for c in cs]return dict(zip(bs, cs))# ----------------- Byte-level BPE -----------------
class BBPE:def __init__(self, vocab_size=100):self.vocab_size = vocab_sizeself.vocab = {} # word -> frequencyself.merges = [] # list of mergesself.bpe_ranks = {} # pair -> rankself.token2id = {} # token(str) -> idself.id2token = {} # id -> token(str)self.next_id = 0 # 下一个分配的idself.byte_encoder = bytes_to_unicode() # 字节到可见符号的映射# ---------- 构建初始词表 ----------def build_vocab(self, corpus):vocab = defaultdict(int)for line in corpus:words = line.strip().split()for word in words:print(f"word:{word}")byte_seq = list(word.encode("utf-8")) # 直接字节,不加 </w>print(f"byte_seq:{byte_seq}")vocab[tuple(byte_seq)] += 1self.vocab = dict(vocab)print(f"初始vocab大小={len(self.vocab)}")print(f"初始vocab {self.vocab}")# 初始化 token 映射:0-255 字节print(f"self.byte_encoder:{self.byte_encoder}")for b in range(256):s = self.byte_encoder.get(b, chr(b))self.token2id[s] = self.next_idself.id2token[self.next_id] = sself.next_id += 1# ---------- 统计 pair ----------def get_stats(self):pairs = defaultdict(int)for word, freq in self.vocab.items():for i in range(len(word)-1):pairs[(word[i], word[i+1])] += freqreturn pairs# ---------- 合并 ----------def merge_vocab(self, pair):a, b = pairtoken_a = self._id_to_token(a)token_b = self._id_to_token(b)new_token = token_a + token_bif new_token not in self.token2id:self.token2id[new_token] = self.next_idself.id2token[self.next_id] = new_tokenself.next_id += 1new_id = self.token2id[new_token]new_vocab = {}for word, freq in self.vocab.items():print(f"word, freq:{word, freq}")new_word = []i = 0while i < len(word):if i < len(word)-1 and (word[i], word[i+1]) == pair:new_word.append(new_id)i += 2else:new_word.append(word[i])i += 1new_vocab[tuple(new_word)] = freqprint(f"old dself.vocab{self.vocab}")self.vocab = new_vocabprint(f"new dself.vocab{self.vocab}")return (token_a, token_b)# ---------- 训练 ----------def train(self, save_merges="merges.txt", save_vocab="vocab.txt"):assert self.byte_encoder == self.id2tokenalphabet = set(ch for word in self.vocab for ch in word)num_merges = self.vocab_size - len(alphabet)print(f"初始alphabet大小={len(alphabet)},目标vocab_size={self.vocab_size},合并次数≈{num_merges}")for i in range(num_merges):pairs = self.get_stats()print(f"pairs:{pairs}")if not pairs:breakbest = max(pairs, key=pairs.get)print(f"best:{best}")merged = self.merge_vocab(best)self.merges.append(merged)self.bpe_ranks = dict(zip(self.merges, range(len(self.merges))))print("训练完成 ✅")# 保存 mergeswith open(save_merges, "w", encoding="utf-8") as f:for a, b in self.merges:f.write(f"{a} {b}\n")# 保存 vocabwith open(save_vocab, "w", encoding="utf-8") as f:for token, idx in self.token2id.items():f.write(f"{token} {idx}\n")print(f"✅ merges 保存到 {save_merges}, vocab 保存到 {save_vocab}")# ---------- 编码 ----------def encode_word(self, word):word = list(word.encode("utf-8")) # 不加 </w>word = [self.token2id[self._id_to_token(b)] for b in word]pairs = self.get_pairs(word)while pairs:bigram = min(pairs, key=lambda p: self.bpe_ranks.get((self._id_to_token(p[0]), self._id_to_token(p[1])), float("inf")))if (self._id_to_token(bigram[0]), self._id_to_token(bigram[1])) not in self.bpe_ranks:breaknew_word = []i = 0new_token = self._id_to_token(bigram[0]) + self._id_to_token(bigram[1])new_id = self.token2id[new_token]while i < len(word):if i < len(word)-1 and (word[i], word[i+1]) == bigram:new_word.append(new_id)i += 2else:new_word.append(word[i])i += 1word = new_wordif len(word) == 1:breakpairs = self.get_pairs(word)return worddef get_pairs(self, word):return set((word[i], word[i+1]) for i in range(len(word)-1))# ---------- 解码 ----------def decode_word(self, tokens):text = "".join([self.id2token[t] for t in tokens])# 替换 bytes_to_unicode 回字节byte_decoder = {v: k for k, v in self.byte_encoder.items()}byte_seq = [byte_decoder.get(ch, ord(ch)) for ch in text]return bytes(byte_seq).decode("utf-8", errors="ignore")# ---------- 工具 ----------def _id_to_token(self, idx):if idx in self.id2token:return self.id2token[idx]else:return chr(idx)# ================== 示例 ==================
if __name__ == "__main__":corpus = ["deep learning is the future of ai","see my eyes first","see my dogs","you are the best","you are the fast","machine learning can be applied to natural language processing","深度学习是人工智能的未来","机器学习可以应用于自然语言处理","人工智能改变世界","学习深度神经网络在图像识别中表现优秀"]bbpe = BBPE(vocab_size=100)bbpe.build_vocab(corpus)bbpe.train("merges.txt", "vocab.txt")print("\n=== 单词测试 ===")for w in ["lowest", "人工智能", "深度学习"]:tokens = bbpe.encode_word(w)print(f"{w} -> {tokens} -> {bbpe.decode_word(tokens)}")print("\n=== 句子测试 ===")sentence = "lowest人工智能深度学习"tokens_list = [bbpe.encode_word(w) for w in sentence.split()]print(tokens_list)print("解码:", " ".join(bbpe.decode_word(toks) for toks in tokens_list))