提交 885a8152 编写于 作者: Q qiaolongfei

fix some format problem

上级 28935ac4
......@@ -270,57 +270,60 @@ pre-wmt14
### 提供数据给PaddlePaddle
1. 生成词典
根据用户配置的词典长度,将数据预处理阶段生成的词典文件(src.dict, trg.dict)load到内存中生成一个map,结构为 {"word": index}.
```python
def __read_to_dict__(dict_path, count):
with open(dict_path, "r") as fin:
out_dict = dict()
for line_count, line in enumerate(fin):
if line_count <= count:
out_dict[line.strip()] = line_count
else:
break
return out_dict
根据用户配置的词典长度,将数据预处理阶段生成的词典文件(src.dict, trg.dict)load到内存中生成一个map,结构为 {"word": index}.
```python
def __read_to_dict__(dict_path, count):
with open(dict_path, "r") as fin:
out_dict = dict()
for line_count, line in enumerate(fin):
if line_count <= count:
out_dict[line.strip()] = line_count
else:
break
return out_dict
```
2. 读取训练数据(如: pre-wmt14/train/train),通过词典将word转换为对应的index.并且:
在源语言序列的每句话前面补上开始符号`<s>`、末尾补上结束符号`<e>`,得到“source_language_word”;
在目标语言序列的每句话前面补上`<s>`,得到“target_language_word”;
在目标语言序列的每句话末尾补上`<e>`,作为目标语言的下一个词序列(“target_language_next_word”)
2. 读取训练数据
通过yield返回给trainer.
```python
def __reader__(file_name, src_dict, trg_dict):
with open(file_name, 'r') as f:
for line_count, line in enumerate(f):
line_split = line.strip().split('\t')
if len(line_split) != 2:
continue
src_seq = line_split[0] # one source sequence
src_words = src_seq.split()
src_ids = [
src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END]
]
trg_seq = line_split[1] # one target sequence
trg_words = trg_seq.split()
trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]
# remove sequence whose length > 80 in training mode
if len(src_ids) > 80 or len(trg_ids) > 80:
continue
trg_ids_next = trg_ids + [trg_dict[END]]
trg_ids = [trg_dict[START]] + trg_ids
yield src_ids, trg_ids, trg_ids_next
```
读取预处理之后的数据文件,如pre-wmt14/train/train, 通过词典将word转换为对应的index。注意:
- 在源语言序列的每句话前面补上开始符号`<s>`、末尾补上结束符号`<e>`,得到“source_language_word”;
- 在目标语言序列的每句话前面补上`<s>`,得到“target_language_word”;
- 在目标语言序列的每句话末尾补上`<e>`,作为目标语言的下一个词序列(“target_language_next_word”)
然后通过yield返回给trainer.
```python
def __reader__(file_name, src_dict, trg_dict):
with open(file_name, 'r') as f:
for line_count, line in enumerate(f):
line_split = line.strip().split('\t')
if len(line_split) != 2:
continue
src_seq = line_split[0] # one source sequence
src_words = src_seq.split()
src_ids = [
src_dict.get(w, UNK_IDX) for w in [START] + src_words + [END]
]
trg_seq = line_split[1] # one target sequence
trg_words = trg_seq.split()
trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]
# remove sequence whose length > 80 in training mode
if len(src_ids) > 80 or len(trg_ids) > 80:
continue
trg_ids_next = trg_ids + [trg_dict[END]]
trg_ids = [trg_dict[START]] + trg_ids
yield src_ids, trg_ids, trg_ids_next
```
## 模型配置说明
### 数据定义
1,首先要定义词典大小,数据生成和网络配置都需要用到。
1. 首先要定义词典大小,数据生成和网络配置都需要用到。
```python
# source and target dict dim.
......@@ -473,38 +476,36 @@ def __reader__(file_name, src_dict, trg_dict):
```
注意:我们提供的配置在Bahdanau的论文\[[4](#参考文献)\]上做了一些简化,可参考[issue #1133](https://github.com/PaddlePaddle/Paddle/issues/1133)。
### 参数定义
首先依据模型配置的`cost`定义模型参数。
```python
# create parameters
parameters = paddle.parameters.create(cost)
```
可以打印参数名字,如果在网络配置中没有指定名字,则默认生成。
```python
for param in parameters.keys():
print param
```
### 定义参数
首先依据模型配置的`cost`定义模型参数。
```python
# create parameters
parameters = paddle.parameters.create(cost)
```
可以打印参数名字,如果在网络配置中没有指定名字,则默认生成。
```python
for param in parameters.keys():
print param
```
## 训练模型
### 训练模型
1. 构造trainer
根据优化目标cost,网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的SGD方法。
```python
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
根据优化目标cost,网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的SGD方法。
```python
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer)
```
2. 构造数据reader
reader负责读取数据变转换为paddle需要的格式。reader_dict指定字段在模型中的顺序。
```python
reader负责读取数据变转换为paddle需要的格式。reader_dict指定字段在模型中的顺序。
```python
wmt14_reader = paddle.reader.batched(
paddle.reader.shuffle(
train_reader("data/pre-wmt14/train/train"), buf_size=8192),
......@@ -515,31 +516,32 @@ reader负责读取数据变转换为paddle需要的格式。reader_dict指定字
'target_language_word': 1,
'target_language_next_word': 2
}
```
```
3. 开始训练
可以通过自定义回调函数来评估训练过程中的各种述职,比如错误率等。下面的代码通过event.batch_id % 10 == 0
指定没10个batch打印一次日志,包含cost等信息。
```python
可以通过自定义回调函数来评估训练过程中的各种述职,比如错误率等。下面的代码通过event.batch_id % 10 == 0
指定没10个batch打印一次日志,包含cost等信息。
```python
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
```
```
4. 启动训练:
启动训练:
```python
```python
trainer.train(
reader=wmt14_reader,
event_handler=event_handler,
num_passes=10000,
reader_dict=reader_dict)
```
训练开始后,可以观察到event_handler输出的日志如下:
```text
...
```
训练开始后,可以观察到event_handler输出的日志如下:
```text
Pass 0, Batch 0, Cost 247.408008, {'classification_error_evaluator': 1.0}
Pass 0, Batch 10, Cost 212.058789, {'classification_error_evaluator': 0.8737863898277283}
...
```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册