Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
775cf990
M
models
项目概览
PaddlePaddle
/
models
接近 2 年 前同步成功
通知
230
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
775cf990
编写于
10月 29, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the inplace reshape in inference of Transformer and refine README
上级
58e9bc20
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
17 deletion
+26
-17
fluid/PaddleNLP/neural_machine_translation/transformer/README_cn.md
...leNLP/neural_machine_translation/transformer/README_cn.md
+17
-15
fluid/PaddleNLP/neural_machine_translation/transformer/model.py
...PaddleNLP/neural_machine_translation/transformer/model.py
+9
-2
未找到文件。
fluid/PaddleNLP/neural_machine_translation/transformer/README_cn.md
浏览文件 @
775cf990
...
@@ -93,9 +93,13 @@ python -u train.py \
...
@@ -93,9 +93,13 @@ python -u train.py \
python train.py
--help
python train.py
--help
```
```
更多模型训练相关的参数则在
`config.py`
中的
`ModelHyperParams`
和
`TrainTaskConfig`
内定义;
`ModelHyperParams`
定义了 embedding 维度等模型超参数,
`TrainTaskConfig`
定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖
`config.py`
中的配置,如可以通过以下命令来训练 Transformer 论文中的 big model (如显存不够可适当减小 batch size 的值):
更多模型训练相关的参数则在
`config.py`
中的
`ModelHyperParams`
和
`TrainTaskConfig`
内定义;
`ModelHyperParams`
定义了 embedding 维度等模型超参数,
`TrainTaskConfig`
定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖
`config.py`
中的配置,如可以通过以下命令来训练 Transformer 论文中的 big model (如显存不够可适当减小 batch size 的值
,或设置
`max_length 200`
过滤过长的句子,或修改某些显存使用相关环境变量的值
):
```
sh
```
sh
# 显存使用的比例,显存不足可适当增大,最大为1
export
FLAGS_fraction_of_gpu_memory_to_use
=
1.0
# 显存清理的阈值,显存不足可适当减小,最小为0,为负数时不启用
export
FLAGS_eager_delete_tensor_gb
=
0.8
python
-u
train.py
\
python
-u
train.py
\
--src_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--src_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
...
@@ -115,18 +119,17 @@ python -u train.py \
...
@@ -115,18 +119,17 @@ python -u train.py \
```
```
有关这些参数更详细信息的请参考
`config.py`
中的注释说明。
有关这些参数更详细信息的请参考
`config.py`
中的注释说明。
训练时默认使用所有 GPU,可以通过
`CUDA_VISIBLE_DEVICES`
环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数
`--divice CPU`
设置),训练速度相对较慢。在训练过程中,每隔一定 iteration 后(通过参数
`save_freq`
设置,默认为10000)保存模型到参数
`model_dir`
指定的目录,每个 epoch 结束后也会保存 checkpiont 到
`ckpt_dir`
指定的目录,每
个 iteration
将打印如下的日志到标准输出:
训练时默认使用所有 GPU,可以通过
`CUDA_VISIBLE_DEVICES`
环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数
`--divice CPU`
设置),训练速度相对较慢。在训练过程中,每隔一定 iteration 后(通过参数
`save_freq`
设置,默认为10000)保存模型到参数
`model_dir`
指定的目录,每个 epoch 结束后也会保存 checkpiont 到
`ckpt_dir`
指定的目录,每
隔一定数目的 iteration (通过参数
`--fetch_steps`
设置,默认为100)
将打印如下的日志到标准输出:
```
txt
```
txt
step_idx: 0, epoch: 0, batch: 0, avg loss: 11.059394, normalized loss: 9.682427, ppl: 63538.027344
[2018-10-26 00:49:24,705 INFO train.py:536] step_idx: 0, epoch: 0, batch: 0, avg loss: 10.999878, normalized loss: 9.624138, ppl: 59866.832031
step_idx: 1, epoch: 0, batch: 1, avg loss: 11.053112, normalized loss: 9.676146, ppl: 63140.144531
[2018-10-26 00:50:08,717 INFO train.py:545] step_idx: 100, epoch: 0, batch: 100, avg loss: 9.454134, normalized loss: 8.078394, ppl: 12760.809570, speed: 2.27 step/s
step_idx: 2, epoch: 0, batch: 2, avg loss: 11.054576, normalized loss: 9.677609, ppl: 63232.640625
[2018-10-26 00:50:52,655 INFO train.py:545] step_idx: 200, epoch: 0, batch: 200, avg loss: 8.643907, normalized loss: 7.268166, ppl: 5675.458496, speed: 2.28 step/s
step_idx: 3, epoch: 0, batch: 3, avg loss: 11.046638, normalized loss: 9.669671, ppl: 62732.664062
[2018-10-26 00:51:36,529 INFO train.py:545] step_idx: 300, epoch: 0, batch: 300, avg loss: 7.916654, normalized loss: 6.540914, ppl: 2742.579346, speed: 2.28 step/s
step_idx: 4, epoch: 0, batch: 4, avg loss: 11.030095, normalized loss: 9.653129, ppl: 61703.449219
[2018-10-26 00:52:20,692 INFO train.py:545] step_idx: 400, epoch: 0, batch: 400, avg loss: 7.902879, normalized loss: 6.527138, ppl: 2705.058350, speed: 2.26 step/s
step_idx: 5, epoch: 0, batch: 5, avg loss: 11.047491, normalized loss: 9.670525, ppl: 62786.230469
[2018-10-26 00:53:04,537 INFO train.py:545] step_idx: 500, epoch: 0, batch: 500, avg loss: 7.818271, normalized loss: 6.442531, ppl: 2485.604492, speed: 2.28 step/s
step_idx: 6, epoch: 0, batch: 6, avg loss: 11.044509, normalized loss: 9.667542, ppl: 62599.273438
[2018-10-26 00:53:48,580 INFO train.py:545] step_idx: 600, epoch: 0, batch: 600, avg loss: 7.554341, normalized loss: 6.178601, ppl: 1909.012451, speed: 2.27 step/s
step_idx: 7, epoch: 0, batch: 7, avg loss: 11.011090, normalized loss: 9.634124, ppl: 60541.859375
[2018-10-26 00:54:32,878 INFO train.py:545] step_idx: 700, epoch: 0, batch: 700, avg loss: 7.177765, normalized loss: 5.802025, ppl: 1309.977661, speed: 2.26 step/s
step_idx: 8, epoch: 0, batch: 8, avg loss: 10.985243, normalized loss: 9.608276, ppl: 58997.058594
[2018-10-26 00:55:17,108 INFO train.py:545] step_idx: 800, epoch: 0, batch: 800, avg loss: 7.005494, normalized loss: 5.629754, ppl: 1102.674805, speed: 2.26 step/s
step_idx: 9, epoch: 0, batch: 9, avg loss: 10.993434, normalized loss: 9.616467, ppl: 59482.292969
```
```
### 模型预测
### 模型预测
...
@@ -138,10 +141,9 @@ python -u infer.py \
...
@@ -138,10 +141,9 @@ python -u infer.py \
--trg_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--test_file_pattern
gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de
\
--test_file_pattern
gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de
\
--use_wordpiece
False
\
--token_delimiter
' '
\
--token_delimiter
' '
\
--batch_size
32
\
--batch_size
32
\
model_path trained_models/iter_1
99999
.infer.model
\
model_path trained_models/iter_1
00000
.infer.model
\
beam_size 4
\
beam_size 4
\
max_out_len 255
max_out_len 255
```
```
...
@@ -164,7 +166,7 @@ BLEU = 33.08, 64.2/39.2/26.4/18.5 (BP=0.994, ratio=0.994, hyp_len=61971, ref_len
...
@@ -164,7 +166,7 @@ BLEU = 33.08, 64.2/39.2/26.4/18.5 (BP=0.994, ratio=0.994, hyp_len=61971, ref_len
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
|-|-|-|-|
| BLEU | 26.
05 | 28.75 | 33.27
|
| BLEU | 26.
25 | 29.15 | 33.64
|
### 分布式训练
### 分布式训练
...
...
fluid/PaddleNLP/neural_machine_translation/transformer/model.py
浏览文件 @
775cf990
...
@@ -124,8 +124,15 @@ def multi_head_attention(queries,
...
@@ -124,8 +124,15 @@ def multi_head_attention(queries,
q
,
k
,
v
=
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
)
q
,
k
,
v
=
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
)
if
cache
is
not
None
:
# use cache and concat time steps
if
cache
is
not
None
:
# use cache and concat time steps
k
=
cache
[
"k"
]
=
layers
.
concat
([
cache
[
"k"
],
k
],
axis
=
1
)
# Since the inplace reshape in __split_heads changes the shape of k and
v
=
cache
[
"v"
]
=
layers
.
concat
([
cache
[
"v"
],
v
],
axis
=
1
)
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k
=
cache
[
"k"
]
=
layers
.
concat
(
[
layers
.
reshape
(
cache
[
"k"
],
shape
=
[
0
,
0
,
d_model
]),
k
],
axis
=
1
)
v
=
cache
[
"v"
]
=
layers
.
concat
(
[
layers
.
reshape
(
cache
[
"v"
],
shape
=
[
0
,
0
,
d_model
]),
v
],
axis
=
1
)
q
=
__split_heads
(
q
,
n_head
)
q
=
__split_heads
(
q
,
n_head
)
k
=
__split_heads
(
k
,
n_head
)
k
=
__split_heads
(
k
,
n_head
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录