From a3faa520ecdaad17c353730675358d7761b359f3 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 25 Nov 2020 19:54:50 +0800 Subject: [PATCH] Fix syncbn (#29013) * fix syncbn * add unittest --- .../fluid/tests/unittests/test_sync_batch_norm_op.py | 10 ++++++++-- python/paddle/nn/layer/norm.py | 7 +++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index 9a380c886e9..4fa64bef32f 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -228,9 +228,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase): with program_guard(Program(), Program()): compare_model = paddle.nn.Sequential( - paddle.nn.Conv2D(3, 5, 3), paddle.nn.BatchNorm2D(5)) + paddle.nn.Conv2D(3, 5, 3), + paddle.nn.BatchNorm2D(5), paddle.nn.BatchNorm2D(5)) model = paddle.nn.Sequential( - paddle.nn.Conv2D(3, 5, 3), paddle.nn.BatchNorm2D(5)) + paddle.nn.Conv2D(3, 5, 3), + paddle.nn.BatchNorm2D(5), + paddle.nn.BatchNorm2D( + 5, + weight_attr=fluid.ParamAttr(name='bn.scale'), + bias_attr=fluid.ParamAttr(name='bn.bias'))) model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) for idx, sublayer in enumerate(compare_model.sublayers()): if isinstance(sublayer, paddle.nn.BatchNorm2D): diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 7f416749c8a..7bff2d64a65 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -1103,6 +1103,13 @@ class SyncBatchNorm(_BatchNormBase): """ layer_output = layer if isinstance(layer, _BatchNormBase): + if layer._weight_attr != None and not isinstance(layer._weight_attr, + bool): + layer._weight_attr.name = layer._weight_attr.name + '_sync' + if layer._bias_attr != None and not isinstance(layer._weight_attr, + bool): + layer._bias_attr.name = layer._bias_attr.name + '_sync' + layer_output = SyncBatchNorm(layer._num_features, layer._momentum, layer._epsilon, layer._weight_attr, layer._bias_attr, layer._data_format, -- GitLab