未验证 提交 a3faa520 编写于 作者: C ceci3 提交者: GitHub

Fix syncbn (#29013)

* fix syncbn

* add unittest
上级 582c0a04
...@@ -228,9 +228,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase): ...@@ -228,9 +228,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
compare_model = paddle.nn.Sequential( compare_model = paddle.nn.Sequential(
paddle.nn.Conv2D(3, 5, 3), paddle.nn.BatchNorm2D(5)) paddle.nn.Conv2D(3, 5, 3),
paddle.nn.BatchNorm2D(5), paddle.nn.BatchNorm2D(5))
model = paddle.nn.Sequential( model = paddle.nn.Sequential(
paddle.nn.Conv2D(3, 5, 3), paddle.nn.BatchNorm2D(5)) paddle.nn.Conv2D(3, 5, 3),
paddle.nn.BatchNorm2D(5),
paddle.nn.BatchNorm2D(
5,
weight_attr=fluid.ParamAttr(name='bn.scale'),
bias_attr=fluid.ParamAttr(name='bn.bias')))
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
for idx, sublayer in enumerate(compare_model.sublayers()): for idx, sublayer in enumerate(compare_model.sublayers()):
if isinstance(sublayer, paddle.nn.BatchNorm2D): if isinstance(sublayer, paddle.nn.BatchNorm2D):
......
...@@ -1103,6 +1103,13 @@ class SyncBatchNorm(_BatchNormBase): ...@@ -1103,6 +1103,13 @@ class SyncBatchNorm(_BatchNormBase):
""" """
layer_output = layer layer_output = layer
if isinstance(layer, _BatchNormBase): if isinstance(layer, _BatchNormBase):
if layer._weight_attr != None and not isinstance(layer._weight_attr,
bool):
layer._weight_attr.name = layer._weight_attr.name + '_sync'
if layer._bias_attr != None and not isinstance(layer._weight_attr,
bool):
layer._bias_attr.name = layer._bias_attr.name + '_sync'
layer_output = SyncBatchNorm(layer._num_features, layer._momentum, layer_output = SyncBatchNorm(layer._num_features, layer._momentum,
layer._epsilon, layer._weight_attr, layer._epsilon, layer._weight_attr,
layer._bias_attr, layer._data_format, layer._bias_attr, layer._data_format,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册