未验证 提交 4dba6afa 编写于 作者: W WangXi 提交者: GitHub

Fleet support pure_fp16 in bert benchmark (#5255)

* bert benchmark fleet support pure_fp16

* fix use_pure_fp16
上级 53208f52
...@@ -231,9 +231,13 @@ def dist_optimizer(args, optimizer): ...@@ -231,9 +231,13 @@ def dist_optimizer(args, optimizer):
dist_strategy.fuse_grad_size_in_MB = 16 dist_strategy.fuse_grad_size_in_MB = 16
if args.use_amp: if args.use_amp:
dist_strategy.amp = True dist_strategy.amp = True
custom_black_list = ['lookup_table', 'lookup_table_v2'] if args.use_pure_fp16 else None
dist_strategy.amp_configs = { dist_strategy.amp_configs = {
'custom_white_list': ['softmax', 'layer_norm', 'gelu'], 'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
'init_loss_scaling': args.scale_loss, 'init_loss_scaling': args.scale_loss,
'custom_black_list': custom_black_list,
'use_pure_fp16': args.use_pure_fp16
} }
if args.gradient_merge_steps > 1: if args.gradient_merge_steps > 1:
dist_strategy.gradient_merge = True dist_strategy.gradient_merge = True
...@@ -320,15 +324,20 @@ def do_train(args): ...@@ -320,15 +324,20 @@ def do_train(args):
apply_decay_param_fun=lambda x: x in [ apply_decay_param_fun=lambda x: x in [
p.name for n, p in model.named_parameters() p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"]) if not any(nd in n for nd in ["bias", "norm"])
]) ],
multi_precision=args.use_pure_fp16)
if worker_num == 1 and args.use_amp: if worker_num == 1 and args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_black_list=(['lookup_table', 'lookup_table_v2']
custom_white_list=['softmax', 'layer_norm', 'gelu']) if args.use_pure_fp16 else None)
optimizer = paddle.fluid.contrib.mixed_precision.decorate( amp_list = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=['softmax', 'layer_norm', 'gelu'],
custom_black_list=custom_black_list)
optimizer = paddle.static.amp.decorate(
optimizer, optimizer,
amp_list, amp_list,
init_loss_scaling=args.scale_loss, init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True) use_dynamic_loss_scaling=True,
use_pure_fp16=args.use_pure_fp16)
if worker_num > 1: if worker_num > 1:
# Use the fleet api to compile the distributed optimizer # Use the fleet api to compile the distributed optimizer
...@@ -343,6 +352,8 @@ def do_train(args): ...@@ -343,6 +352,8 @@ def do_train(args):
# Use the state dict to update the parameter # Use the state dict to update the parameter
reset_state_dict = reset_program_state_dict(model, state_dict) reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(main_program, reset_state_dict) paddle.static.set_program_state(main_program, reset_state_dict)
if args.use_amp:
optimizer.amp_init(place)
if worker_num == 1: if worker_num == 1:
# Construct the compiled program # Construct the compiled program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册