Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
book
提交
885a8152
B
book
项目概览
PaddlePaddle
/
book
通知
16
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
40
列表
看板
标记
里程碑
合并请求
37
Wiki
5
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
40
Issue
40
列表
看板
标记
里程碑
合并请求
37
合并请求
37
Pages
分析
分析
仓库分析
DevOps
Wiki
5
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
885a8152
编写于
3月 06, 2017
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix some format problem
上级
28935ac4
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
85 addition
and
83 deletion
+85
-83
machine_translation/README.md
machine_translation/README.md
+85
-83
未找到文件。
machine_translation/README.md
浏览文件 @
885a8152
...
@@ -270,57 +270,60 @@ pre-wmt14
...
@@ -270,57 +270,60 @@ pre-wmt14
### 提供数据给PaddlePaddle
### 提供数据给PaddlePaddle
1.
生成词典
1.
生成词典
根据用户配置的词典长度,将数据预处理阶段生成的词典文件(src.dict, trg.dict)load到内存中生成一个map,结构为 {"word": index}.
```
python
根据用户配置的词典长度,将数据预处理阶段生成的词典文件(src.dict, trg.dict)load到内存中生成一个map,结构为 {"word": index}.
def
__read_to_dict__
(
dict_path
,
count
):
with
open
(
dict_path
,
"r"
)
as
fin
:
```python
out_dict
=
dict
()
def __read_to_dict__(dict_path, count):
for
line_count
,
line
in
enumerate
(
fin
):
with open(dict_path, "r") as fin:
if
line_count
<=
count
:
out_dict = dict()
out_dict
[
line
.
strip
()]
=
line_count
for line_count, line in enumerate(fin):
else
:
if line_count <= count:
break
out_dict[line.strip()] = line_count
return
out_dict
else:
break
return out_dict
```
```
2.
读取训练数据(如: pre-wmt14/train/train),通过词典将word转换为对应的index.并且:
2. 读取训练数据
在源语言序列的每句话前面补上开始符号
`<s>`
、末尾补上结束符号
`<e>`
,得到“source_language_word”;
在目标语言序列的每句话前面补上
`<s>`
,得到“target_language_word”;
在目标语言序列的每句话末尾补上
`<e>`
,作为目标语言的下一个词序列(“target_language_next_word”)
通过yield返回给trainer.
读取预处理之后的数据文件,如pre-wmt14/train/train, 通过词典将word转换为对应的index。注意:
```
python
def
__reader__
(
file_name
,
src_dict
,
trg_dict
):
- 在源语言序列的每句话前面补上开始符号`<s>`、末尾补上结束符号`<e>`,得到“source_language_word”;
with
open
(
file_name
,
'r'
)
as
f
:
- 在目标语言序列的每句话前面补上`<s>`,得到“target_language_word”;
for
line_count
,
line
in
enumerate
(
f
):
- 在目标语言序列的每句话末尾补上`<e>`,作为目标语言的下一个词序列(“target_language_next_word”)
line_split
=
line
.
strip
().
split
(
'
\t
'
)
if
len
(
line_split
)
!=
2
:
然后通过yield返回给trainer.
continue
```python
src_seq
=
line_split
[
0
]
# one source sequence
def __reader__(file_name, src_dict, trg_dict):
src_words
=
src_seq
.
split
()
with open(file_name, 'r') as f:
src_ids
=
[
for line_count, line in enumerate(f):
src_dict
.
get
(
w
,
UNK_IDX
)
for
w
in
[
START
]
+
src_words
+
[
END
]
line_split = line.strip().split('\t')
]
if len(line_split) != 2:
continue
trg_seq
=
line_split
[
1
]
# one target sequence
src_seq = line_split[0] # one source sequence
trg_words
=
trg_seq
.
split
()
src_words = src_seq.split()
trg_ids
=
[
trg_dict
.
get
(
w
,
UNK_IDX
)
for
w
in
trg_words
]
src_ids = [
src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END]
# remove sequence whose length > 80 in training mode
]
if
len
(
src_ids
)
>
80
or
len
(
trg_ids
)
>
80
:
continue
trg_seq = line_split[1] # one target sequence
trg_ids_next
=
trg_ids
+
[
trg_dict
[
END
]]
trg_words = trg_seq.split()
trg_ids
=
[
trg_dict
[
START
]]
+
trg_ids
trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]
yield
src_ids
,
trg_ids
,
trg_ids_next
# remove sequence whose length > 80 in training mode
```
if len(src_ids) > 80 or len(trg_ids) > 80:
continue
trg_ids_next = trg_ids + [trg_dict[END]]
trg_ids = [trg_dict[START]] + trg_ids
yield src_ids, trg_ids, trg_ids_next
```
## 模型配置说明
## 模型配置说明
### 数据定义
### 数据定义
1. 首先要定义词典大小,数据生成和网络配置都需要用到。
1,首先要定义词典大小,数据生成和网络配置都需要用到。
```python
```python
# source and target dict dim.
# source and target dict dim.
...
@@ -473,38 +476,36 @@ def __reader__(file_name, src_dict, trg_dict):
...
@@ -473,38 +476,36 @@ def __reader__(file_name, src_dict, trg_dict):
```
```
注意:我们提供的配置在Bahdanau的论文\[[4](#参考文献)\]上做了一些简化,可参考[issue #1133](https://github.com/PaddlePaddle/Paddle/issues/1133)。
注意:我们提供的配置在Bahdanau的论文\[[4](#参考文献)\]上做了一些简化,可参考[issue #1133](https://github.com/PaddlePaddle/Paddle/issues/1133)。
### 参数定义
首先依据模型配置的`cost`定义模型参数。
```python
# create parameters
parameters = paddle.parameters.create(cost)
```
可以打印参数名字,如果在网络配置中没有指定名字,则默认生成。
```python
for param in parameters.keys():
print param
```
### 定义参数
### 训练模型
首先依据模型配置的
`cost`
定义模型参数。
```
python
# create parameters
parameters
=
paddle
.
parameters
.
create
(
cost
)
```
可以打印参数名字,如果在网络配置中没有指定名字,则默认生成。
```
python
for
param
in
parameters
.
keys
():
print
param
```
## 训练模型
1. 构造trainer
1. 构造trainer
根据优化目标cost,网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的SGD方法。
```
python
根据优化目标cost,网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的SGD方法。
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
1e-4
)
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
```python
parameters
=
parameters
,
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
update_equation
=
optimizer
)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
```
```
2.
构造数据reader
2.
构造数据reader
reader负责读取数据变转换为paddle需要的格式。reader_dict指定字段在模型中的顺序。
```
python
reader负责读取数据变转换为paddle需要的格式。reader_dict指定字段在模型中的顺序。
```python
wmt14_reader = paddle.reader.batched(
wmt14_reader = paddle.reader.batched(
paddle.reader.shuffle(
paddle.reader.shuffle(
train_reader("data/pre-wmt14/train/train"), buf_size=8192),
train_reader("data/pre-wmt14/train/train"), buf_size=8192),
...
@@ -515,31 +516,32 @@ reader负责读取数据变转换为paddle需要的格式。reader_dict指定字
...
@@ -515,31 +516,32 @@ reader负责读取数据变转换为paddle需要的格式。reader_dict指定字
'target_language_word': 1,
'target_language_word': 1,
'target_language_next_word': 2
'target_language_next_word': 2
}
}
```
```
3.
开始训练
3.
开始训练
可以通过自定义回调函数来评估训练过程中的各种述职,比如错误率等。下面的代码通过event.batch_id % 10 == 0
指定没10个batch打印一次日志,包含cost等信息。
可以通过自定义回调函数来评估训练过程中的各种述职,比如错误率等。下面的代码通过event.batch_id % 10 == 0
```
python
指定没10个batch打印一次日志,包含cost等信息。
```python
def event_handler(event):
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0:
if event.batch_id % 10 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
event.pass_id, event.batch_id, event.cost, event.metrics)
```
```
4.
启动训练:
启动训练:
```python
```
python
trainer.train(
trainer.train(
reader=wmt14_reader,
reader=wmt14_reader,
event_handler=event_handler,
event_handler=event_handler,
num_passes=10000,
num_passes=10000,
reader_dict=reader_dict)
reader_dict=reader_dict)
```
```
训练开始后,可以观察到event_handler输出的日志如下:
训练开始后,可以观察到event_handler输出的日志如下:
```
text
```text
Pass 0, Batch 0, Cost 247.408008, {'classification_error_evaluator': 1.0}
...
Pass 0, Batch 10, Cost 212.058789, {'classification_error_evaluator': 0.8737863898277283}
...
```
```
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录