未验证 提交 d9f55e01 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #1404 from guoshengCS/fix-reshape-inplace-transformer-infer

Fix the inplace reshape in inference of Transformer and refine README
......@@ -93,9 +93,13 @@ python -u train.py \
python train.py --help
```
更多模型训练相关的参数则在 `config.py` 中的 `ModelHyperParams``TrainTaskConfig` 内定义;`ModelHyperParams` 定义了 embedding 维度等模型超参数,`TrainTaskConfig` 定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖 `config.py` 中的配置,如可以通过以下命令来训练 Transformer 论文中的 big model (如显存不够可适当减小 batch size 的值):
更多模型训练相关的参数则在 `config.py` 中的 `ModelHyperParams``TrainTaskConfig` 内定义;`ModelHyperParams` 定义了 embedding 维度等模型超参数,`TrainTaskConfig` 定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖 `config.py` 中的配置,如可以通过以下命令来训练 Transformer 论文中的 big model (如显存不够可适当减小 batch size 的值,或设置 `max_length 200` 过滤过长的句子,或修改某些显存使用相关环境变量的值):
```sh
# 显存使用的比例,显存不足可适当增大,最大为1
export FLAGS_fraction_of_gpu_memory_to_use=1.0
# 显存清理的阈值,显存不足可适当减小,最小为0,为负数时不启用
export FLAGS_eager_delete_tensor_gb=0.8
python -u train.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
......@@ -115,18 +119,17 @@ python -u train.py \
```
有关这些参数更详细信息的请参考 `config.py` 中的注释说明。
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--divice CPU` 设置),训练速度相对较慢。在训练过程中,每隔一定 iteration 后(通过参数 `save_freq` 设置,默认为10000)保存模型到参数 `model_dir` 指定的目录,每个 epoch 结束后也会保存 checkpiont 到 `ckpt_dir` 指定的目录,每个 iteration 将打印如下的日志到标准输出:
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--divice CPU` 设置),训练速度相对较慢。在训练过程中,每隔一定 iteration 后(通过参数 `save_freq` 设置,默认为10000)保存模型到参数 `model_dir` 指定的目录,每个 epoch 结束后也会保存 checkpiont 到 `ckpt_dir` 指定的目录,每隔一定数目的 iteration (通过参数 `--fetch_steps` 设置,默认为100)将打印如下的日志到标准输出:
```txt
step_idx: 0, epoch: 0, batch: 0, avg loss: 11.059394, normalized loss: 9.682427, ppl: 63538.027344
step_idx: 1, epoch: 0, batch: 1, avg loss: 11.053112, normalized loss: 9.676146, ppl: 63140.144531
step_idx: 2, epoch: 0, batch: 2, avg loss: 11.054576, normalized loss: 9.677609, ppl: 63232.640625
step_idx: 3, epoch: 0, batch: 3, avg loss: 11.046638, normalized loss: 9.669671, ppl: 62732.664062
step_idx: 4, epoch: 0, batch: 4, avg loss: 11.030095, normalized loss: 9.653129, ppl: 61703.449219
step_idx: 5, epoch: 0, batch: 5, avg loss: 11.047491, normalized loss: 9.670525, ppl: 62786.230469
step_idx: 6, epoch: 0, batch: 6, avg loss: 11.044509, normalized loss: 9.667542, ppl: 62599.273438
step_idx: 7, epoch: 0, batch: 7, avg loss: 11.011090, normalized loss: 9.634124, ppl: 60541.859375
step_idx: 8, epoch: 0, batch: 8, avg loss: 10.985243, normalized loss: 9.608276, ppl: 58997.058594
step_idx: 9, epoch: 0, batch: 9, avg loss: 10.993434, normalized loss: 9.616467, ppl: 59482.292969
[2018-10-26 00:49:24,705 INFO train.py:536] step_idx: 0, epoch: 0, batch: 0, avg loss: 10.999878, normalized loss: 9.624138, ppl: 59866.832031
[2018-10-26 00:50:08,717 INFO train.py:545] step_idx: 100, epoch: 0, batch: 100, avg loss: 9.454134, normalized loss: 8.078394, ppl: 12760.809570, speed: 2.27 step/s
[2018-10-26 00:50:52,655 INFO train.py:545] step_idx: 200, epoch: 0, batch: 200, avg loss: 8.643907, normalized loss: 7.268166, ppl: 5675.458496, speed: 2.28 step/s
[2018-10-26 00:51:36,529 INFO train.py:545] step_idx: 300, epoch: 0, batch: 300, avg loss: 7.916654, normalized loss: 6.540914, ppl: 2742.579346, speed: 2.28 step/s
[2018-10-26 00:52:20,692 INFO train.py:545] step_idx: 400, epoch: 0, batch: 400, avg loss: 7.902879, normalized loss: 6.527138, ppl: 2705.058350, speed: 2.26 step/s
[2018-10-26 00:53:04,537 INFO train.py:545] step_idx: 500, epoch: 0, batch: 500, avg loss: 7.818271, normalized loss: 6.442531, ppl: 2485.604492, speed: 2.28 step/s
[2018-10-26 00:53:48,580 INFO train.py:545] step_idx: 600, epoch: 0, batch: 600, avg loss: 7.554341, normalized loss: 6.178601, ppl: 1909.012451, speed: 2.27 step/s
[2018-10-26 00:54:32,878 INFO train.py:545] step_idx: 700, epoch: 0, batch: 700, avg loss: 7.177765, normalized loss: 5.802025, ppl: 1309.977661, speed: 2.26 step/s
[2018-10-26 00:55:17,108 INFO train.py:545] step_idx: 800, epoch: 0, batch: 800, avg loss: 7.005494, normalized loss: 5.629754, ppl: 1102.674805, speed: 2.26 step/s
```
### 模型预测
......@@ -138,10 +141,9 @@ python -u infer.py \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \
--use_wordpiece False \
--token_delimiter ' ' \
--batch_size 32 \
model_path trained_models/iter_199999.infer.model \
model_path trained_models/iter_100000.infer.model \
beam_size 4 \
max_out_len 255
```
......@@ -164,7 +166,7 @@ BLEU = 33.08, 64.2/39.2/26.4/18.5 (BP=0.994, ratio=0.994, hyp_len=61971, ref_len
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
| BLEU | 26.05 | 28.75 | 33.27 |
| BLEU | 26.25 | 29.15 | 33.64 |
### 分布式训练
......
......@@ -124,8 +124,15 @@ def multi_head_attention(queries,
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
k = cache["k"] = layers.concat([cache["k"], k], axis=1)
v = cache["v"] = layers.concat([cache["v"], v], axis=1)
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册