PyTorch实现Seq2Seq机器翻译

摘要:
Seq2Seq简介Seq2Seq由编码器和解码器组成,编码器和解码器由RNN组成。Seq2Seq在训练阶段和预测阶段略有不同。简而言之,注意机制允许我们在输出时关注输入序列的某些部分,即使输入单词具有不同的贡献。这里只描述了使用注意机制的翻译器的部分代码。完整代码如下https://gitee.com/dogecheng/python/blob/master/pytorch/Seq2SeqForTranslation.ipynb在计算关注值之后,解码器将对编码器输出的隐藏状态执行加权平均,以获得上下文向量上下文。然后,将拼接解码器当前时间步长的上下文和隐藏状态,并在tanh之后。

Seq2Seq简介

Seq2Seq由Encoder和Decoder组成,Encoder和Decoder又由RNN构成。Encoder负责将输入编码为一个向量。Decoder根据这个向量,和上一个时间步的预测结果作为输入,预测我们需要的内容。

PyTorch实现Seq2Seq机器翻译第1张

Seq2Seq在训练阶段和预测阶段稍有差异。如果Decoder第一个预测预测的输出就错了,它会导致“蝴蝶效应“,影响后面全部内容。为了解决这个问题,在训练时,Decoder每个时间步的输入不全是上一个时间步的输出,而以一定的概率选择真实值作为输入。

PyTorch实现Seq2Seq机器翻译第2张

通常,Encoder的输入序列需要添加一个终止符“<eos>”,可以不需要起始符“<sos>”。Decoder输入序列在训练时则需要添加一个起始符和终止符,在预测时,Decoder接收一个起始符“<sos>”,它类似一个信号,告诉Decoder可以开始工作了,当输出终止符时我们就可以停下来(通常可以再设置一个最大输出长度,防止Decoder一直不输出终止符)。

终止符和起始符只要不会出现在原始序列中就可以了,也可以用<start>和<stop>,<bos>和<eos>,<s>和</s>等等

Attention机制

这里介绍的是LuongAttention

整个输入序列的信息被Encoder编码为固定长度的向量,类似”有损压缩”。这个向量无法完全表达整个输入序列的信息。另外,随着输入长度的增加,这个固定长度的向量,会逐渐丢失更多信息。

以英中翻译任务为例,我们翻译的时候,虽然要考虑上下文,但每个时间步的输出,不同单词的贡献是不同的。考虑下面这个句子对:

She doesn't like soccer.

她不喜欢足球。

我们翻译“她”时,其实只需要考虑“She”就好了,“足球”也是同理。简单说,Attention机制让我们的输出时,关注输入序列中的某一些部位就可以了,即让输入的单词有不同的贡献。

PyTorch实现Seq2Seq机器翻译第3张

根据原始论文,我们定义以下符号:在每个时间步$t$,Decoder当前时间步的隐藏状态$h_t$,整个Encoder输出的隐藏状态$ar h_s$​,权重数值​$a_t$​,上下文向量​$c_t$。

注意力值通过以下方式计算:

$$
score(h_t,ar h_s)=
egin{cases}
h_t^Tar h_s & ext{dot} \
h_t^TW_aar h_s & ext{general} \
v_a^T anh (W_a[h_t;ar h_s]) & ext{concat}
end{cases}
$$

其中权重根据以下公式计算(其实就是用softmax归一化了)

$$
a_t(s)=align(h_t, ar h_s)=frac {exp(score(h_t, ar h_s))}{sum_{s'} exp(score(h_t, ar h_{s'}))}
$$

上下文向量根据权重,对​Encoder输出隐藏状态的每个时间步进行加权平均

$$
c_t=sum_s a_t(s) cdot ar h_s
$$

与Decoder当前时间步的隐藏状态拼接,计算一个注意力隐藏状态,其计算公式如下

$$
ilde h_t = anh (W_c[c_t;h_t])
$$

再根据这个注意力隐藏状态预测输出结果

$$
y = ext{softmax}(W_s ilde h_t)
$$

部分代码

参考了官方文档和github上的一些代码,使用Attention机制和不使用Attention机制的翻译器都实现了一下。这里只对使用了Attention机制的翻译器的部分代码进行说明,完整代码如下

https://gitee.com/dogecheng/python/blob/master/pytorch/Seq2SeqForTranslation.ipynb

在计算出注意力值后,Decoder将其与Encoder输出的隐藏状态进行加权平均,得到上下文向量context.

再将context与Decoder当前时间步的隐藏状态拼接,经过tanh。最后用softmax预测最终的输出概率。

class Decoder(nn.Module):
    def forward(self, token_inputs, last_hidden, encoder_outputs):
        ...
        # encoder_outputs = [input_lengths, batch, hid_dim * n directions]
        attn_weights = self.attn(gru_output, encoder_outputs)
        # attn_weights = [batch, 1, sql_len]
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        # [batch, 1, hid_dim * n directions]

        gru_output = gru_output.squeeze(0) # [batch, n_directions * hid_dim]
        context = context.squeeze(1)       # [batch, n_directions * hid_dim]
        concat_input = torch.cat((gru_output, context), 1)  # [batch, n_directions * hid_dim * 2]
        concat_output = torch.tanh(self.concat(concat_input))  # [batch, n_directions*hid_dim]
        output = self.out(concat_output) # [batch, output_dim]
        output = self.softmax(output)
        ...

训练时,根据use_teacher_forcing设置的阈值,决定下一时间步的输入是上一时间步的预测结果还是来自数据的真实值

if self.predict:
    """
    预测代码
    """
    ...

else:
    max_target_length = max(target_lengths)
    all_decoder_outputs = torch.zeros((max_target_length, batch_size, self.decoder.output_dim), device=self.device)

    for t in range(max_target_length):
        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
        if use_teacher_forcing:
            # decoder_output = [batch, output_dim]
            # decoder_hidden = [n_layers*n_directions, batch, hid_dim]
            decoder_output, decoder_hidden, decoder_attn = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            all_decoder_outputs[t] = decoder_output
            decoder_input = target_batches[t]  # 下一个输入来自训练数据
        else:
            decoder_output, decoder_hidden, decoder_attn = self.decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            # [batch, 1]
            topv, topi = decoder_output.topk(1)
            all_decoder_outputs[t] = decoder_output
            decoder_input = topi.squeeze(1).detach()  # 下一个输入来自模型预测

损失函数通过使用设置ignore_index不计padding部分的损失

loss_fn = nn.NLLLoss(ignore_index=PAD_token)
loss = loss_fn(
    all_decoder_outputs.reshape(-1, self.decoder.output_dim),  # [batch*seq_len, output_dim]
    target_batches.reshape(-1)               # [batch*seq_len]
)

Seq2Seq在预测阶段每次只输入一个样本,输出其翻译结果,对应forward()函数中的内容如下,当Decoder输出终止符或输出长度达到所设定的阈值时便停止。

class Seq2Seq(nn.Module):
    ...
    def forward(self, input_batches, input_lengths, target_batches=None, target_lengths=None, teacher_forcing_ratio=0.5):
        ...
        if self.predict:
            # 一次只输入一句话
            assert batch_size == 1, "batch_size of predict phase must be 1!"
            output_tokens = []

            while True:
                decoder_output, decoder_hidden, decoder_attn = self.decoder(
                    decoder_input, decoder_hidden, encoder_outputs
                )
                # [1, 1]
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(1).detach()
                output_token = topi.squeeze().detach().item()
                if output_token == EOS_token or len(output_tokens) == self.max_len:
                    break
                output_tokens.append(output_token)
            return output_tokens

        else:
            """
            训练代码
            """
            ...

部分实验结果,具体可以在notebook里看

PyTorch实现Seq2Seq机器翻译第4张

参考资料

NLP FROM SCRATCH: TRANSLATION WITH A SEQUENCE TO SEQUENCE NETWORK AND ATTENTION

DEPLOYING A SEQ2SEQ MODEL WITH TORCHSCRIPT

Practical PyTorch: Translation with a Sequence to Sequence Network and Attention

1 - Sequence to Sequence Learning with Neural Networks

免责声明:文章转载自《PyTorch实现Seq2Seq机器翻译》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇vue 支持 超大上G,多附件上传laravel框架——验证码(第二种方法)下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

ISD9160学习笔记05_ISD9160语音识别代码分析

前言 语音识别是特别酷的功能,ISD9160的核心卖点就是这个语音识别,使用了Cybron VR 算法。 很好奇这颗10块钱以内的IC是如何实现人家百来块钱的方案。且听如下分析。 本文作者twowinter,转载请注明:http://blog.csdn.net/iotisan/ 功能分析 语音识别例程中做了21条语音识别模型,只要识别到对应的语音,就从串...

自然语言处理-中文语料预处理

自然语言处理——中文文本预处理 近期,在自学自然语言处理,初次接触NLP觉得十分的难,各种概念和算法,而且也没有很强的编程基础,学着稍微有点吃力。不过经过两个星期的学习,已经掌握了一些简单的中文、英文语料的预处理操作。写点笔记,记录一下学习的过程。 1、中文语料的特点   第一点:中文语料中词与词之间是紧密相连的,这一点不同与英文或者其它语种的语料,因此在...

nltk安装配置以及语料库的安装配置

一 nltk的安装   nltk的安装个人推荐使用pip安装 直接在pycharm的Termial中安装即可    其中 安装语句为 pip3 install nltk (如有python版本不同 可尝试pip install nltk)   此处我的已经安装过所以显示的是安装位置  在安装时如果很慢 可以使用其他的源路径 如 阿里云 :-i http:...

自然语言处理(nlp)比计算机视觉(cv)发展缓慢,而且更难!

https://mp.weixin.qq.com/s/kWw0xce4kdCx62AflY6AzQ 1.抢跑的nlp nlp发展的历史非常早,因为人从计算机发明开始,就有对语言处理的需求。各种字符串算法都贯穿于计算机的发展历史中。伟大的乔姆斯基提出了生成文法,人类拥有的处理语言的最基本框架,自动机(正则表达式),随机上下文无关分析树,字符串匹配算法KMP,...

斯坦福大学自然语言处理第五课“拼写纠错(Spelling Correction)”

一、课程介绍 斯坦福大学于2012年3月在Coursera启动了在线自然语言处理课程,由NLP领域大牛Dan Jurafsky 和 Chirs Manning教授授课:https://class.coursera.org/nlp/ 以下是本课程的学习笔记,以课程PPT/PDF为主,其他参考资料为辅,融入个人拓展、注解,抛砖引玉,欢迎大家在“我爱公开课”上...

词向量之word2vec实践

首先感谢无私分享的各位大神,文中很多内容多有借鉴之处。本次将自己的实验过程记录,希望能帮助有需要的同学。 一、从下载数据开始     现在的中文语料库不是特别丰富,我在之前的文章中略有整理,有兴趣的可以看看。本次实验使用wiki公开数据,下载地址如下:         wiki英文数据下载:https://dumps.wikimedia.org/enwik...