From b761325faaefcffb96d6f99bcf8b13005e98d19b Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Sat, 2 Apr 2022 07:35:01 +0000 Subject: [PATCH] fix: fp32 eval by default when enable amp If you want to eval by fp16 when enable amp, please set Amp.use_fp16_test=True, False by default. --- ppcls/engine/evaluation/classification.py | 2 +- ppcls/static/train.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 994eeb5e..e9836fcb 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -58,7 +58,7 @@ def classification_eval(engine, epoch_id=0): batch[1] = batch[1].reshape([-1, 1]).astype("int64") # image input - if engine.amp: + if engine.amp and engine.config["AMP"].get("use_fp16_test", False): amp_level = engine.config['AMP'].get("level", "O1").upper() with paddle.amp.auto_cast( custom_black_list={ diff --git a/ppcls/static/train.py b/ppcls/static/train.py index dd16cdb4..45c93762 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -161,12 +161,13 @@ def main(args): # load pretrained models or checkpoints init_model(global_config, train_prog, exe) - if 'AMP' in config and config.AMP.get("level", "O1") == "O2": + if 'AMP' in config: optimizer.amp_init( device, scope=paddle.static.global_scope(), test_program=eval_prog - if global_config["eval_during_train"] else None) + if global_config["eval_during_train"] else None, + use_fp16_test=config["AMP"].get("use_fp16_test", False)) if not global_config.get("is_distributed", True): compiled_train_prog = program.compile( @@ -182,7 +183,7 @@ def main(args): program.run(train_dataloader, exe, compiled_train_prog, train_feeds, train_fetchs, epoch_id, 'train', config, vdl_writer, lr_scheduler, args.profiler_options) - # 2. evaate with eval dataset + # 2. evaluate with eval dataset if global_config["eval_during_train"] and epoch_id % global_config[ "eval_interval"] == 0: top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog, -- GitLab