From 913317fe0ee37f87c09a120a8eb2efa986497ffb Mon Sep 17 00:00:00 2001 From: ceci3 Date: Mon, 26 Apr 2021 17:04:48 +0800 Subject: [PATCH] fix bn docs (#32492) * fix bn docs * fix unittest --- .../tests/unittests/test_imperative_layers.py | 6 ++-- python/paddle/nn/layer/norm.py | 28 ++++++++++++++++++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layers.py b/python/paddle/fluid/tests/unittests/test_imperative_layers.py index 214339c50d6..dc15566f854 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layers.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layers.py @@ -210,7 +210,8 @@ class TestLayerPrint(unittest.TestCase): module = nn.BatchNorm1D(1) self.assertEqual( str(module), - 'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05)') + 'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCL)' + ) module = nn.BatchNorm2D(1) self.assertEqual( @@ -220,7 +221,8 @@ class TestLayerPrint(unittest.TestCase): module = nn.BatchNorm3D(1) self.assertEqual( str(module), - 'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05)') + 'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCDHW)' + ) module = nn.SyncBatchNorm(2) self.assertEqual( diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index a1cc41f3912..0b0b2bf7b9b 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -745,6 +745,19 @@ class BatchNorm1D(_BatchNormBase): print(batch_norm_out) """ + def __init__(self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NCL', + use_global_stats=None, + name=None): + super(BatchNorm1D, + self).__init__(num_features, momentum, epsilon, weight_attr, + bias_attr, data_format, use_global_stats, name) + def _check_data_format(self, input): if input == 'NCHW' or input == 'NC' or input == 'NCL': self._data_format = 'NCHW' @@ -924,6 +937,19 @@ class BatchNorm3D(_BatchNormBase): print(batch_norm_out) """ + def __init__(self, + num_features, + momentum=0.9, + epsilon=1e-05, + weight_attr=None, + bias_attr=None, + data_format='NCDHW', + use_global_stats=None, + name=None): + super(BatchNorm3D, + self).__init__(num_features, momentum, epsilon, weight_attr, + bias_attr, data_format, use_global_stats, name) + def _check_data_format(self, input): if input == 'NCHW' or input == 'NCDHW': self._data_format = 'NCHW' @@ -1036,7 +1062,7 @@ class SyncBatchNorm(_BatchNormBase): name=None): super(SyncBatchNorm, self).__init__(num_features, momentum, epsilon, weight_attr, - bias_attr, data_format, name) + bias_attr, data_format, None, name) def forward(self, x): # create output -- GitLab