未验证 提交 fea4ea64 编写于 作者: W WangXi 提交者: GitHub

nlp bert benchmark add gradient merge (#5139)

nlp bert benchmark add gradient merge
上级 241712a4
......@@ -141,6 +141,13 @@ def parse_args():
type=str,
default="gpu",
help="Device for selecting for the training.")
parser.add_argument(
"--gradient_merge_steps",
type=int,
default=1,
help="Number of merge steps before gradient update."
"global_batch_size = gradient_merge_steps * batch_size."
)
args = parser.parse_args()
return args
......@@ -224,6 +231,11 @@ def dist_optimizer(args, optimizer):
'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
'init_loss_scaling': args.scale_loss,
}
if args.gradient_merge_steps > 1:
dist_strategy.gradient_merge = True
dist_strategy.gradient_merge_configs = {
'k_steps': args.gradient_merge_steps
}
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
return optimizer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册