diff --git a/ppcls/static/train.py b/ppcls/static/train.py index e86e2c8d75e3b30c25fb914c95680eebca4d2f50..98d51de5c63a2d2fe917087090210ebd7f9a08b5 100755 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -95,7 +95,9 @@ def main(args): device = paddle.set_device(global_config["device"]) # amp related config - if 'AMP' in config: + amp_config = config.get("AMP", None) + use_amp = True if amp_config and amp_config.get("use_amp", True) else False + if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_exhaustive_search': 1, 'FLAGS_conv_workspace_size_limit': 1500, @@ -159,15 +161,15 @@ def main(args): # load pretrained models or checkpoints init_model(global_config, train_prog, exe) - if 'AMP' in config: + if use_amp: + # for AMP O2 if config["AMP"].get("level", "O1").upper() == "O2": use_fp16_test = True msg = "Only support FP16 evaluation when AMP O2 is enabled." logger.warning(msg) - elif "use_fp16_test" in config["AMP"]: - use_fp16_test = config["AMP"].get["use_fp16_test"] + # for AMP O1 else: - use_fp16_test = False + use_fp16_test = config["AMP"].get("use_fp16_test", False) optimizer.amp_init( device,