From 607a07cb28c37541d5f84e30f7f77bb2755dae16 Mon Sep 17 00:00:00 2001 From: Tingquan Gao Date: Fri, 28 Jul 2023 17:48:52 +0800 Subject: [PATCH] compatible with the AMP.use_amp field in config (#2889) --- ppcls/static/program.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ppcls/static/program.py b/ppcls/static/program.py index 505a6765..fa8141cd 100644 --- a/ppcls/static/program.py +++ b/ppcls/static/program.py @@ -154,7 +154,8 @@ def create_strategy(config): """ build_strategy = paddle.static.BuildStrategy() - fuse_op = True if 'AMP' in config else False + fuse_op = True if 'AMP' in config and config['AMP'].get('use_amp', + True) else False fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op) fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op) @@ -194,7 +195,7 @@ def dist_optimizer(config, optimizer): def mixed_precision_optimizer(config, optimizer): - if 'AMP' in config: + if 'AMP' in config and config['AMP'].get('use_amp', True): amp_cfg = config.AMP if config.AMP else dict() scale_loss = amp_cfg.get('scale_loss', 1.0) use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling', @@ -243,7 +244,8 @@ def build(config, use_mix = "batch_transform_ops" in config["DataLoader"][mode][ "dataset"] data_dtype = "float32" - if 'AMP' in config and config["AMP"]["level"] == 'O2': + if 'AMP' in config and config['AMP'].get( + 'use_amp', True) and config["AMP"]["level"] == 'O2': data_dtype = "float16" feeds = create_feeds( config["Global"]["image_shape"], -- GitLab