diff --git a/PaddleNLP/benchmark/bert/run_pretrain.py b/PaddleNLP/benchmark/bert/run_pretrain.py index efbec1814d4d42c0b0ce4d0858a2c8ec888007b2..b5adea2a04ea7d7ad2d0f47fb15a6cea498c5e61 100644 --- a/PaddleNLP/benchmark/bert/run_pretrain.py +++ b/PaddleNLP/benchmark/bert/run_pretrain.py @@ -231,9 +231,13 @@ def dist_optimizer(args, optimizer): dist_strategy.fuse_grad_size_in_MB = 16 if args.use_amp: dist_strategy.amp = True + + custom_black_list = ['lookup_table', 'lookup_table_v2'] if args.use_pure_fp16 else None dist_strategy.amp_configs = { 'custom_white_list': ['softmax', 'layer_norm', 'gelu'], 'init_loss_scaling': args.scale_loss, + 'custom_black_list': custom_black_list, + 'use_pure_fp16': args.use_pure_fp16 } if args.gradient_merge_steps > 1: dist_strategy.gradient_merge = True @@ -320,15 +324,20 @@ def do_train(args): apply_decay_param_fun=lambda x: x in [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) - ]) + ], + multi_precision=args.use_pure_fp16) if worker_num == 1 and args.use_amp: - amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists( - custom_white_list=['softmax', 'layer_norm', 'gelu']) - optimizer = paddle.fluid.contrib.mixed_precision.decorate( + custom_black_list=(['lookup_table', 'lookup_table_v2'] + if args.use_pure_fp16 else None) + amp_list = paddle.static.amp.AutoMixedPrecisionLists( + custom_white_list=['softmax', 'layer_norm', 'gelu'], + custom_black_list=custom_black_list) + optimizer = paddle.static.amp.decorate( optimizer, amp_list, init_loss_scaling=args.scale_loss, - use_dynamic_loss_scaling=True) + use_dynamic_loss_scaling=True, + use_pure_fp16=args.use_pure_fp16) if worker_num > 1: # Use the fleet api to compile the distributed optimizer @@ -343,6 +352,8 @@ def do_train(args): # Use the state dict to update the parameter reset_state_dict = reset_program_state_dict(model, state_dict) paddle.static.set_program_state(main_program, reset_state_dict) + if args.use_amp: + optimizer.amp_init(place) if worker_num == 1: # Construct the compiled program