未验证 提交 4c32181d 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #931 from guoshengCS/fix-transformer-executor-usage

Make Transformer adapt to the latest api of ParallelExecutor
......@@ -291,11 +291,15 @@ def train(args):
clip_last_batch=False)
train_data = read_multiple(reader=train_data.batch_generator)
build_strategy = fluid.BuildStrategy()
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
loss_name=sum_cost.name,
use_default_grad_scale=False)
build_strategy=build_strategy)
def test_context():
# Context to do validation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册