指针网络 Pointer Network

传统的 Seq2Seq 模型中 Decoder 输出的目标数量是固定的,例如翻译时 Decoder 预测的目标数量等于字典的大小。这导致 Seq2Seq 不能用于一些组合优化的问题,例如凸包问题,三角剖分,旅行商问题 (TSP) 等。Pointer Network 可以解决输出字典大小可变的问题,Pointer Network 的输出字典大小等于 Encoder 输入序列的长度并修改了 Attention 的方法,根据 Attention 的值从 Encoder 的输入中选择一个作为 Decoder 的输出。

1.Pointer Network

Seq2Seq 模型是一种包含 Encoder 和 Decoder 的模型,可以将一个序列转成另外一个序列。但是 Seq2Seq 模型的预测输出目标大小是固定的,对于一些输出目标大小会变的情况,例如很多组合优化问题。

组合优化问题的输出目标的数量依赖于输入序列的长度,例如旅行商问题中包含5个城市 (1, 2, 3, 4, 5),输出预测的时候目标数量为 5。Pointer Network 改变了传统 Attention 的方式,从而可以用于这些组合优化的问题,Pointer Network 在预测输出时会根据 Attention 得到输入序列中每一个城市的概率 (即输出从输入中选择)。

传统 Attention

传统的 Attention 会根据 Attention 值融合 Encoder 的每一个时刻的输出,然后和 Decoder 当前时刻的输出混在一起再预测输出。如下面的公式所示,e 表示 Encoder 的输出,d 表示 Decoder 的输出,Wv 都是可以学习的参数。

指针网络 Pointer Network

Seq2Seq 的 attention

Pointer Network 的 Attention

Pointer Network 计算 Attention 值之后不会把 Encoder 的输出融合,而是将 Attention 作为输入序列 P 中每一个位置输出的概率。

指针网络 Pointer Network

Pointer Network 的 Attention

Pointer Network 和 Seq2Seq 的区别如下图所示,图中展示了凸包问题。Seq2Seq 的 Decoder 会预测每一个位置的输出 (但是输出目标的数量是固定的),而 Pointer Network 的 Decoder 直接根据 Attention 得到输入序列中每一个位置的概率,取概率最大的输入位置作为当前输出。

指针网络 Pointer Network

Seq2Seq 和 Pointer Network

2.实验结果

指针网络 Pointer Network

图中底部的都是 Pointer Network 的实验结果,m 是训练数据中点的个数,n 是测试数据中点的个数。图 (a) 中是使用 LSTM 的 Seq2Seq,Seq2Seq 训练和测试必须使用相同点的个数。

3.参考文献

Pointer Networks


分享到:


相關文章: