diff --git a/ppcls/static/program.py b/ppcls/static/program.py index 505a6765ad896e20272381c54f634d9dcb5bf08a..fa8141cd9febbc52ac3a2daa967d1d64b17f7812 100644 --- a/ppcls/static/program.py +++ b/ppcls/static/program.py @@ -154,7 +154,8 @@ def create_strategy(config): """ build_strategy = paddle.static.BuildStrategy() - fuse_op = True if 'AMP' in config else False + fuse_op = True if 'AMP' in config and config['AMP'].get('use_amp', + True) else False fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op) fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op) @@ -194,7 +195,7 @@ def dist_optimizer(config, optimizer): def mixed_precision_optimizer(config, optimizer): - if 'AMP' in config: + if 'AMP' in config and config['AMP'].get('use_amp', True): amp_cfg = config.AMP if config.AMP else dict() scale_loss = amp_cfg.get('scale_loss', 1.0) use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling', @@ -243,7 +244,8 @@ def build(config, use_mix = "batch_transform_ops" in config["DataLoader"][mode][ "dataset"] data_dtype = "float32" - if 'AMP' in config and config["AMP"]["level"] == 'O2': + if 'AMP' in config and config['AMP'].get( + 'use_amp', True) and config["AMP"]["level"] == 'O2': data_dtype = "float16" feeds = create_feeds( config["Global"]["image_shape"],