NLP Subword 之 BPE(Byte Pair Encoding) 算法原理
本文将介绍以下内容:
- 1. BPE 算法核心原理
- 2. BPE 算法流程
- 3. BPE 算法源码实现Demo
BPE最早是一种数据压缩算法,由Sennrich等人于2015年引入到NLP领域并很快得到推广。该算法简单有效,因而目前它是最流行的方法。GPT-2和RoBERTa使用的Subword算法都是BPE。
1. BPE 算法核心原理:
它的主要思想是:
- 使用频率统计来逐步合并高频的字符/子词对。
- 从最小的单位(字符)开始,逐渐学习得到一套子词词表,使模型能够兼顾 常见词的完整表示 和 罕见词的组合表示。
在大语言模型时代,最常用的分词方法是Byte-Pair Encoding(BPE)和Byte-level BPE(BBPE)。该算法的核心思想是逐步合并出频率最高的子词对而不是像wordpiece一样通过计算合并分数。
2. BPE 算法流程:
(1)计算初始词表:通过训练语料获得或者最初的英文种26个字母加上各种符号以及常见中文字符,这些作为初始词表。
(2)构建频率统计:统计所有子词单元对在文本中的出现频率。
(3)合并频率最高的子词对:选择出现频率最高的子词对,将它们合并成一个新的子词单元,并更新词汇表。
(4)重复合并步骤:不断重复步骤2和步骤3,直到达到预定的词汇表大小、合并次数。
(5)分词:使用训练得到的词汇表对文本进行分词。
3. 算法源码实现Demo
import re
from collections import defaultdictclass BPE:def __init__(self, vocab_size=100):self.vocab_size = vocab_sizeself.vocab = {} # word -> frequencyself.merges = [] # list of mergesself.bpe_ranks = {} # pair -> rank# ---------- 构建初始词表 ----------def build_vocab(self, corpus):"""corpus: list[str],输入语料英文: 用空格分词中文: 逐字处理"""vocab = defaultdict(int)for line in corpus:words = line.strip().split()for word in words:chars = list(word) + ["</w>"] # 加上词边界vocab[tuple(chars)] += 1self.vocab = dict(vocab)# ---------- 统计 pair ----------def get_stats(self):"""统计 pair 的频率"""pairs = defaultdict(int)for word, freq in self.vocab.items():for i in range(len(word)-1):pairs[(word[i], word[i+1])] += freqreturn pairs# ---------- 合并 ----------def merge_vocab(self, pair):"""执行一次合并"""new_vocab = {}bigram = re.escape(" ".join(pair))pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')for word, freq in self.vocab.items():word_str = " ".join(word)new_word = tuple(pattern.sub("".join(pair), word_str).split())new_vocab[new_word] = freqself.vocab = new_vocab# ---------- 训练 ----------def train(self, save_merges="merges.txt", save_vocab="vocab.txt"):# 初始 alphabet 大小alphabet = set(ch for word in self.vocab for ch in word)num_merges = self.vocab_size - len(alphabet)print(f"初始alphabet大小={len(alphabet)},目标vocab_size={self.vocab_size},合并次数≈{num_merges}")for i in range(num_merges):pairs = self.get_stats()if not pairs:breakbest = max(pairs, key=pairs.get)self.merges.append(best)self.merge_vocab(best)# 构建 bpe_ranksself.bpe_ranks = dict(zip(self.merges, range(len(self.merges))))print(f"self.bpe_ranks:{self.bpe_ranks}")# 保存 mergeswith open(save_merges, "w", encoding="utf-8") as f:for a, b in self.merges:f.write(f"{a} {b}\n")# 保存 vocabvocab_tokens = set()for word in self.vocab:for token in word:vocab_tokens.add(token)with open(save_vocab, "w", encoding="utf-8") as f:for token in sorted(vocab_tokens):f.write(token + "\n")print(f"✅ merges 保存到 {save_merges}, vocab 保存到 {save_vocab}")# ---------- 推理 ----------def get_pairs(self, word):"""获取当前词的所有pair"""pairs = set()prev_char = word[0]for char in word[1:]:pairs.add((prev_char, char))prev_char = charreturn pairsdef encode_word(self, word):"""BPE 编码单个词"""word = tuple(list(word) + ["</w>"])pairs = self.get_pairs(word)if not pairs:return [word]while True:# 找到rank最小的pairbigram = min(pairs, key=lambda p: self.bpe_ranks.get(p, float("inf")))if bigram not in self.bpe_ranks:breaknew_word = []i = 0while i < len(word):if i < len(word)-1 and word[i] == bigram[0] and word[i+1] == bigram[1]:new_word.append(word[i] + word[i+1])i += 2else:new_word.append(word[i])i += 1word = tuple(new_word)if len(word) == 1:breakpairs = self.get_pairs(word)return list(word)def decode_word(self, tokens):"""还原单词"""word = "".join(tokens)if word.endswith("</w>"):word = word[:-4]return worddef encode_sentence(self, sentence):"""BPE 编码整句"""return [self.encode_word(w) for w in sentence.strip().split()]def decode_sentence(self, tokens_list):"""解码整句"""return " ".join(self.decode_word(toks) for toks in tokens_list)# ================== 示例 ==================
if __name__ == "__main__":corpus = ["deep learning is the future of ai","see my eyes first","see my dogs","you are the best","you are the fast","machine learning can be applied to natural language processing","深度学习是人工智能的未来","机器学习可以应用于自然语言处理","人工智能改变世界","学习深度神经网络在图像识别中表现优秀"]# 训练bpe = BPE(vocab_size=100)bpe.build_vocab(corpus)bpe.train("merges.txt", "vocab.txt")# 测试推理print("\n=== 单词测试 ===")for w in ["lowest", "newer", "人工智能", "深度学习"]:tokens = bpe.encode_word(w)print(f"{w} -> {tokens} -> {bpe.decode_word(tokens)}")print("\n=== 句子测试 ===")sentence = "lowest newer 人工智能深度学习"tokens_list = bpe.encode_sentence(sentence)print(tokens_list)print(bpe.decode_sentence(tokens_list))# === 单词测试 ===
# lowest -> ['l', 'o', 'w', 'e', 's', 't', '</w>'] -> lowest
# newer -> ['n', 'e', 'w', 'e', 'r', '</w>'] -> newer
# 人工智能 -> ['人工智能', '</w>'] -> 人工智能
# 深度学习 -> ['深度', '学习', '</w>'] -> 深度学习# === 句子测试 ===
# [['l', 'o', 'w', 'e', 's', 't', '</w>'], ['n', 'e', 'w', 'e', 'r', '</w>'], ['人工智能', '王', '赞', '深度', '学习', '</w>']]
# lowest newer 人工智能深度学习