From 09104d0221635b2cc6b2284a07bfe848e366bb00 Mon Sep 17 00:00:00 2001 From: hlygit66666 <32728786+hlygit66666@users.noreply.github.com> Date: Tue, 25 Jan 2022 19:39:18 +0800 Subject: [PATCH] Add FuseBatchNormActPass and unittest. (#39176) * add fuse_relu_depthwise_conv_pass unittest * fix atol and rtol * fix according to review * Add fuse_bn_act_pass unittest * rm others * add fuse_bn_act_pass --- python/paddle/distributed/passes/cpp_pass.py | 13 +++ .../test_dist_fuse_bn_act_pass.py | 96 +++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_bn_act_pass.py diff --git a/python/paddle/distributed/passes/cpp_pass.py b/python/paddle/distributed/passes/cpp_pass.py index fe6ef74bd85..6b8ea30f3ba 100644 --- a/python/paddle/distributed/passes/cpp_pass.py +++ b/python/paddle/distributed/passes/cpp_pass.py @@ -26,3 +26,16 @@ class FuseElementwiseAddActPass(CPPPassWrapper): def _type(self): return PassType.FUSION_OPT + + +@register_pass("fuse_bn_act") +class FuseBatchNormActPass(CPPPassWrapper): + def __init__(self): + super(FuseBatchNormActPass, self).__init__() + + @property + def cpp_name(self): + return "fuse_bn_act_pass" + + def _type(self): + return PassType.FUSION_OPT diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_bn_act_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_bn_act_pass.py new file mode 100644 index 00000000000..a7147724fbc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_dist_fuse_bn_act_pass.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +import paddle.distributed.fleet as fleet +import numpy as np +import paddle.nn as nn +from paddle.distributed.passes import new_pass, PassManager +import unittest +from dist_pass_test_base import DistPassTestBase + + +class BatchNormActNet(nn.Layer): + def __init__(self): + super(BatchNormActNet, self).__init__() + + self.conv1 = nn.Conv2D(3, 8, (3, 3), data_format="NHWC") + self.bn1 = nn.BatchNorm2D(8, data_format="NHWC") + self.relu = nn.ReLU() + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = paddle.flatten(out, 1) + return out + + +class TestFuseBatchNormActPass(DistPassTestBase): + def init(self): + self.atol = 1e-4 + self.rtol = 1e-4 + + def get_model(self, place, batch_size=32, image_shape=[224, 224, 3]): + image = paddle.static.data( + shape=[batch_size] + image_shape, dtype='float32', name='image') + + model = BatchNormActNet() + pred_out = model(image) + loss = paddle.mean(pred_out) + optimizer = paddle.optimizer.Adam(learning_rate=1e-3) + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.fuse_all_reduce_ops = False + dist_strategy.without_graph_optimization = True + dist_strategy.amp = True + dist_strategy.amp_configs = { + "init_loss_scaling": 32768, + "use_dynamic_loss_scaling": True, + } + fleet.init(is_collective=True, strategy=dist_strategy) + optimizer = fleet.distributed_optimizer(optimizer) + optimizer.minimize(loss) + + rank = paddle.distributed.get_rank() + + def reader(): + seed = int(os.environ.get("SEED", 0)) + np.random.seed(seed + rank) + for _ in range(10): + image_np = np.random.random(size=image.shape).astype('float32') + yield image_np, + + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + return main_program, startup_program, [image], [loss], reader + + def apply_passes(self, main_prog, startup_prog): + pass_manager = PassManager([new_pass("fuse_bn_act")]) + pass_manager.apply([main_prog], [startup_prog]) + print(pass_manager.names) + + op_type = [] + for op in main_prog.global_block().ops: + op_type.append(op.type) + self.assertTrue("fused_batch_norm_act" in op_type) + self.assertTrue("fused_batch_norm_act_grad" in op_type) + + def test_fuse_bn_act(self): + self.check_main() + + +if __name__ == "__main__": + unittest.main() -- GitLab