From bdfa1feb2f310a36006e5cdf542cc330a477fffa Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Mon, 29 May 2023 09:48:36 +0000 Subject: [PATCH] update for amp config refactoring --- ppcls/static/train.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ppcls/static/train.py b/ppcls/static/train.py index e86e2c8d..98d51de5 100755 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -95,7 +95,9 @@ def main(args): device = paddle.set_device(global_config["device"]) # amp related config - if 'AMP' in config: + amp_config = config.get("AMP", None) + use_amp = True if amp_config and amp_config.get("use_amp", True) else False + if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_exhaustive_search': 1, 'FLAGS_conv_workspace_size_limit': 1500, @@ -159,15 +161,15 @@ def main(args): # load pretrained models or checkpoints init_model(global_config, train_prog, exe) - if 'AMP' in config: + if use_amp: + # for AMP O2 if config["AMP"].get("level", "O1").upper() == "O2": use_fp16_test = True msg = "Only support FP16 evaluation when AMP O2 is enabled." logger.warning(msg) - elif "use_fp16_test" in config["AMP"]: - use_fp16_test = config["AMP"].get["use_fp16_test"] + # for AMP O1 else: - use_fp16_test = False + use_fp16_test = config["AMP"].get("use_fp16_test", False) optimizer.amp_init( device, -- GitLab