未验证 提交 25742f1d 编写于 作者: W wuhuachaocoding 提交者: GitHub

add fleet pure_fp16 into train.py (#5268)

上级 cdb8e50a
......@@ -39,10 +39,9 @@ def do_train(args):
paddle.enable_static()
if args.is_distributed:
fleet.init(is_collective=True)
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
places = paddle.CUDAPlace(
gpu_id) if args.use_gpu else paddle.static.cpu_places()
trainer_count = 1 if args.use_gpu else len(places)
places = [paddle.set_device("gpu")] if \
args.use_gpu else paddle.static.cpu_places()
trainer_count = len(places)
else:
if args.use_gpu:
places = paddle.static.cuda_places()
......@@ -111,8 +110,10 @@ def do_train(args):
if args.use_amp:
dist_strategy.amp = True
dist_strategy.amp_configs = {
'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
'custom_white_list': ['softmax', 'layer_norm'],
'init_loss_scaling': args.scale_loss,
'custom_black_list': ['lookup_table_v2'],
'use_pure_fp16': args.use_pure_fp16
}
optimizer = fleet.distributed_optimizer(
......@@ -131,7 +132,7 @@ def do_train(args):
optimizer.minimize(avg_cost)
if args.is_distributed:
exe = paddle.static.Executor(places)
exe = paddle.static.Executor(places[0])
else:
exe = paddle.static.Executor()
build_strategy = paddle.static.BuildStrategy()
......@@ -144,7 +145,7 @@ def do_train(args):
exec_strategy=exec_strategy)
exe.run(startup_program)
if not args.is_distributed and args.use_amp:
if args.use_amp:
optimizer.amp_init(places[0])
# the best cross-entropy value with label smoothing
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册