diff --git a/08.machine_translation/README.cn.md b/08.machine_translation/README.cn.md
index 1244f94f66449cdf3c78470810c3f736f8ae8e03..d0186286fd7f08ca9da9c4aa04cd94642a2bcd74 100644
--- a/08.machine_translation/README.cn.md
+++ b/08.machine_translation/README.cn.md
@@ -54,9 +54,9 @@
### 编码器-解码器框架
编码器-解码器(Encoder-Decoder)\[[2](#参考文献)\]框架用于解决由一个任意长度的源序列到另一个任意长度的目标序列的变换问题。即编码阶段将整个源序列编码成一个向量,解码阶段通过最大化预测序列概率,从中解码出整个目标序列。编码和解码的过程通常都使用RNN实现。
-![encoder_decoder](./image/encoder_decoder.png)
+
-
+
图3. 编码器-解码器框架
@@ -82,9 +82,9 @@
机器翻译任务的训练过程中,解码阶段的目标是最大化下一个正确的目标语言词的概率。思路是:
1. 每一个时刻,根据源语言句子的编码信息(又叫上下文向量,context vector)$c$、真实目标语言序列的第$i$个词$u_i$和$i$时刻RNN的隐层状态$z_i$,计算出下一个隐层状态$z_{i+1}$。计算公式如下:
$$z_{i+1}=\phi_{\theta '} \left ( c,u_i,z_i \right )$$
-其中$\phi _{\theta '}$是一个非线性激活函数;$c=q\mathbf{h}$是源语言句子的上下文向量,在不使用注意力机制时,如果[编码器](#编码器)的输出是源语言句子编码后的最后一个元素,则可以定义$c=h_T$;$u_i$是目标语言序列的第$i$个单词,$u_0$是目标语言序列的开始标记``,表示解码开始;$z_i$是$i$时刻解码RNN的隐层状态,$z_0$是一个全零的向量。
+其中$\phi _{\theta '}$是一个非线性激活函数;$c$是源语言句子的上下文向量,在不使用注意力机制时,如果[编码器](#编码器)的输出是源语言句子编码后的最后一个元素,则可以定义$c=h_T$;$u_i$是目标语言序列的第$i$个单词,$u_0$是目标语言序列的开始标记``,表示解码开始;$z_i$是$i$时刻解码RNN的隐层状态,$z_0$是一个全零的向量。
-2. 将$z_{i+1}$通过`softmax`归一化,得到目标语言序列的第$i+1$个单词的概率分布$p_{i+1}$。概率分布公式如下:
+1. 将$z_{i+1}$通过`softmax`归一化,得到目标语言序列的第$i+1$个单词的概率分布$p_{i+1}$。概率分布公式如下:
$$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$
其中$W_sz_{i+1}+b_z$是对每个可能的输出单词进行打分,再用softmax归一化就可以得到第$i+1$个词的概率$p_{i+1}$。
@@ -136,185 +136,215 @@ $$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$
## 模型配置说明
-下面我们开始根据输入数据的形式配置模型。首先引入所需的库函数以及定义全局变量。
+下面我们开始根据输入数据的形式配置模型。
+
+首先定义用到的全局变量:
```python
-from __future__ import print_function
-import contextlib
-
-import numpy as np
-import paddle
-import paddle.fluid as fluid
-import paddle.fluid.framework as framework
-import paddle.fluid.layers as pd
-from paddle.fluid.executor import Executor
-from functools import partial
-import os
-try:
- from paddle.fluid.contrib.trainer import *
- from paddle.fluid.contrib.inferencer import *
-except ImportError:
- print(
- "In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
- file=sys.stderr)
- from paddle.fluid.trainer import *
- from paddle.fluid.inferencer import *
-
-dict_size = 30000
-source_dict_dim = target_dict_dim = dict_size
-hidden_dim = 32
-word_dim = 16
-batch_size = 2
-max_length = 8
-topk_size = 50
-beam_size = 2
-
-decoder_size = hidden_dim
+dict_size = 30000 # 字典维度
+source_dict_dim = target_dict_dim = dict_size # 源/目标语言字典维度
+word_dim = 16 # 词向量维度
+hidden_dim = 32 # 编码器中的GRU隐层大小
+decoder_size = hidden_dim # 解码器中的GRU隐层大小
+max_length = 8 # 生成句子的最大长度
+beam_size = 2 # 柱宽度
```
-然后如下实现编码器框架:
+然后如下实现编码器框架,包括以下内容:
- ```python
- def encoder(is_sparse):
- src_word_id = pd.data(
+- 定义源语言id序列的输入数据
+
+```python
+src_word_id = pd.data(
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
- src_embedding = pd.embedding(
- input=src_word_id,
- size=[dict_size, word_dim],
- dtype='float32',
- is_sparse=is_sparse,
- param_attr=fluid.ParamAttr(name='vemb'))
-
- fc1 = pd.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
- lstm_hidden0, lstm_0 = pd.dynamic_lstm(input=fc1, size=hidden_dim * 4)
- encoder_out = pd.sequence_last_step(input=lstm_hidden0)
- return encoder_out
- ```
-
-再实现训练模式下的解码器:
+```
+- 将上述编码映射到低维语言空间的词向量
+```python
+src_embedding = pd.embedding(
+ input=src_word_id,
+ size=[source_dict_dim, word_dim],
+ dtype='float32',
+ is_sparse=is_sparse)
+```
+- 用双向GRU编码源语言序列,拼接两个GRU的编码结果得到$\mathbf{h}$
+
```python
- def train_decoder(context, is_sparse):
+fc_forward = pd.fc(
+ input=src_embedding, size=hidden_dim * 3, bias_attr=False)
+src_forward = pd.dynamic_gru(input=fc_forward, size=hidden_dim)
+fc_backward = pd.fc(
+ input=src_embedding, size=hidden_dim * 3, bias_attr=False)
+src_backward = pd.dynamic_gru(
+ input=fc_backward, size=hidden_dim, is_reverse=True)
+encoded_vector = pd.concat(input=[src_forward, src_backward], axis=1)
+```
+
+接着定义解码器框架,这里以不使用注意力机制的解码器为例,对训练模式和生成模式的解码器分别说明。
+- 训练模式下的解码器
+
+1. 取源语言序列编码后的最后一个状态,并过一个前馈神经网络得到其映射
+ ```python
+ encoder_last = pd.sequence_last_step(input=encoder_out)
+ encoder_last_projected = pd.fc(
+ input=encoder_last, size=decoder_size, act='tanh')
+ ```
+2. 定义目标语言id序列的输入数据,并映射到低维语言空间的词向量
+ ```python
trg_language_word = pd.data(
- name="target_language_word", shape=[1], dtype='int64', lod_level=1)
+ name="trg_word_id", shape=[1], dtype='int64', lod_level=1)
trg_embedding = pd.embedding(
input=trg_language_word,
- size=[dict_size, word_dim],
+ size=[target_dict_dim, word_dim],
dtype='float32',
- is_sparse=is_sparse,
- param_attr=fluid.ParamAttr(name='vemb'))
-
+ is_sparse=is_sparse)
+ ```
+3. 使用 `DynamicRNN` 定义每一步的计算,包括以下内容:
+ - 获取当前步目标语言输入的词向量
+ - 获取源语言句子的上下文向量
+ - 获取隐层状态
+ - 定义GRU计算单元
+ - 计算归一化的单词预测概率
+ - 更新RNN的隐层状态
+ - 输出预测概率
+ ```python
rnn = pd.DynamicRNN()
with rnn.block():
+ # 当前步目标语言输入的词向量
current_word = rnn.step_input(trg_embedding)
- pre_state = rnn.memory(init=context)
- current_state = pd.fc(input=[current_word, pre_state],
- size=decoder_size,
- act='tanh')
-
- current_score = pd.fc(input=current_state,
- size=target_dict_dim,
- act='softmax')
+ # 源语言句子的上下文向量
+ context = rnn.static_input(encoder_last)
+ # 隐层状态,初始为 encoder_last_projected
+ pre_state = rnn.memory(
+ init=encoder_last_projected, size=decoder_size, need_reorder=True)
+ # gru计算单元:fc + gru_unit
+ decoder_inputs = pd.fc(
+ input=[current_word, context],
+ size=decoder_size * 3,
+ bias_attr=False)
+ current_state = pd.gru_unit(
+ input=decoder_inputs, hidden=pre_state, size=decoder_size)
+ # 计算归一化的单词预测概率
+ current_score = pd.fc(
+ input=current_state, size=target_dict_dim, act='softmax')
+ # 更新RNN的隐层状态
rnn.update_memory(pre_state, current_state)
+ # 输出预测概率
rnn.output(current_score)
return rnn()
-```
-
-实现推测模式下的解码器:
-
-```python
-def decode(context, is_sparse):
- init_state = context
+ ```
+
+- 生成模式下的解码器:
+
+1. 取源语言序列编码后的最后一个状态,并过一个前馈神经网络得到其映射
+ ```python
+ encoder_last = pd.sequence_last_step(input=encoder_out)
+ encoder_last_projected = pd.fc(
+ input=encoder_last, size=decoder_size, act='tanh')
+ ```
+2. 定义解码过程循环计数变量
+ ```python
array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True)
-
- # fill the first element with init_state
- state_array = pd.create_array('float32')
- pd.array_write(init_state, array=state_array, i=counter)
-
- # ids, scores as memory
+ ```
+3. 定义 tensor array 用以保存各个时间步的内容,并写入初始id和score
+ ```python
+ # 用以保存每一步beam search结果id和对应score的数组
ids_array = pd.create_array('int64')
- scores_array = pd.create_array('float32')
-
- init_ids = pd.data(name="init_ids", shape=[1], dtype="int64", lod_level=2)
- init_scores = pd.data(
- name="init_scores", shape=[1], dtype="float32", lod_level=2)
-
pd.array_write(init_ids, array=ids_array, i=counter)
+ scores_array = pd.create_array('float32')
pd.array_write(init_scores, array=scores_array, i=counter)
- cond = pd.less_than(x=counter, y=array_len)
-
+ # 用以保存每一步states和context的数组
+ state_array = pd.create_array('float32')
+ pd.array_write(encoder_last_projected, array=state_array, i=counter)
+ context_array = pd.create_array('float32')
+ pd.array_write(encoder_last, array=state_array, i=counter)
+ ```
+4. 定义 `while_op` 和循环终止条件变量
+ ```python
+ # 循环终止条件变量
+ cond = pd.less_than(x=counter, y=max_len)
while_op = pd.While(cond=cond)
- with while_op.block():
- pre_ids = pd.array_read(array=ids_array, i=counter)
- pre_state = pd.array_read(array=state_array, i=counter)
- pre_score = pd.array_read(array=scores_array, i=counter)
-
- # expand the lod of pre_state to be the same with pre_score
- pre_state_expanded = pd.sequence_expand(pre_state, pre_score)
-
- pre_ids_emb = pd.embedding(
- input=pre_ids,
- size=[dict_size, word_dim],
- dtype='float32',
- is_sparse=is_sparse)
-
- # use rnn unit to update rnn
- current_state = pd.fc(input=[pre_state_expanded, pre_ids_emb],
- size=decoder_size,
- act='tanh')
- current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score)
- # use score to do beam search
- current_score = pd.fc(input=current_state_with_lod,
- size=target_dict_dim,
- act='softmax')
- topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
- # calculate accumulated scores after topk to reduce computation cost
- accu_scores = pd.elementwise_add(
- x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
- selected_ids, selected_scores = pd.beam_search(
- pre_ids,
- pre_score,
- topk_indices,
- accu_scores,
- beam_size,
- end_id=10,
- level=0)
-
- pd.increment(x=counter, value=1, in_place=True)
-
- # update the memories
- pd.array_write(current_state, array=state_array, i=counter)
- pd.array_write(selected_ids, array=ids_array, i=counter)
- pd.array_write(selected_scores, array=scores_array, i=counter)
-
- # update the break condition: up to the max length or all candidates of
- # source sentences have ended.
- length_cond = pd.less_than(x=counter, y=array_len)
- finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
- pd.logical_and(x=length_cond, y=finish_cond, out=cond)
-
+ ```
+5. 在 `while_op.block()` 下定义每一步的计算,循环执行至终止条件变量变为 `False`,循环体包括以下内容:
+ - 获取解码器在当前步的输入,包括上一步选择的id及其对应的得分,隐层状态和源语言上下文
+ ```python
+ pre_ids = pd.array_read(array=ids_array, i=counter)
+ pre_score = pd.array_read(array=scores_array, i=counter)
+ pre_state = pd.array_read(array=state_array, i=counter)
+ pre_context = pd.array_read(array=context_array, i=counter)
+ ```
+ - 定义从词id到下一词预测概率的计算,同训练模式下解码器中的计算逻辑,包括获取输入向量,GRU计算单元和归一化单词预测概率的计算
+ ```python
+ pre_ids_emb = pd.embedding(
+ input=pre_ids,
+ size=[target_dict_dim, word_dim],
+ dtype='float32',
+ is_sparse=is_sparse)
+ decoder_inputs = pd.fc(
+ input=[pre_ids_emb, pre_context],
+ size=decoder_size * 3,
+ bias_attr=False)
+ current_state = pd.gru_unit(
+ input=decoder_inputs, hidden=pre_state, size=decoder_size)
+ current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score)
+ current_score = pd.fc(
+ input=current_state, size=target_dict_dim, act='softmax')
+ ```
+ - 计算累计得分,进行beam search
+ ```python
+ topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
+ accu_scores = pd.elementwise_add(
+ x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
+ selected_ids, selected_scores = pd.beam_search(
+ pre_ids,
+ pre_score,
+ topk_indices,
+ accu_scores,
+ beam_size,
+ end_id=10,
+ level=0)
+ ```
+ - 更新循环计数,收集search结果对应的隐层状态和源语言上下文并写入 tensor array 中
+ ```python
+ pd.increment(x=counter, value=1, in_place=True)
+
+ pd.array_write(selected_ids, array=ids_array, i=counter)
+ pd.array_write(selected_scores, array=scores_array, i=counter)
+ # 使用sequence_expand收集search结果对应的隐层状态和源语言上下文
+ current_state = pd.sequence_expand(current_state, selected_ids)
+ current_context = pd.sequence_expand(pre_context, selected_ids)
+ pd.array_write(current_state, array=state_array, i=counter)
+ pd.array_write(current_context, array=context_array, i=counter)
+ ```
+ - 更新循环终止条件
+ ```python
+ length_cond = pd.less_than(x=counter, y=array_len)
+ finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
+ pd.logical_and(x=length_cond, y=finish_cond, out=cond)
+ ```
+6. 从保存所有时间步预测结果的 tensor array 中获得完整结果
+ ```python
translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)
+ ```
- return translation_ids, translation_scores
-```
-
-进而,我们定义一个`train_program`来使用`inference_program`计算出的结果,在标记数据的帮助下来计算误差。我们还定义了一个`optimizer_func`来定义优化器。
+对于训练模式,还需要定义损失函数,使用交叉熵作为损失函数,train_program 的定义如下:
```python
-def train_program(is_sparse):
- context = encoder(is_sparse)
- rnn_out = train_decoder(context, is_sparse)
+def train_program():
+ encoder_out = encoder()
+ rnn_out = train_decoder(encoder_out)
label = pd.data(
- name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
+ name="trg_next_word_id", shape=[1], dtype='int64', lod_level=1)
cost = pd.cross_entropy(input=rnn_out, label=label)
avg_cost = pd.mean(cost)
return avg_cost
-
-
+```
+此外,还需定义要使用的优化方法,如下:
+```python
def optimizer_func():
return fluid.optimizer.Adagrad(
learning_rate=1e-4,
diff --git a/08.machine_translation/infer.py b/08.machine_translation/infer.py
index 290d8e8f3ef461911634511a166a64e96817efc1..f4ed5729c6f2b57a3e6031c47c49d944e8ab7bb0 100644
--- a/08.machine_translation/infer.py
+++ b/08.machine_translation/infer.py
@@ -23,15 +23,14 @@ import os
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
-hidden_dim = 32
word_dim = 32
-batch_size = 2
+hidden_dim = 32
+decoder_size = hidden_dim
max_length = 8
-topk_size = 50
beam_size = 2
+batch_size = 2
is_sparse = True
-decoder_size = hidden_dim
model_save_dir = "machine_translation.inference.model"
@@ -40,66 +39,71 @@ def encoder():
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = pd.embedding(
input=src_word_id,
- size=[dict_size, word_dim],
+ size=[source_dict_dim, word_dim],
dtype='float32',
- is_sparse=is_sparse,
- param_attr=fluid.ParamAttr(name='vemb'))
+ is_sparse=is_sparse)
- fc1 = pd.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
- lstm_hidden0, lstm_0 = pd.dynamic_lstm(input=fc1, size=hidden_dim * 4)
- encoder_out = pd.sequence_last_step(input=lstm_hidden0)
- return encoder_out
+ fc_forward = pd.fc(
+ input=src_embedding, size=hidden_dim * 3, bias_attr=False)
+ src_forward = pd.dynamic_gru(input=fc_forward, size=hidden_dim)
+ fc_backward = pd.fc(
+ input=src_embedding, size=hidden_dim * 3, bias_attr=False)
+ src_backward = pd.dynamic_gru(
+ input=fc_backward, size=hidden_dim, is_reverse=True)
+ encoded_vector = pd.concat(input=[src_forward, src_backward], axis=1)
+ return encoded_vector
-def decode(context):
- init_state = context
- array_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
- counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True)
-
- # fill the first element with init_state
- state_array = pd.create_array('float32')
- pd.array_write(init_state, array=state_array, i=counter)
+def decode(encoder_out):
+ encoder_last = pd.sequence_last_step(input=encoder_out)
+ encoder_last_projected = pd.fc(
+ input=encoder_last, size=decoder_size, act='tanh')
- # ids, scores as memory
- ids_array = pd.create_array('int64')
- scores_array = pd.create_array('float32')
+ max_len = pd.fill_constant(shape=[1], dtype='int64', value=max_length)
+ counter = pd.zeros(shape=[1], dtype='int64', force_cpu=True)
init_ids = pd.data(name="init_ids", shape=[1], dtype="int64", lod_level=2)
init_scores = pd.data(
name="init_scores", shape=[1], dtype="float32", lod_level=2)
+ # arrays to save selected ids and corresponding scores for each step
+ ids_array = pd.create_array('int64')
pd.array_write(init_ids, array=ids_array, i=counter)
+ scores_array = pd.create_array('float32')
pd.array_write(init_scores, array=scores_array, i=counter)
- cond = pd.less_than(x=counter, y=array_len)
+ # arrays to save states and context for each step
+ state_array = pd.create_array('float32')
+ pd.array_write(encoder_last_projected, array=state_array, i=counter)
+ context_array = pd.create_array('float32')
+ pd.array_write(encoder_last, array=state_array, i=counter)
+ cond = pd.less_than(x=counter, y=max_len)
while_op = pd.While(cond=cond)
with while_op.block():
pre_ids = pd.array_read(array=ids_array, i=counter)
- pre_state = pd.array_read(array=state_array, i=counter)
pre_score = pd.array_read(array=scores_array, i=counter)
+ pre_state = pd.array_read(array=state_array, i=counter)
+ pre_context = pd.array_read(array=context_array, i=counter)
- # expand the lod of pre_state to be the same with pre_score
- pre_state_expanded = pd.sequence_expand(pre_state, pre_score)
-
+ # cell calculations
pre_ids_emb = pd.embedding(
input=pre_ids,
- size=[dict_size, word_dim],
+ size=[target_dict_dim, word_dim],
dtype='float32',
- is_sparse=is_sparse,
- param_attr=fluid.ParamAttr(name='vemb'))
-
- # use rnn unit to update rnn
- current_state = pd.fc(
- input=[pre_state_expanded, pre_ids_emb],
- size=decoder_size,
- act='tanh')
+ is_sparse=is_sparse)
+ decoder_inputs = pd.fc(
+ input=[pre_ids_emb, pre_context],
+ size=decoder_size * 3,
+ bias_attr=False)
+ current_state = pd.gru_unit(
+ input=decoder_inputs, hidden=pre_state, size=decoder_size)
current_state_with_lod = pd.lod_reset(x=current_state, y=pre_score)
- # use score to do beam search
current_score = pd.fc(
- input=current_state_with_lod, size=target_dict_dim, act='softmax')
+ input=current_state, size=target_dict_dim, act='softmax')
+
+ # beam search
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
- # calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
@@ -111,23 +115,20 @@ def decode(context):
end_id=10,
level=0)
- with pd.Switch() as switch:
- with switch.case(pd.is_empty(selected_ids)):
- pd.fill_constant(
- shape=[1], value=0, dtype='bool', force_cpu=True, out=cond)
- with switch.default():
- pd.increment(x=counter, value=1, in_place=True)
-
- # update the memories
- pd.array_write(current_state, array=state_array, i=counter)
- pd.array_write(selected_ids, array=ids_array, i=counter)
- pd.array_write(selected_scores, array=scores_array, i=counter)
-
- # update the break condition: up to the max length or all candidates of
- # source sentences have ended.
- length_cond = pd.less_than(x=counter, y=array_len)
- finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
- pd.logical_and(x=length_cond, y=finish_cond, out=cond)
+ pd.increment(x=counter, value=1, in_place=True)
+ # update states
+ pd.array_write(selected_ids, array=ids_array, i=counter)
+ pd.array_write(selected_scores, array=scores_array, i=counter)
+ # update rnn state by sequence_expand acting as gather
+ current_state = pd.sequence_expand(current_state, selected_ids)
+ current_context = pd.sequence_expand(pre_context, selected_ids)
+ pd.array_write(current_state, array=state_array, i=counter)
+ pd.array_write(current_context, array=context_array, i=counter)
+
+ # update conditional variable
+ length_cond = pd.less_than(x=counter, y=array_len)
+ finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
+ pd.logical_and(x=length_cond, y=finish_cond, out=cond)
translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)
@@ -143,8 +144,8 @@ def decode_main(use_cuda):
exe = Executor(place)
exe.run(framework.default_startup_program())
- context = encoder()
- translation_ids, translation_scores = decode(context)
+ encoder_out = encoder()
+ translation_ids, translation_scores = decode(encoder_out)
fluid.io.load_persistables(executor=exe, dirname=model_save_dir)
init_ids_data = np.array([1 for _ in range(batch_size)], dtype='int64')
diff --git a/08.machine_translation/train.py b/08.machine_translation/train.py
index 589e4dc4a49ab6833ddb3b8946b13d41b1933a21..0d015f6c25779a67704b996eb64b6a9274041612 100644
--- a/08.machine_translation/train.py
+++ b/08.machine_translation/train.py
@@ -29,15 +29,14 @@ except ImportError:
dict_size = 30000
source_dict_dim = target_dict_dim = dict_size
-hidden_dim = 32
word_dim = 32
-batch_size = 2
+hidden_dim = 32
+decoder_size = hidden_dim
max_length = 8
-topk_size = 50
beam_size = 2
+batch_size = 2
is_sparse = True
-decoder_size = hidden_dim
model_save_dir = "machine_translation.inference.model"
@@ -46,36 +45,50 @@ def encoder():
name="src_word_id", shape=[1], dtype='int64', lod_level=1)
src_embedding = pd.embedding(
input=src_word_id,
- size=[dict_size, word_dim],
+ size=[source_dict_dim, word_dim],
dtype='float32',
- is_sparse=is_sparse,
- param_attr=fluid.ParamAttr(name='vemb'))
-
- fc1 = pd.fc(input=src_embedding, size=hidden_dim * 4, act='tanh')
- lstm_hidden0, lstm_0 = pd.dynamic_lstm(input=fc1, size=hidden_dim * 4)
- encoder_out = pd.sequence_last_step(input=lstm_hidden0)
- return encoder_out
-
-
-def train_decoder(context):
+ is_sparse=is_sparse)
+
+ fc_forward = pd.fc(
+ input=src_embedding, size=hidden_dim * 3, bias_attr=False)
+ src_forward = pd.dynamic_gru(input=fc_forward, size=hidden_dim)
+ fc_backward = pd.fc(
+ input=src_embedding, size=hidden_dim * 3, bias_attr=False)
+ src_backward = pd.dynamic_gru(
+ input=fc_backward, size=hidden_dim, is_reverse=True)
+ encoded_vector = pd.concat(input=[src_forward, src_backward], axis=1)
+ return encoded_vector
+
+
+def train_decoder(encoder_out):
+ encoder_last = pd.sequence_last_step(input=encoder_out)
+ encoder_last_projected = pd.fc(
+ input=encoder_last, size=decoder_size, act='tanh')
trg_language_word = pd.data(
- name="target_language_word", shape=[1], dtype='int64', lod_level=1)
+ name="trg_word_id", shape=[1], dtype='int64', lod_level=1)
trg_embedding = pd.embedding(
input=trg_language_word,
- size=[dict_size, word_dim],
+ size=[target_dict_dim, word_dim],
dtype='float32',
- is_sparse=is_sparse,
- param_attr=fluid.ParamAttr(name='vemb'))
+ is_sparse=is_sparse)
rnn = pd.DynamicRNN()
with rnn.block():
current_word = rnn.step_input(trg_embedding)
- pre_state = rnn.memory(init=context, need_reorder=True)
- current_state = pd.fc(
- input=[current_word, pre_state], size=decoder_size, act='tanh')
+ context = rnn.static_input(encoder_last)
+ pre_state = rnn.memory(
+ init=encoder_last_projected, size=decoder_size, need_reorder=True)
+
+ decoder_inputs = pd.fc(
+ input=[current_word, context],
+ size=decoder_size * 3,
+ bias_attr=False)
+ current_state = pd.gru_unit(
+ input=decoder_inputs, hidden=pre_state, size=decoder_size)
current_score = pd.fc(
input=current_state, size=target_dict_dim, act='softmax')
+
rnn.update_memory(pre_state, current_state)
rnn.output(current_score)
@@ -83,10 +96,10 @@ def train_decoder(context):
def train_program():
- context = encoder()
- rnn_out = train_decoder(context)
+ encoder_out = encoder()
+ rnn_out = train_decoder(encoder_out)
label = pd.data(
- name="target_language_next_word", shape=[1], dtype='int64', lod_level=1)
+ name="trg_next_word_id", shape=[1], dtype='int64', lod_level=1)
cost = pd.cross_entropy(input=rnn_out, label=label)
avg_cost = pd.mean(cost)
return avg_cost