From 25742f1d5cce7c06f96b364a74e641b7927c49ff Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Thu, 4 Feb 2021 16:19:44 +0800 Subject: [PATCH] add fleet pure_fp16 into train.py (#5268) --- PaddleNLP/benchmark/transformer/static/train.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/PaddleNLP/benchmark/transformer/static/train.py b/PaddleNLP/benchmark/transformer/static/train.py index b99c056d..38818666 100644 --- a/PaddleNLP/benchmark/transformer/static/train.py +++ b/PaddleNLP/benchmark/transformer/static/train.py @@ -39,10 +39,9 @@ def do_train(args): paddle.enable_static() if args.is_distributed: fleet.init(is_collective=True) - gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) - places = paddle.CUDAPlace( - gpu_id) if args.use_gpu else paddle.static.cpu_places() - trainer_count = 1 if args.use_gpu else len(places) + places = [paddle.set_device("gpu")] if \ + args.use_gpu else paddle.static.cpu_places() + trainer_count = len(places) else: if args.use_gpu: places = paddle.static.cuda_places() @@ -111,8 +110,10 @@ def do_train(args): if args.use_amp: dist_strategy.amp = True dist_strategy.amp_configs = { - 'custom_white_list': ['softmax', 'layer_norm', 'gelu'], + 'custom_white_list': ['softmax', 'layer_norm'], 'init_loss_scaling': args.scale_loss, + 'custom_black_list': ['lookup_table_v2'], + 'use_pure_fp16': args.use_pure_fp16 } optimizer = fleet.distributed_optimizer( @@ -131,7 +132,7 @@ def do_train(args): optimizer.minimize(avg_cost) if args.is_distributed: - exe = paddle.static.Executor(places) + exe = paddle.static.Executor(places[0]) else: exe = paddle.static.Executor() build_strategy = paddle.static.BuildStrategy() @@ -144,9 +145,9 @@ def do_train(args): exec_strategy=exec_strategy) exe.run(startup_program) - if not args.is_distributed and args.use_amp: + if args.use_amp: optimizer.amp_init(places[0]) - + # the best cross-entropy value with label smoothing loss_normalizer = -( (1. - args.label_smooth_eps) * np.log( -- GitLab