未验证 提交 4dba6afa 编写于 作者: W WangXi 提交者: GitHub

Fleet support pure_fp16 in bert benchmark (#5255)

* bert benchmark fleet support pure_fp16

* fix use_pure_fp16
上级 53208f52
......@@ -231,9 +231,13 @@ def dist_optimizer(args, optimizer):
dist_strategy.fuse_grad_size_in_MB = 16
if args.use_amp:
dist_strategy.amp = True
custom_black_list = ['lookup_table', 'lookup_table_v2'] if args.use_pure_fp16 else None
dist_strategy.amp_configs = {
'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
'init_loss_scaling': args.scale_loss,
'custom_black_list': custom_black_list,
'use_pure_fp16': args.use_pure_fp16
}
if args.gradient_merge_steps > 1:
dist_strategy.gradient_merge = True
......@@ -320,15 +324,20 @@ def do_train(args):
apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
],
multi_precision=args.use_pure_fp16)
if worker_num == 1 and args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_white_list=['softmax', 'layer_norm', 'gelu'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate(
custom_black_list=(['lookup_table', 'lookup_table_v2']
if args.use_pure_fp16 else None)
amp_list = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=['softmax', 'layer_norm', 'gelu'],
custom_black_list=custom_black_list)
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True)
use_dynamic_loss_scaling=True,
use_pure_fp16=args.use_pure_fp16)
if worker_num > 1:
# Use the fleet api to compile the distributed optimizer
......@@ -343,6 +352,8 @@ def do_train(args):
# Use the state dict to update the parameter
reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(main_program, reset_state_dict)
if args.use_amp:
optimizer.amp_init(place)
if worker_num == 1:
# Construct the compiled program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册