未验证 提交 f3bd2194 编写于 作者: L liu zhengxi 提交者: GitHub

fix models/dygraph/transformer (#5135)

* fix dygraph

* update decription
上级 1dbc2855
......@@ -28,7 +28,8 @@
1. paddle安装
本项目依赖于 PaddlePaddle 1.8及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装
本项目依赖于 PaddlePaddle 2.0rc 及以上版本或适当的develop版本,请参考 [安装指南](https://www.paddlepaddle.org.cn/install/quick) 进行安装。
若使用的是 PaddlePaddle 1.8 版本,请选择 release/1.8 分支。
2. 下载代码
......
......@@ -172,19 +172,17 @@ def do_train(args):
sum_cost, avg_cost, token_num = criterion(logits, lbl_word,
lbl_weight)
if trainer_count > 1:
avg_cost = transformer.scale_loss(avg_cost)
avg_cost.backward()
transformer.apply_collective_grads()
else:
avg_cost.backward()
# NOTE: When using PaddlePaddle 2.0, it's not necessary to call
# scale_loss() and apply_collective_grads(). However, they are both
# necessary for PaddlePaddle 1.8. Please check PaddlePaddle version.
avg_cost.backward()
optimizer.minimize(avg_cost)
transformer.clear_gradients()
interval_word_num += np.prod(src_word.shape)
if step_idx % args.print_step == 0:
total_avg_cost = avg_cost.numpy() * trainer_count
total_avg_cost = avg_cost.numpy()
if step_idx == 0:
logger.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册