未验证 提交 25742f1d 编写于 作者: W wuhuachaocoding 提交者: GitHub

add fleet pure_fp16 into train.py (#5268)

上级 cdb8e50a
...@@ -39,10 +39,9 @@ def do_train(args): ...@@ -39,10 +39,9 @@ def do_train(args):
paddle.enable_static() paddle.enable_static()
if args.is_distributed: if args.is_distributed:
fleet.init(is_collective=True) fleet.init(is_collective=True)
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) places = [paddle.set_device("gpu")] if \
places = paddle.CUDAPlace( args.use_gpu else paddle.static.cpu_places()
gpu_id) if args.use_gpu else paddle.static.cpu_places() trainer_count = len(places)
trainer_count = 1 if args.use_gpu else len(places)
else: else:
if args.use_gpu: if args.use_gpu:
places = paddle.static.cuda_places() places = paddle.static.cuda_places()
...@@ -111,8 +110,10 @@ def do_train(args): ...@@ -111,8 +110,10 @@ def do_train(args):
if args.use_amp: if args.use_amp:
dist_strategy.amp = True dist_strategy.amp = True
dist_strategy.amp_configs = { dist_strategy.amp_configs = {
'custom_white_list': ['softmax', 'layer_norm', 'gelu'], 'custom_white_list': ['softmax', 'layer_norm'],
'init_loss_scaling': args.scale_loss, 'init_loss_scaling': args.scale_loss,
'custom_black_list': ['lookup_table_v2'],
'use_pure_fp16': args.use_pure_fp16
} }
optimizer = fleet.distributed_optimizer( optimizer = fleet.distributed_optimizer(
...@@ -131,7 +132,7 @@ def do_train(args): ...@@ -131,7 +132,7 @@ def do_train(args):
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
if args.is_distributed: if args.is_distributed:
exe = paddle.static.Executor(places) exe = paddle.static.Executor(places[0])
else: else:
exe = paddle.static.Executor() exe = paddle.static.Executor()
build_strategy = paddle.static.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
...@@ -144,7 +145,7 @@ def do_train(args): ...@@ -144,7 +145,7 @@ def do_train(args):
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
exe.run(startup_program) exe.run(startup_program)
if not args.is_distributed and args.use_amp: if args.use_amp:
optimizer.amp_init(places[0]) optimizer.amp_init(places[0])
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册