未验证 提交 f24158de 编写于 作者: Z Zhang Ting 提交者: GitHub

use fuse_bn_add_act pass (#4915)

上级 b480de5d
...@@ -145,7 +145,7 @@ def parse_args(): ...@@ -145,7 +145,7 @@ def parse_args():
add_arg('data_format', str, "NCHW", "Tensor data format when training.") add_arg('data_format', str, "NCHW", "Tensor data format when training.")
add_arg('fuse_elewise_add_act_ops', bool, False, "Whether to use elementwise_act fusion.") add_arg('fuse_elewise_add_act_ops', bool, False, "Whether to use elementwise_act fusion.")
add_arg('fuse_bn_act_ops', bool, False, "Whether to use batch_norm and act fusion.") add_arg('fuse_bn_act_ops', bool, False, "Whether to use batch_norm and act fusion.")
add_arg('fuse_bn_add_act_ops', bool, False, "Whether to use batch_norm, elementwise_add and act fusion. This is only used for AMP training.") add_arg('fuse_bn_add_act_ops', bool, True, "Whether to use batch_norm, elementwise_add and act fusion. This is only used for AMP training.")
add_arg('enable_addto', bool, False, "Whether to enable the addto strategy for gradient accumulation or not. This is only used for AMP training.") add_arg('enable_addto', bool, False, "Whether to enable the addto strategy for gradient accumulation or not. This is only used for AMP training.")
add_arg('use_label_smoothing', bool, False, "Whether to use label_smoothing") add_arg('use_label_smoothing', bool, False, "Whether to use label_smoothing")
...@@ -537,15 +537,19 @@ def best_strategy_compiled(args, ...@@ -537,15 +537,19 @@ def best_strategy_compiled(args,
"PaddlePaddle version 1.7.0 or higher is " "PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse batch_norm and activation_op.") "required when you want to fuse batch_norm and activation_op.")
build_strategy.fuse_elewise_add_act_ops = args.fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = args.fuse_elewise_add_act_ops
try:
build_strategy.fuse_bn_add_act_ops = args.fuse_bn_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle 2.0-rc or higher is "
"required when you want to enable fuse_bn_add_act_ops strategy.")
try: try:
build_strategy.enable_addto = args.enable_addto build_strategy.enable_addto = args.enable_addto
except Exception as e: except Exception as e:
logger.info( logger.info(
"PaddlePaddle 2.0-rc or higher is " "PaddlePaddle 2.0-rc or higher is "
"required when you want to enable addto strategy.") "required when you want to enable addto strategy.")
build_strategy.enable_addto = args.enable_addto
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册