Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
69b4bdad
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
69b4bdad
编写于
7月 13, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Further update README.md and add function doc for mt_with_external_memory model.
上级
617bc4e5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
236 addition
and
34 deletion
+236
-34
mt_with_external_memory/README.md
mt_with_external_memory/README.md
+174
-16
mt_with_external_memory/external_memory.py
mt_with_external_memory/external_memory.py
+6
-12
mt_with_external_memory/model.py
mt_with_external_memory/model.py
+56
-6
未找到文件。
mt_with_external_memory/README.md
浏览文件 @
69b4bdad
...
...
@@ -222,7 +222,7 @@ class ExternalMemory(object):
-
输入参数
`read_key`
:某层的输出,其包含的信息用于读头的寻址。
-
返回:读出的信息(可直接作为其他层的输入)。
部分
重要的
实现逻辑:
部分
关键
实现逻辑:
-
神经图灵机的 “外部存储矩阵” 采用
`Paddle.layer.memory`
实现,并采用序列形式(
`is_seq=True`
),该序列的长度表示记忆槽的数量,序列的
`size`
表示记忆槽(向量)的大小。该序列依赖一个外部层作为初始化, 其记忆槽的数量取决于该层输出序列的长度。因此,该类不仅可用来实现有界记忆(Bounded Memory),同时可用来实现无界记忆 (Unbounded Memory,即记忆槽数量可变)。
...
...
@@ -244,21 +244,130 @@ class ExternalMemory(object):
涉及三个主要函数:
```
memory_enhanced_seq2seq(...)
bidirectional_gru_encoder(...)
memory_enhanced_decoder(...)
def bidirectional_gru_encoder(input, size, word_vec_dim):
"""Bidirectional GRU encoder.
:params size: Hidden cell number in decoder rnn.
:type size: int
:params word_vec_dim: Word embedding size.
:type word_vec_dim: int
:return: Tuple of 1. concatenated forward and backward hidden sequence.
2. last state of backward rnn.
:rtype: tuple of LayerOutput
"""
pass
def memory_enhanced_decoder(input, target, initial_state, source_context, size,
word_vec_dim, dict_size, is_generating, beam_size):
"""GRU sequence decoder enhanced with external memory.
The "external memory" refers to two types of memories.
- Unbounded memory: i.e. attention mechanism in Seq2Seq.
- Bounded memory: i.e. external memory in NTM.
Both types of external memories can be implemented with
ExternalMemory class, and are both exploited in this enhanced RNN decoder.
The vanilla RNN/LSTM/GRU also has a narrow memory mechanism, namely the
hidden state vector (or cell state in LSTM) carrying information through
a span of sequence time, which is a successful design enriching the model
with the capability to "remember" things in the long run. However, such a
vector state is somewhat limited to a very narrow memory bandwidth. External
memory introduced here could easily increase the memory capacity with linear
complexity cost (rather than quadratic for vector state).
This enhanced decoder expands its "memory passage" through two
ExternalMemory objects:
- Bounded memory for handling long-term information exchange within decoder
itself. A direct expansion of traditional "vector" state.
- Unbounded memory for handling source language's token-wise information.
Exactly the attention mechanism over Seq2Seq.
Notice that we take the attention mechanism as a particular form of external
memory, with read-only memory bank initialized with encoder states, and a
read head with content-based addressing (attention). From this view point,
we arrive at a better understanding of attention mechanism itself and other
external memory, and a concise and unified implementation for them.
For more details about external memory, please refer to
`Neural Turing Machines <https://arxiv.org/abs/1410.5401>`_.
For more details about this memory-enhanced decoder, please
refer to `Memory-enhanced Decoder for Neural Machine Translation
<https://arxiv.org/abs/1606.02003>`_. This implementation is highly
correlated to this paper, but with minor differences (e.g. put "write"
before "read" to bypass a potential bug in V2 APIs. See
(`issue <https://github.com/PaddlePaddle/Paddle/issues/2061>`_).
"""
pass
def memory_enhanced_seq2seq(encoder_input, decoder_input, decoder_target,
hidden_size, word_vec_dim, dict_size, is_generating,
beam_size):
"""Seq2Seq Model enhanced with external memory.
The "external memory" refers to two types of memories.
- Unbounded memory: i.e. attention mechanism in Seq2Seq.
- Bounded memory: i.e. external memory in NTM.
Both types of external memories can be implemented with
ExternalMemory class, and are both exploited in this Seq2Seq model.
:params encoder_input: Encoder input.
:type encoder_input: LayerOutput
:params decoder_input: Decoder input.
:type decoder_input: LayerOutput
:params decoder_target: Decoder target.
:type decoder_target: LayerOutput
:params hidden_size: Hidden cell number, both in encoder and decoder rnn.
:type hidden_size: int
:params word_vec_dim: Word embedding size.
:type word_vec_dim: int
:param dict_size: Vocabulary size.
:type dict_size: int
:params is_generating: Whether for beam search inferencing (True) or
for training (False).
:type is_generating: bool
:params beam_size: Beam search width.
:type beam_size: int
:return: Cost layer if is_generating=False; Beam search layer if
is_generating = True.
:rtype: LayerOutput
"""
pass
```
`memory_enhanced_seq2seq` 函数定义整个带外部记忆机制的序列到序列模型,是模型定义的主调函数。它首先调用`bidirectional_gru_encoder` 对源语言进行编码,然后通过 `memory_enhanced_decoder` 进行解码。
- `bidirectional_gru_encoder` 函数实现双向单层 GRU(Gated Recurrent Unit) 编码器。返回两组结果:一组为字符级编码向量序列(包含前后向),一组为整个源语句的句级编码向量(仅后向)。前者用于解码器的注意力机制中记忆矩阵的初始化,后者用于解码器的状态向量的初始化。
- `memory_enhanced_decoder` 函数实现通过外部记忆增强的 GRU 解码器。它利用同一个`ExternalMemory` 类实现两种外部记忆模块:
- 无界外部记忆:即传统的注意力机制。利用`ExternalMemory`,打开只读开关,关闭插值寻址。并利用解码器的第一组输出作为 `ExternalMemory` 中存储矩阵的初始化(`boot_layer`)。因此,该存储的记忆槽数目是动态可变的,取决于编码器的字符数。
```
unbounded_memory = ExternalMemory(
name="unbounded_memory",
mem_slot_size=size * 2,
boot_layer=unbounded_memory_init,
readonly=True,
enable_interpolation=False)
```
- 有界外部记忆:利用`ExternalMemory`,关闭只读开关,打开插值寻址。并利用解码器的第一组输出,取均值池化(pooling)后并扩展为指定序列长度后,叠加随机噪声(训练和推断时保持一致),作为 `ExternalMemory` 中存储矩阵的初始化(`boot_layer`)。因此,该存储的记忆槽数目是固定的。即代码中的:
`bidirectional_gru_encoder` 函数实现双向单层 GRU(Gated Recurrent Unit) 编码器。返回两组结果:一组为字符级编码向量序列(包含前后向),一组为整个源语句的句级编码向量(仅后向)。前者用于解码器的注意力机制中记忆矩阵的初始化,后者用于解码器的状态向量的初始化。
```
bounded_memory = ExternalMemory(
name="bounded_memory",
mem_slot_size=size,
boot_layer=bounded_memory_init,
readonly=False,
enable_interpolation=True)
```
`memory_enhanced_decoder` 函数实现通过外部记忆增强的 GRU 解码器。它利用同一个`ExternalMemory` 类实现两种外部记忆模块:
注意到,在我们的实现中,注意力机制(或无界外部存储)和神经图灵机(或有界外部存储)被实现成相同的 `ExternalMemory` 类。前者是**只读**的, 后者**可读可写**。这样处理仅仅是为了便于统一我们对 “注意机制” 和 “记忆机制” 的理解和认识,同时也提供更简洁和统一的实现版本。注意力机制也可以通过 `paddle.networks.simple_attention` 实现。
- `memory_enhanced_seq2seq` 函数定义整个带外部记忆机制的序列到序列模型,是模型定义的主调函数。它首先调用`bidirectional_gru_encoder` 对源语言进行编码,然后通过 `memory_enhanced_decoder` 进行解码。
- 无界外部记忆:即传统的注意力机制。利用`ExternalMemory`,打开只读开关,关闭插值寻址。并利用解码器的第一组输出作为 `ExternalMemory` 中存储矩阵的初始化(`boot_layer`)。因此,该存储的记忆槽数目是动态可变的,取决于编码器的字符数。
- 有界外部记忆:利用`ExternalMemory`,关闭只读开关,打开插值寻址。并利用解码器的第一组输出,取均值池化(pooling)后并扩展为指定序列长度后,叠加随机噪声(训练和推断时保持一致),作为 `ExternalMemory` 中存储矩阵的初始化(`boot_layer`)。因此,该存储的记忆槽数目是固定的。
注意到,在我们的实现中,注意力机制(或无界外部存储)和神经图灵机(或有界外部存储)被实现成相同的 `ExternalMemory` 类。前者是**只读**的, 后者**可读可写**。这样处理仅仅是为了便于统一我们对 “注意机制” 和 “记忆机制” 的理解和认识,同时也提供更简洁和统一的实现版本。注意力机制也可以通过 `paddle.networks.simple_attention` 实现。
此外,在该实现中,将 `ExternalMemory` 的 `write` 操作提前至 `read` 之前,以避开潜在的拓扑连接局限,详见 [Issue](https://github.com/PaddlePaddle/Paddle/issues/2061)。我们可以看到,本质上他们是等价的。
...
...
@@ -278,24 +387,73 @@ def reader():
用户需自行完成字符的切分 (Tokenize) ,并构建字典完成 ID 化。
PaddlePaddle 的接口 [paddle.paddle.wmt14](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/dataset/wmt14.py), 默认提供了一个经过预处理的、较小规模的
wmt14 英法翻译数据集的子集
。并提供了两个reader creator函数如下:
PaddlePaddle 的接口 [paddle.paddle.wmt14](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/dataset/wmt14.py), 默认提供了一个经过预处理的、较小规模的
[wmt14 英法翻译数据集的子集](http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz)(该数据集有193319条训练数据,6003条测试数据,词典长度为30000)
。并提供了两个reader creator函数如下:
```
paddle.dataset.wmt14.train(dict_size)
paddle.dataset.wmt14.test(dict_size)
```
这两个函数被调用时即返回相应的`reader()`函数,供`paddle.traner.SGD.train`使用。
这两个函数被调用时即返回相应的`reader()`函数,供`paddle.traner.SGD.train`使用。当我们需要使用其他数据时,可参考 [paddle.paddle.wmt14](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/dataset/wmt14.py) 构造相应的 data creator,并替换 `paddle.dataset.wmt14.train` 和 `paddle.dataset.wmt14.train` 成相应函数名。
### 训练
命令行输入:
当我们需要使用其他数据时,可参考 [paddle.paddle.wmt14](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/dataset/wmt14.py) 构造相应的 data creator,并替换 `paddle.dataset.wmt14.train` 和 `paddle.dataset.wmt14.train` 成相应函数名。
```
python mt_with_external_memory.py
```
或自定义部分参数, 例如:
### 训练及预测
```
CUDA_VISIBLE_DEVICES=8,9,10,11 python train.py
\
--dict_size 30000
\
--word_vec_dim 512
\
--hidden_size 1024
\
--memory_slot_num 8
\
--use_gpu True
\
--trainer_count 4
\
--num_passes 100
\
--batch_size 128
\
--memory_perturb_stddev 0.1
```
即可运行训练脚本,训练模型将被定期保存于本地 `./checkpoints`。参数含义可运行
```
python train.py --help
```
### 解码
命令行输入:
```
python mt_with_external_memory.py
```
```
python infer.py
```
或自定义部分参数, 例如:
```
CUDA_VISIBLE_DEVICES=8,9,10,11 python train.py
\
--dict_size 30000
\
--word_vec_dim 512
\
--hidden_size 1024
\
--memory_slot_num 8
\
--use_gpu True
\
--trainer_count 4
\
--memory_perturb_stddev 0.1
\
--infer_num_data 10
\
--model_filepath checkpoints/params.latest.tar.gz
--beam_size 3
```
即可运行解码脚本,产生示例翻译结果。参数含义可运行:
```
python infer.py --help
```
即可运行训练脚本(默认训练一轮),训练模型将被定期保存于本地
`params.tar.gz`
。训练完成后,将为少量样本生成翻译结果,详见
`infer`
函数。
## 其他讨论
...
...
mt_with_external_memory/external_memory.py
浏览文件 @
69b4bdad
...
...
@@ -5,8 +5,7 @@ import paddle.v2 as paddle
class
ExternalMemory
(
object
):
"""
External neural memory class.
"""External neural memory class.
A simplified Neural Turing Machines (NTM) with only content-based
addressing (including content addressing and interpolation, but excluding
...
...
@@ -76,8 +75,7 @@ class ExternalMemory(object):
size
=
self
.
mem_slot_size
)
def
_content_addressing
(
self
,
key_vector
):
"""
Get write/read head's addressing weights via content-based addressing.
"""Get write/read head's addressing weights via content-based addressing.
"""
# content-based addressing: a=tanh(W*M + U*key)
key_projection
=
paddle
.
layer
.
fc
(
...
...
@@ -104,8 +102,7 @@ class ExternalMemory(object):
return
addressing_weight
def
_interpolation
(
self
,
head_name
,
key_vector
,
addressing_weight
):
"""
Interpolate between previous and current addressing weights.
"""Interpolate between previous and current addressing weights.
"""
# prepare interpolation scalar gate: g=sigmoid(W*key)
gate
=
paddle
.
layer
.
fc
(
...
...
@@ -126,8 +123,7 @@ class ExternalMemory(object):
return
interpolated_weight
def
_get_addressing_weight
(
self
,
head_name
,
key_vector
):
"""
Get final addressing weights for read/write heads, including content
"""Get final addressing weights for read/write heads, including content
addressing and interpolation.
"""
# current content-based addressing
...
...
@@ -139,8 +135,7 @@ class ExternalMemory(object):
return
addressing_weight
def
write
(
self
,
write_key
):
"""
Write onto the external memory.
"""Write onto the external memory.
It cannot be called if "readonly" set True.
:param write_key: Key vector for write heads to generate writing
...
...
@@ -183,8 +178,7 @@ class ExternalMemory(object):
name
=
self
.
name
)
def
read
(
self
,
read_key
):
"""
Read from the external memory.
"""Read from the external memory.
:param write_key: Key vector for read head to generate addressing
signals.
...
...
mt_with_external_memory/model.py
浏览文件 @
69b4bdad
...
...
@@ -20,8 +20,15 @@ from external_memory import ExternalMemory
def
bidirectional_gru_encoder
(
input
,
size
,
word_vec_dim
):
"""
Bidirectional GRU encoder.
"""Bidirectional GRU encoder.
:params size: Hidden cell number in decoder rnn.
:type size: int
:params word_vec_dim: Word embedding size.
:type word_vec_dim: int
:return: Tuple of 1. concatenated forward and backward hidden sequence.
2. last state of backward rnn.
:rtype: tuple of LayerOutput
"""
# token embedding
embeddings
=
paddle
.
layer
.
embedding
(
input
=
input
,
size
=
word_vec_dim
)
...
...
@@ -38,8 +45,7 @@ def bidirectional_gru_encoder(input, size, word_vec_dim):
def
memory_enhanced_decoder
(
input
,
target
,
initial_state
,
source_context
,
size
,
word_vec_dim
,
dict_size
,
is_generating
,
beam_size
):
"""
GRU sequence decoder enhanced with external memory.
"""GRU sequence decoder enhanced with external memory.
The "external memory" refers to two types of memories.
- Unbounded memory: i.e. attention mechanism in Seq2Seq.
...
...
@@ -77,6 +83,30 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
correlated to this paper, but with minor differences (e.g. put "write"
before "read" to bypass a potential bug in V2 APIs. See
(`issue <https://github.com/PaddlePaddle/Paddle/issues/2061>`_).
:params input: Decoder input.
:type input: LayerOutput
:params target: Decoder target.
:type target: LayerOutput
:params initial_state: Initial hidden state.
:type initial_state: LayerOutput
:params source_context: Group of context hidden states for each token in the
source sentence, for attention mechanisim.
:type source_context: LayerOutput
:params size: Hidden cell number in decoder rnn.
:type size: int
:params word_vec_dim: Word embedding size.
:type word_vec_dim: int
:param dict_size: Vocabulary size.
:type dict_size: int
:params is_generating: Whether for beam search inferencing (True) or
for training (False).
:type is_generating: bool
:params beam_size: Beam search width.
:type beam_size: int
:return: Cost layer if is_generating=False; Beam search layer if
is_generating = True.
:rtype: LayerOutput
"""
# prepare initial bounded and unbounded memory
bounded_memory_slot_init
=
paddle
.
layer
.
fc
(
...
...
@@ -172,8 +202,7 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
def
memory_enhanced_seq2seq
(
encoder_input
,
decoder_input
,
decoder_target
,
hidden_size
,
word_vec_dim
,
dict_size
,
is_generating
,
beam_size
):
"""
Seq2Seq Model enhanced with external memory.
"""Seq2Seq Model enhanced with external memory.
The "external memory" refers to two types of memories.
- Unbounded memory: i.e. attention mechanism in Seq2Seq.
...
...
@@ -189,6 +218,27 @@ def memory_enhanced_seq2seq(encoder_input, decoder_input, decoder_target,
For more details about this memory-enhanced Seq2Seq, please
refer to `Memory-enhanced Decoder for Neural Machine Translation
<https://arxiv.org/abs/1606.02003>`_.
:params encoder_input: Encoder input.
:type encoder_input: LayerOutput
:params decoder_input: Decoder input.
:type decoder_input: LayerOutput
:params decoder_target: Decoder target.
:type decoder_target: LayerOutput
:params hidden_size: Hidden cell number, both in encoder and decoder rnn.
:type hidden_size: int
:params word_vec_dim: Word embedding size.
:type word_vec_dim: int
:param dict_size: Vocabulary size.
:type dict_size: int
:params is_generating: Whether for beam search inferencing (True) or
for training (False).
:type is_generating: bool
:params beam_size: Beam search width.
:type beam_size: int
:return: Cost layer if is_generating=False; Beam search layer if
is_generating = True.
:rtype: LayerOutput
"""
# encoder
context_encodings
,
sequence_encoding
=
bidirectional_gru_encoder
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录