[译] Understanding Incremental Decoding in fairseq

Category 碎碎念

近来一直在使用 fairseq 做项目,因为其功能较多而源码也比较复杂,光靠官方文档也难以完全理解。ankur6ue 的一篇文章对 fairseq 中的增量解码(Incremental Decoding)操作做了详尽的介绍,于是我节选了其中的一部分,将其译为中文,希望对和我一样在读源码的朋友有所帮助。

推理过程中的增量解码

在语言翻译任务的推理过程,解码器逐步地输出目标语言词汇的概率分布。最简单的翻译算法通过贪心策略直接选择概率最高的目标词,这个方法和训练过程中计算损失函数的方式一致。另一种方法则是保存所有可能的目标序列,再从中选出最小化对数似然的的结果。但这种方法需要基于似然搜索所有的可能序列,同时由于词表大小通常在数百上千的规模,导致计算开销随着序列长度指数上升。

集束搜索(Beam Search)在两种极端策略间取得了平衡,由于网络上已经有许多非常好的教程,本文不在此展开介绍集束搜索的工作原理。因为集束搜索在每个步骤只考虑 \(B\) 个前缀序列,搜索空间就由 \(V\times V\) 下降至 \(B\times V\)(其中 \(B\) 为集束宽度,\(V\) 为目标词表的大小),所以相比暴力搜索的方法显著高效。解码结果缺乏多样性是集束搜索的一项缺陷,因为一条输入序列可能具有多条正确翻译,这项缺陷就会影响翻译任务。针对该问题也提出了许多解决方法,例如 Diverse Beam Search 在标准的集束分数中添加了一个差异项,通过对先前步骤已使用过的词施加惩罚并使用 top-k 随机取样在下一步的生成中随机选出前 k 个最有可能的备选项(代替了集束搜索中永远选择前 \(B\) 个备选项的取样方式),从而产生更多样的结果。

尽管集束搜索比蛮力搜索更高效,但由于每个步骤都要重新计算所有前缀 token(prefix tokens),其计算开销会随着解码序列长度的增加而线性增长。

n

在这个例子中,用 A、B、C 等字母表示 token,在推理时,集束会扩展成由翻译出句子所构成的 batch。如果输入 batch 由 2 个句子构成,同时将集束宽度设置为 3,最终得到 batch 的大小就为 6。在计算过程中,每个集束都作为 batch 中的元素并行计算。

当模型完成前面一部分的 token 的解码计算后,我们就会思考:是否可以重复利用这些计算结果?其实增量解码正是这个思路的实现。增量解码使用名为增量状态(incremental state)的数据结构保存先前计算结果,用于后续的卷积计算。在每个计算步骤中,解码器只需对当前 token 做计算,若是模型中的某些层需要先前 token 的信息(例如卷积层),则从增量状态中取出所需结果。而在编码器的计算过程中,编码器与解码出的目标序列无关,它只在一开始时计算输入序列并产生每个输入字词的编码,这些编码本就会被解码器重复使用。

增量解码如何节省计算开销?

增量解码的具体实现稍有些复杂,希望下图能够帮助读者更好地理解整个过程。我推荐读者尝试使用 Python Debugger 在以下代码中设置断点,相信能够更容易理解每一步所做的操作。

n

 在第(1)步中,input_buffer 中每个卷积模块对应的值都是 None,其内存大小由 beam_size (3)、conv_kernel_width (3) 和 conv_kernel_input_dimension (512) 分配,初始化为 0。

我们假设输入 1 条句子,那么 batch 的大小与集束宽度相等(该样例中为 3)。首先,每个集束都由开始标记构成(BOS),因而输入卷积层的嵌入向量相等。

n

 在完成创建和初始化后,input_buffer 如上图所示

然后 input_buffer 左移 1 个位置并将输入添加到最后一列。由于 input_buffer 全由 0 填充,左移后并没有明显的变化,不过我们在下一步就会看到它的作用。

n

 在最右侧填入输入列后的 incremental_state 如上图所示,因为在第(1)步中的所有输入 token 都是 BOS token,填入每个集束的输入向量也都相同。

接着,输入数据与卷积滤波器做计算,将计算结果传递给后续层——即 GLU 和注意力层。

n

现在我们考虑下一个线性卷积层模块,进入该层的输入来源于前一个模块。正如前一个部分中的卷积层,该卷积层也有自己的增量状态,同样由 0 初始化并填入输入数据。

n

n

和先前一样,输入数据与卷积核做计算并传入 GLU 和注意力层,解码器中的所有结构都重复这一过程,最后的输出就为每个集束的词表概率分布向量。集束搜索算法从该结果中得到最优的 token,用于下一步的计算,在这里用 A、B、C 表示解码结果。

接下来我们考虑在步骤(2)所做的操作,我们目前的集束为

n

再一次重复第一层的卷积操作,由于每个集束的输入 token 不同(A、B 和 C),其嵌入向量也不再相等。此外,由于步骤(1)中的初始化过程,input_buffer 也不再全为 0。

n

接着,input_buffer 左移并将新的输入添加至最后一列。

n

每个集束的 input_buffer 与卷积核做计算后传入到后续层中,与先前步骤中相同。

n

整个过程中,input_buffer 作为内存记录先前步骤给出的结果,用于计算卷积结果。input_buffer 同时节省了计算开销,这一点可以在下一个卷积操作中看出。

n

由于 input_buffer 中保存了先前步骤的输入,在当前步骤中可以直接用于完成卷积计算。最重要的是,先前步骤的输入是再前一个步骤的计算结果,保存的计算结果就避免了解码过程中的重复计算,从而节省计算开销。后续过程与前文所述一样,左移并填入输入,完成卷积计算。

n

为什么需要重排增量状态?

在每一步开始前,generator 都会重新排列解码器和编码器的 incremental_state

n

这是由于集束搜索会导致每个集束中前缀 token 的顺序发生改变,通过一个简单的样例描述这个过程。假设下图是步骤(2)得到的集束状态,其中展示了每个步骤得到的 token 和预测分数。

n

到了步骤(3)时,通过预测分数得到 N、P 和 S,箭头表示了每个结果 token 的来源。

n

(译者案:作者在这里描述得不是很清晰,需要额外补充一些说明。集束搜索过程如下图示意,其主要操作是在每一步中只取 top-k 个预测分数最高的结果作为下一步的前缀 token,其他分支中止不再计算。在该例中,输入 BOS token 后,在若干结果中取 top-3 分数最高的结果,分别为 A、B 和 C。那么下一步的输入就为 [BOS A]、[BOS B] 和 [BOS C],再取 top-3 分数最高的结果。由 A 生成的结果分数无法位于 top-3,A token 所属的分支就被中止,后续不会再计算,在 buffer 中存储其状态也是无用的了,因此要将其替换为有效的前缀 token。)

n

于是就如下图所示,重新排列每个集束。

n

当对当前 token N、P 和 S 执行解码操作(预测下一个 token)时,我们必须重排 incremental_state 使卷积操作能够使用正确的前缀 token。这个操作可能不能马上明白,需要花些时间仔细理解。

另外还有一点,fairseq 的代码也重新排列了编码器的状态,然而由于编码器状态只取决于输入 token,并不会随集束状态改变,其重排也就不是必需的,至少在本文的例子中不需要这样的操作。

为什么集束搜索返回的 token 数量是集束数量的两倍?

在 fairseq 中,集束搜索返回输出 token 的数量是集束数量的两倍。这是由于集束搜索中的部分集束可能会返回表示句子结束的 EOS token,而我们不想要集束搜索太早就停止。当 EOS token 出现在结果的前半部分时,可以将预测总分与其他已有结果的分数相比较从而完成句子。下图展示了相关代码并附上了一些注释,希望能有助读者理解。

n

 (a)返回表示集束中具有 EOS token 的掩码;(b)具有 EOS token 集束对应的索引,只有在 EOS 出现在前半部分的情况下(:beam_size)。注意集束搜索返回 2 * beam_size 个结果;(c)对于前 beam_size 个具有 EOS 的集束,组合预测结果并判断是否完成句子。如果是,减少剩余句子的数量,注意我们处理的是一整个 batch 的输入句子;(d)如果剩余句子的数量是 0,完成;(e)如果能够完成一整个 batch 的目标句子,从 batch 中移除元素并调整 batch 索引。