小白的进阶之路系列之十八----人工智能从初步到精通pytorch综合运用的讲解第十一部分
从零开始的NLP:使用序列到序列网络和注意力机制进行翻译
我们将编写自己的类和函数来预处理数据以完成我们的 NLP 建模任务。
在这个项目中,我们将训练一个神经网络将法语翻译成英语。
[KEY: > input, = target, < output]> il est en train de peindre un tableau .
= he is painting a picture .
< he is painting a picture .> pourquoi ne pas essayer ce vin delicieux ?
= why not try that delicious wine ?
< why not try that delicious wine ?> elle n est pas poete mais romanciere .
= she is not a poet but a novelist .
< she not not a poet but a novelist .> vous etes trop maigre .
= you re too skinny .
< you re all alone .
… 取得不同程度的成功。
这得益于简单而强大的序列到序列网络思想,其中两个循环神经网络协同工作以将一个序列转换为另一个序列。编码器网络将输入序列压缩成一个向量,解码器网络将该向量展开成一个新的序列。
为了改进这个模型,我们将使用一个注意力机制,它允许解码器学习集中关注输入序列的特定范围。
你还会发现之前的从零开始的 NLP:使用字符级 RNN 分类名称 和 从零开始的 NLP:使用字符级 RNN 生成名称 教程非常有用,因为这些概念分别与编码器和解码器模型非常相似。
要求
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import randomimport torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as Fimport numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSamplerdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
加载数据文件
本项目的数据集包含数千对英法翻译对。
Open Data Stack Exchange 上的这个问题 指引我找到了开放翻译网站 https://tatoeba.org/,该网站在 https://tatoeba.org/eng/downloads 提供下载——更妙的是,有人做了额外的工作,将语言对分成独立的文本文件,在此处提供:https://www.manythings.org/anki/
英法翻译对文件太大,无法包含在仓库中,请在继续之前下载到 data/eng-fra.txt
。该文件是以制表符分隔的翻译对列表
I am cold. J'ai froid.
注意
从这里下载数据并将其解压到当前目录。
与字符级 RNN 教程中使用的字符编码类似,我们将语言中的每个词表示为一个独热向量,或者一个巨大的零向量,除了一个位置为一(该词的索引)。与语言中可能存在的几十个字符相比,词的数量要多得多,因此编码向量更大。然而,我们将稍微做一些妥协,仅使用每种语言的几千个词来修剪数据。
我们稍后需要为每个词设置一个唯一的索引,用作网络的输入和目标。为了跟踪所有这些,我们将使用一个名为 Lang
的辅助类,它包含 word → index (word2index
) 和 index → word (index2word
) 字典,以及每个词的计数 word2count
,这将用于稍后替换罕见词。
SOS_token = 0
EOS_token = 1class Lang:def __init__(self, name):self.name = nameself.word2index = {}self.word2count = {}self.index2word = {0: "SOS", 1: "EOS"}self.n_words = 2 # Count SOS and EOSdef addSentence(self, sentence):for word in sentence.split(' '):self.addWord(word)def addWord(self, word):if word not in self.word2index:self.word2index[word] = self.n_wordsself.word2count[word] = 1self.index2word[self.n_words] = wordself.n_words += 1else:self.word2count[word] += 1
所有文件都是 Unicode 格式,为了简化,我们将 Unicode 字符转换为 ASCII,全部转换为小写,并去除大部分标点符号。
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn')# Lowercase, trim, and remove non-letter characters
def normalizeString(s):s = unicodeToAscii(s.lower().strip())s = re.sub(r"([.!?])", r" \1", s)s = re.sub(r"[^a-zA-Z!?]+", r" ", s)return s.strip()
为了读取数据文件,我们将文件按行分割,然后将行分割成对。文件都是英语 → 其他语言,因此如果我们要从其他语言 → 英语翻译,我添加了 reverse
标志来反转翻译对。
def readLangs(lang1, lang2, reverse=False):print("Reading lines...")# Read the file and split into lineslines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\read().strip().split('\n')# Split every line into pairs and normalizepairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]# Reverse pairs, make Lang instancesif reverse:pairs = [list(reversed(p)) for p in pairs]input_lang = Lang(lang2)output_lang = Lang(lang1)else:input_lang = Lang(lang1)output_lang = Lang(lang2)re