提交 bdfa1feb 编写于 作者: G gaotingquan 提交者: Tingquan Gao

update for amp config refactoring

上级 09817fe8
......@@ -95,7 +95,9 @@ def main(args):
device = paddle.set_device(global_config["device"])
# amp related config
if 'AMP' in config:
amp_config = config.get("AMP", None)
use_amp = True if amp_config and amp_config.get("use_amp", True) else False
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 1500,
......@@ -159,15 +161,15 @@ def main(args):
# load pretrained models or checkpoints
init_model(global_config, train_prog, exe)
if 'AMP' in config:
if use_amp:
# for AMP O2
if config["AMP"].get("level", "O1").upper() == "O2":
use_fp16_test = True
msg = "Only support FP16 evaluation when AMP O2 is enabled."
logger.warning(msg)
elif "use_fp16_test" in config["AMP"]:
use_fp16_test = config["AMP"].get["use_fp16_test"]
# for AMP O1
else:
use_fp16_test = False
use_fp16_test = config["AMP"].get("use_fp16_test", False)
optimizer.amp_init(
device,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册