From f24158de7f752265822cead74c9a1898bb19a6e5 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Tue, 27 Oct 2020 15:25:55 +0800 Subject: [PATCH] use fuse_bn_add_act pass (#4915) --- PaddleCV/image_classification/utils/utility.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/PaddleCV/image_classification/utils/utility.py b/PaddleCV/image_classification/utils/utility.py index bef22147..537004c0 100644 --- a/PaddleCV/image_classification/utils/utility.py +++ b/PaddleCV/image_classification/utils/utility.py @@ -145,7 +145,7 @@ def parse_args(): add_arg('data_format', str, "NCHW", "Tensor data format when training.") add_arg('fuse_elewise_add_act_ops', bool, False, "Whether to use elementwise_act fusion.") add_arg('fuse_bn_act_ops', bool, False, "Whether to use batch_norm and act fusion.") - add_arg('fuse_bn_add_act_ops', bool, False, "Whether to use batch_norm, elementwise_add and act fusion. This is only used for AMP training.") + add_arg('fuse_bn_add_act_ops', bool, True, "Whether to use batch_norm, elementwise_add and act fusion. This is only used for AMP training.") add_arg('enable_addto', bool, False, "Whether to enable the addto strategy for gradient accumulation or not. This is only used for AMP training.") add_arg('use_label_smoothing', bool, False, "Whether to use label_smoothing") @@ -537,15 +537,19 @@ def best_strategy_compiled(args, "PaddlePaddle version 1.7.0 or higher is " "required when you want to fuse batch_norm and activation_op.") build_strategy.fuse_elewise_add_act_ops = args.fuse_elewise_add_act_ops - + + try: + build_strategy.fuse_bn_add_act_ops = args.fuse_bn_add_act_ops + except Exception as e: + logger.info( + "PaddlePaddle 2.0-rc or higher is " + "required when you want to enable fuse_bn_add_act_ops strategy.") try: build_strategy.enable_addto = args.enable_addto except Exception as e: logger.info( "PaddlePaddle 2.0-rc or higher is " "required when you want to enable addto strategy.") - build_strategy.enable_addto = args.enable_addto - exec_strategy = fluid.ExecutionStrategy() -- GitLab