未验证 提交 607a07cb 编写于 作者: T Tingquan Gao 提交者: GitHub

compatible with the AMP.use_amp field in config (#2889)

上级 267cd647
...@@ -154,7 +154,8 @@ def create_strategy(config): ...@@ -154,7 +154,8 @@ def create_strategy(config):
""" """
build_strategy = paddle.static.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
fuse_op = True if 'AMP' in config else False fuse_op = True if 'AMP' in config and config['AMP'].get('use_amp',
True) else False
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op) fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op) fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
...@@ -194,7 +195,7 @@ def dist_optimizer(config, optimizer): ...@@ -194,7 +195,7 @@ def dist_optimizer(config, optimizer):
def mixed_precision_optimizer(config, optimizer): def mixed_precision_optimizer(config, optimizer):
if 'AMP' in config: if 'AMP' in config and config['AMP'].get('use_amp', True):
amp_cfg = config.AMP if config.AMP else dict() amp_cfg = config.AMP if config.AMP else dict()
scale_loss = amp_cfg.get('scale_loss', 1.0) scale_loss = amp_cfg.get('scale_loss', 1.0)
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling', use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling',
...@@ -243,7 +244,8 @@ def build(config, ...@@ -243,7 +244,8 @@ def build(config,
use_mix = "batch_transform_ops" in config["DataLoader"][mode][ use_mix = "batch_transform_ops" in config["DataLoader"][mode][
"dataset"] "dataset"]
data_dtype = "float32" data_dtype = "float32"
if 'AMP' in config and config["AMP"]["level"] == 'O2': if 'AMP' in config and config['AMP'].get(
'use_amp', True) and config["AMP"]["level"] == 'O2':
data_dtype = "float16" data_dtype = "float16"
feeds = create_feeds( feeds = create_feeds(
config["Global"]["image_shape"], config["Global"]["image_shape"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册