提交 cd9835a5 编写于 作者: Q qiaolongfei

update code

上级 c79eaa79
...@@ -150,7 +150,7 @@ e_{ij}&=align(z_i,h_j)\\\\ ...@@ -150,7 +150,7 @@ e_{ij}&=align(z_i,h_j)\\\\
注意:$z_{i+1}$和$p_{i+1}$的计算公式同[解码器](#解码器)中的一样。且由于生成时的每一步都是通过贪心法实现的,因此并不能保证得到全局最优解。 注意:$z_{i+1}$和$p_{i+1}$的计算公式同[解码器](#解码器)中的一样。且由于生成时的每一步都是通过贪心法实现的,因此并不能保证得到全局最优解。
## 数据准备 ## 数据准备(默认已提供测试数据,可跳过)
### 下载与解压缩 ### 下载与解压缩
...@@ -315,6 +315,7 @@ pre-wmt14 ...@@ -315,6 +315,7 @@ pre-wmt14
- 在目标语言序列的每句话末尾补上`<e>`,作为目标语言的下一个词序列(“target_language_next_word”) - 在目标语言序列的每句话末尾补上`<e>`,作为目标语言的下一个词序列(“target_language_next_word”)
然后通过yield返回给trainer. 然后通过yield返回给trainer.
```python ```python
def reader_creator(tar_file, file_name, dict_size): def reader_creator(tar_file, file_name, dict_size):
def reader(): def reader():
...@@ -351,25 +352,26 @@ pre-wmt14 ...@@ -351,25 +352,26 @@ pre-wmt14
return reader return reader
``` ```
## 模型配置说明 ## 训练流程说明
### 数据定义 ### 数据定义
1. 首先要定义词典大小,数据生成和网络配置都需要用到。
```python 首先要定义词典大小,数据生成和网络配置都需要用到。然后获取wmt14的dataset reader。
# source and target dict dim.
dict_size = 30000
reader_dict = { ```python
'source_language_word': 0, # source and target dict dim.
'target_language_word': 1, dict_size = 30000
'target_language_next_word': 2
} feeding = {
wmt14_reader = paddle.batch( 'source_language_word': 0,
paddle.reader.shuffle( 'target_language_word': 1,
paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192), 'target_language_next_word': 2
batch_size=5) }
``` wmt14_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(dict_size=dict_size), buf_size=8192),
batch_size=5)
```
### 模型结构 ### 模型结构
1. 首先,定义了一些全局变量。 1. 首先,定义了一些全局变量。
...@@ -384,7 +386,8 @@ pre-wmt14 ...@@ -384,7 +386,8 @@ pre-wmt14
2. 其次,实现编码器框架。分为三步: 2. 其次,实现编码器框架。分为三步:
2.1 传入已经在data_set中转换成one-hot vector表示的源语言序列$\mathbf{w}$,其类型为integer_value_sequence。 2.1 将在dataset reader中生成的用每个单词在字典中的索引表示的源语言序列
转换成one-hot vector表示的源语言序列$\mathbf{w}$,其类型为integer_value_sequence。
```python ```python
src_word_id = paddle.layer.data( src_word_id = paddle.layer.data(
...@@ -543,7 +546,7 @@ pre-wmt14 ...@@ -543,7 +546,7 @@ pre-wmt14
train_reader("data/pre-wmt14/train/train"), buf_size=8192), train_reader("data/pre-wmt14/train/train"), buf_size=8192),
batch_size=5) batch_size=5)
reader_dict = { feeding = {
'source_language_word': 0, 'source_language_word': 0,
'target_language_word': 1, 'target_language_word': 1,
'target_language_next_word': 2 'target_language_next_word': 2
...@@ -567,7 +570,7 @@ pre-wmt14 ...@@ -567,7 +570,7 @@ pre-wmt14
reader=wmt14_reader, reader=wmt14_reader,
event_handler=event_handler, event_handler=event_handler,
num_passes=10000, num_passes=10000,
reader_dict=reader_dict) feeding=feeding)
``` ```
训练开始后,可以观察到event_handler输出的日志如下: 训练开始后,可以观察到event_handler输出的日志如下:
```text ```text
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册