提交 056f8c48 编写于 作者: G guosheng

Make Transformer adapt to the latest api of ParallelExecutor

上级 1b2641aa
...@@ -291,11 +291,15 @@ def train(args): ...@@ -291,11 +291,15 @@ def train(args):
clip_last_batch=False) clip_last_batch=False)
train_data = read_multiple(reader=train_data.batch_generator) train_data = read_multiple(reader=train_data.batch_generator)
build_strategy = fluid.BuildStrategy()
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu, use_cuda=TrainTaskConfig.use_gpu,
loss_name=sum_cost.name, loss_name=sum_cost.name,
use_default_grad_scale=False) build_strategy=build_strategy)
def test_context(): def test_context():
# Context to do validation. # Context to do validation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册