未验证 提交 f17d6430 编写于 作者: C ceci3 提交者: GitHub

Fix syncbn (#32989) (#33321)

* fix syncbn
上级 c42ccf14
...@@ -248,7 +248,7 @@ class TestConvertSyncBatchNorm(unittest.TestCase): ...@@ -248,7 +248,7 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
isinstance(model[idx], paddle.nn.SyncBatchNorm), True) isinstance(model[idx], paddle.nn.SyncBatchNorm), True)
class TestConvertSyncBatchNormCase2(unittest.TestCase): class TestConvertSyncBatchNormCast1(unittest.TestCase):
def test_convert(self): def test_convert(self):
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
return return
...@@ -277,5 +277,70 @@ class TestConvertSyncBatchNormCase2(unittest.TestCase): ...@@ -277,5 +277,70 @@ class TestConvertSyncBatchNormCase2(unittest.TestCase):
self.assertEqual(len(compare_model.sublayers()), len(model.sublayers())) self.assertEqual(len(compare_model.sublayers()), len(model.sublayers()))
class TestConvertSyncBatchNormCase2(unittest.TestCase):
def test_convert(self):
if not core.is_compiled_with_cuda():
return
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
class SyBNNet(paddle.nn.Layer):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(SyBNNet, self).__init__()
self.bn_s1 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(
out_ch,
weight_attr=paddle.ParamAttr(
regularizer=paddle.regularizer.L2Decay(0.))))
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(
out_ch, data_format='NDHWC'))
def forward(self, x):
x = self.bn_s1(x)
out = paddle.sum(paddle.abs(self.bn_s2(x)))
return out
class BNNet(paddle.nn.Layer):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(BNNet, self).__init__()
self.bn_s1 = paddle.nn.BatchNorm3D(
out_ch,
weight_attr=paddle.ParamAttr(
regularizer=paddle.regularizer.L2Decay(0.)))
self.bn_s2 = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
paddle.nn.BatchNorm3D(
out_ch, data_format='NDHWC'))
def forward(self, x):
x = self.bn_s1(x)
out = paddle.sum(paddle.abs(self.bn_s2(x)))
return out
bn_model = BNNet()
sybn_model = SyBNNet()
np.random.seed(10)
data = np.random.random([3, 3, 3, 3, 3]).astype('float32')
x = paddle.to_tensor(data)
bn_out = bn_model(x)
sybn_out = sybn_model(x)
self.assertTrue(
np.allclose(bn_out.numpy(), sybn_out.numpy()),
"Output has diff. \n" + "\nBN " + str(bn_out.numpy()) + "\n"
+ "Sync BN " + str(sybn_out.numpy()))
class TestDygraphSyncBatchNormDataFormatError(unittest.TestCase):
def test_errors(self):
if not core.is_compiled_with_cuda():
return
with fluid.dygraph.guard(fluid.CUDAPlace(0)):
my_sync_batch_norm = paddle.nn.SyncBatchNorm(10, data_format='CN')
data = np.random.random([3, 3, 3]).astype('float32')
x = paddle.to_tensor(data)
self.assertRaises(ValueError, my_sync_batch_norm, x)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -1057,7 +1057,18 @@ class SyncBatchNorm(_BatchNormBase): ...@@ -1057,7 +1057,18 @@ class SyncBatchNorm(_BatchNormBase):
self).__init__(num_features, momentum, epsilon, weight_attr, self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, None, name) bias_attr, data_format, None, name)
def _check_data_format(self):
if self._data_format in ['NCHW', 'NCDHW', 'NC', 'NCL']:
self._data_format = 'NCHW'
elif self._data_format in ["NHWC", "NDHWC", 'NLC']:
self._data_format = 'NHWC'
else:
raise ValueError(
'expected \'NCDHW\', \'NDHWC\', \'NCL\', \'NLC\', \'NC\', \'NCHW\', \'NHWC\' for data_format'
)
def forward(self, x): def forward(self, x):
self._check_data_format()
# create output # create output
# mean and mean_out share the same memory # mean and mean_out share the same memory
mean_out = self._mean mean_out = self._mean
...@@ -1142,11 +1153,12 @@ class SyncBatchNorm(_BatchNormBase): ...@@ -1142,11 +1153,12 @@ class SyncBatchNorm(_BatchNormBase):
""" """
layer_output = layer layer_output = layer
if isinstance(layer, _BatchNormBase): if isinstance(layer, _BatchNormBase):
if layer._weight_attr != None and not isinstance(layer._weight_attr, if layer._weight_attr != None and not isinstance(
bool): layer._weight_attr,
bool) and layer._weight_attr.name != None:
layer._weight_attr.name = layer._weight_attr.name + '_sync' layer._weight_attr.name = layer._weight_attr.name + '_sync'
if layer._bias_attr != None and not isinstance(layer._weight_attr, if layer._bias_attr != None and not isinstance(
bool): layer._bias_attr, bool) and layer._bias_attr.name != None:
layer._bias_attr.name = layer._bias_attr.name + '_sync' layer._bias_attr.name = layer._bias_attr.name + '_sync'
layer_output = SyncBatchNorm(layer._num_features, layer._momentum, layer_output = SyncBatchNorm(layer._num_features, layer._momentum,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册