提交 885a8152 编写于 作者: Q qiaolongfei

fix some format problem

上级 28935ac4
...@@ -270,57 +270,60 @@ pre-wmt14 ...@@ -270,57 +270,60 @@ pre-wmt14
### 提供数据给PaddlePaddle ### 提供数据给PaddlePaddle
1. 生成词典 1. 生成词典
根据用户配置的词典长度,将数据预处理阶段生成的词典文件(src.dict, trg.dict)load到内存中生成一个map,结构为 {"word": index}.
```python 根据用户配置的词典长度,将数据预处理阶段生成的词典文件(src.dict, trg.dict)load到内存中生成一个map,结构为 {"word": index}.
def __read_to_dict__(dict_path, count):
with open(dict_path, "r") as fin: ```python
out_dict = dict() def __read_to_dict__(dict_path, count):
for line_count, line in enumerate(fin): with open(dict_path, "r") as fin:
if line_count <= count: out_dict = dict()
out_dict[line.strip()] = line_count for line_count, line in enumerate(fin):
else: if line_count <= count:
break out_dict[line.strip()] = line_count
return out_dict else:
break
return out_dict
``` ```
2. 读取训练数据(如: pre-wmt14/train/train),通过词典将word转换为对应的index.并且: 2. 读取训练数据
在源语言序列的每句话前面补上开始符号`<s>`、末尾补上结束符号`<e>`,得到“source_language_word”;
在目标语言序列的每句话前面补上`<s>`,得到“target_language_word”;
在目标语言序列的每句话末尾补上`<e>`,作为目标语言的下一个词序列(“target_language_next_word”)
通过yield返回给trainer. 读取预处理之后的数据文件,如pre-wmt14/train/train, 通过词典将word转换为对应的index。注意:
```python
def __reader__(file_name, src_dict, trg_dict): - 在源语言序列的每句话前面补上开始符号`<s>`、末尾补上结束符号`<e>`,得到“source_language_word”;
with open(file_name, 'r') as f: - 在目标语言序列的每句话前面补上`<s>`,得到“target_language_word”;
for line_count, line in enumerate(f): - 在目标语言序列的每句话末尾补上`<e>`,作为目标语言的下一个词序列(“target_language_next_word”)
line_split = line.strip().split('\t')
if len(line_split) != 2: 然后通过yield返回给trainer.
continue ```python
src_seq = line_split[0] # one source sequence def __reader__(file_name, src_dict, trg_dict):
src_words = src_seq.split() with open(file_name, 'r') as f:
src_ids = [ for line_count, line in enumerate(f):
src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END] line_split = line.strip().split('\t')
] if len(line_split) != 2:
continue
trg_seq = line_split[1] # one target sequence src_seq = line_split[0] # one source sequence
trg_words = trg_seq.split() src_words = src_seq.split()
trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words] src_ids = [
src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END]
# remove sequence whose length > 80 in training mode ]
if len(src_ids) > 80 or len(trg_ids) > 80:
continue trg_seq = line_split[1] # one target sequence
trg_ids_next = trg_ids + [trg_dict[END]] trg_words = trg_seq.split()
trg_ids = [trg_dict[START]] + trg_ids trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]
yield src_ids, trg_ids, trg_ids_next # remove sequence whose length > 80 in training mode
``` if len(src_ids) > 80 or len(trg_ids) > 80:
continue
trg_ids_next = trg_ids + [trg_dict[END]]
trg_ids = [trg_dict[START]] + trg_ids
yield src_ids, trg_ids, trg_ids_next
```
## 模型配置说明 ## 模型配置说明
### 数据定义 ### 数据定义
1. 首先要定义词典大小,数据生成和网络配置都需要用到。
1,首先要定义词典大小,数据生成和网络配置都需要用到。
```python ```python
# source and target dict dim. # source and target dict dim.
...@@ -473,38 +476,36 @@ def __reader__(file_name, src_dict, trg_dict): ...@@ -473,38 +476,36 @@ def __reader__(file_name, src_dict, trg_dict):
``` ```
注意:我们提供的配置在Bahdanau的论文\[[4](#参考文献)\]上做了一些简化,可参考[issue #1133](https://github.com/PaddlePaddle/Paddle/issues/1133)。 注意:我们提供的配置在Bahdanau的论文\[[4](#参考文献)\]上做了一些简化,可参考[issue #1133](https://github.com/PaddlePaddle/Paddle/issues/1133)。
### 参数定义
首先依据模型配置的`cost`定义模型参数。
```python
# create parameters
parameters = paddle.parameters.create(cost)
```
可以打印参数名字,如果在网络配置中没有指定名字,则默认生成。
```python
for param in parameters.keys():
print param
```
### 定义参数 ### 训练模型
首先依据模型配置的`cost`定义模型参数。
```python
# create parameters
parameters = paddle.parameters.create(cost)
```
可以打印参数名字,如果在网络配置中没有指定名字,则默认生成。
```python
for param in parameters.keys():
print param
```
## 训练模型
1. 构造trainer 1. 构造trainer
根据优化目标cost,网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的SGD方法。
```python 根据优化目标cost,网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的SGD方法。
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
trainer = paddle.trainer.SGD(cost=cost, ```python
parameters=parameters, optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
update_equation=optimizer) trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
``` ```
2. 构造数据reader 2. 构造数据reader
reader负责读取数据变转换为paddle需要的格式。reader_dict指定字段在模型中的顺序。
```python reader负责读取数据变转换为paddle需要的格式。reader_dict指定字段在模型中的顺序。
```python
wmt14_reader = paddle.reader.batched( wmt14_reader = paddle.reader.batched(
paddle.reader.shuffle( paddle.reader.shuffle(
train_reader("data/pre-wmt14/train/train"), buf_size=8192), train_reader("data/pre-wmt14/train/train"), buf_size=8192),
...@@ -515,31 +516,32 @@ reader负责读取数据变转换为paddle需要的格式。reader_dict指定字 ...@@ -515,31 +516,32 @@ reader负责读取数据变转换为paddle需要的格式。reader_dict指定字
'target_language_word': 1, 'target_language_word': 1,
'target_language_next_word': 2 'target_language_next_word': 2
} }
``` ```
3. 开始训练 3. 开始训练
可以通过自定义回调函数来评估训练过程中的各种述职,比如错误率等。下面的代码通过event.batch_id % 10 == 0
指定没10个batch打印一次日志,包含cost等信息。 可以通过自定义回调函数来评估训练过程中的各种述职,比如错误率等。下面的代码通过event.batch_id % 10 == 0
```python 指定没10个batch打印一次日志,包含cost等信息。
```python
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0: if event.batch_id % 10 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % ( print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost, event.metrics)
``` ```
4. 启动训练:
启动训练: ```python
```python
trainer.train( trainer.train(
reader=wmt14_reader, reader=wmt14_reader,
event_handler=event_handler, event_handler=event_handler,
num_passes=10000, num_passes=10000,
reader_dict=reader_dict) reader_dict=reader_dict)
``` ```
训练开始后,可以观察到event_handler输出的日志如下: 训练开始后,可以观察到event_handler输出的日志如下:
```text ```text
Pass 0, Batch 0, Cost 247.408008, {'classification_error_evaluator': 1.0}
... Pass 0, Batch 10, Cost 212.058789, {'classification_error_evaluator': 0.8737863898277283}
...
``` ```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册