From e7b160113d9c9337c608cee8cf44289f3bc3db5d Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Tue, 2 Aug 2022 20:02:25 +0800 Subject: [PATCH] fix fp16 bug --- tools/train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/train.py b/tools/train.py index 309d4bb9..dc8cae8a 100755 --- a/tools/train.py +++ b/tools/train.py @@ -119,9 +119,6 @@ def main(config, device, logger, vdl_writer): config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) - if config['Global']['distributed']: - model = paddle.DataParallel(model) - model = apply_to_static(model, config, logger) # build loss @@ -157,10 +154,13 @@ def main(config, device, logger, vdl_writer): scaler = paddle.amp.GradScaler( init_loss_scaling=scale_loss, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True) + model, optimizer = paddle.amp.decorate( + models=model, optimizers=optimizer, level='O2', master_weight=True) else: scaler = None + if config['Global']['distributed']: + model = paddle.DataParallel(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, -- GitLab