使用 Transformer 进行机器翻译
一、transformer 简介 transformer 是 Google 在 2017 年发表的文章 Attention Is All You Need 中提出的网络架构。transformer 中只使用了注意力,实现了序列数据的处理,而未使用之前常用的 RNN 或 CNN。 对 nlp 问题,我们希望的是尽可能的获取句子的整体含义。使用 RNN,我们必须逐词获取语义,因此容易导致开头词汇词义的遗忘;使用 CNN,我们必须通过增加层数来扩大获取信息的范围。这两种方法都有很大的局限。 注意力方法则可以直接获得全局信息。方法是对一条序列,求其对于本身的注意力,这被称为自注意力。 transformer 的原理和模型较为复杂,在这里只是简单说明。 二、数据集 此为训练模型所用的数据集。设定英文为源语言,中文为要翻译成的语言。 (1)Dataset 类编写 我们根据路径打开文件,获取中英文序列和单词表。并将序列直接转化为 tensor,方便读取。 class TranslateDataset(Dataset): def __init__(self, en_path, zh_path): en_seqs, self.en_vocab = get_seq_and_vocab(en_path, get_tokenizer("basic_english")) zh_seqs, self.zh_vocab = get_seq_and_vocab(zh_path, zh_simple_tokenizer) self.items = [] for i in range(len(en_seqs)): en_seq = en_seqs[i] zh_seq = zh_seqs[i] src = en_seq tgt = zh_seq[:-1] pdt = zh_seq[1:] self....