提交 bdfa1feb 编写于 作者: G gaotingquan 提交者: Tingquan Gao

update for amp config refactoring

上级 09817fe8
...@@ -95,7 +95,9 @@ def main(args): ...@@ -95,7 +95,9 @@ def main(args):
device = paddle.set_device(global_config["device"]) device = paddle.set_device(global_config["device"])
# amp related config # amp related config
if 'AMP' in config: amp_config = config.get("AMP", None)
use_amp = True if amp_config and amp_config.get("use_amp", True) else False
if use_amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1, 'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 1500, 'FLAGS_conv_workspace_size_limit': 1500,
...@@ -159,15 +161,15 @@ def main(args): ...@@ -159,15 +161,15 @@ def main(args):
# load pretrained models or checkpoints # load pretrained models or checkpoints
init_model(global_config, train_prog, exe) init_model(global_config, train_prog, exe)
if 'AMP' in config: if use_amp:
# for AMP O2
if config["AMP"].get("level", "O1").upper() == "O2": if config["AMP"].get("level", "O1").upper() == "O2":
use_fp16_test = True use_fp16_test = True
msg = "Only support FP16 evaluation when AMP O2 is enabled." msg = "Only support FP16 evaluation when AMP O2 is enabled."
logger.warning(msg) logger.warning(msg)
elif "use_fp16_test" in config["AMP"]: # for AMP O1
use_fp16_test = config["AMP"].get["use_fp16_test"]
else: else:
use_fp16_test = False use_fp16_test = config["AMP"].get("use_fp16_test", False)
optimizer.amp_init( optimizer.amp_init(
device, device,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册