From 1b93e92ccbd2db562e2cbac16edbc821d92b19a8 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Fri, 19 Aug 2022 13:26:02 +0800 Subject: [PATCH] fix amp bug --- tools/program.py | 18 ++++++++++-------- tools/train.py | 8 +++++--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tools/program.py b/tools/program.py index b450bc5a..3043cd77 100755 --- a/tools/program.py +++ b/tools/program.py @@ -162,18 +162,18 @@ def to_float32(preds): for k in preds: if isinstance(preds[k], dict) or isinstance(preds[k], list): preds[k] = to_float32(preds[k]) - else: - preds[k] = paddle.to_tensor(preds[k], dtype='float32') + elif isinstance(preds[k], paddle.Tensor): + preds[k] = preds[k].astype(paddle.float32) elif isinstance(preds, list): for k in range(len(preds)): if isinstance(preds[k], dict): preds[k] = to_float32(preds[k]) elif isinstance(preds[k], list): preds[k] = to_float32(preds[k]) - else: - preds[k] = paddle.to_tensor(preds[k], dtype='float32') - else: - preds = paddle.to_tensor(preds, dtype='float32') + elif isinstance(preds[k], paddle.Tensor): + preds[k] = preds[k].astype(paddle.float32) + elif isinstance(preds, paddle.Tensor): + preds = preds.astype(paddle.float32) return preds @@ -190,7 +190,8 @@ def train(config, pre_best_model_dict, logger, log_writer=None, - scaler=None): + scaler=None, + amp_level='O2'): cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1) @@ -276,7 +277,7 @@ def train(config, model_average = True # use amp if scaler: - with paddle.amp.auto_cast(level='O2'): + with paddle.amp.auto_cast(level=amp_level): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie", 'vqa']: @@ -514,6 +515,7 @@ def eval(model, sum_images, i), fm_lr) else: preds = model(images) + preds = to_float32(preds) else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) diff --git a/tools/train.py b/tools/train.py index 0c881eca..5f310938 100755 --- a/tools/train.py +++ b/tools/train.py @@ -147,6 +147,7 @@ def main(config, device, logger, vdl_writer): len(valid_dataloader))) use_amp = config["Global"].get("use_amp", False) + amp_level = config["Global"].get("amp_level", 'O2') if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, @@ -159,8 +160,9 @@ def main(config, device, logger, vdl_writer): scaler = paddle.amp.GradScaler( init_loss_scaling=scale_loss, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - model, optimizer = paddle.amp.decorate( - models=model, optimizers=optimizer, level='O2', master_weight=True) + if amp_level == "O2": + model, optimizer = paddle.amp.decorate( + models=model, optimizers=optimizer, level=amp_level, master_weight=True) else: scaler = None @@ -169,7 +171,7 @@ def main(config, device, logger, vdl_writer): # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, scaler) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level) def test_reader(config, device, logger): -- GitLab