未验证 提交 913317fe 编写于 作者: C ceci3 提交者: GitHub

fix bn docs (#32492)

* fix bn docs

* fix unittest
上级 7f162b5e
...@@ -210,7 +210,8 @@ class TestLayerPrint(unittest.TestCase): ...@@ -210,7 +210,8 @@ class TestLayerPrint(unittest.TestCase):
module = nn.BatchNorm1D(1) module = nn.BatchNorm1D(1)
self.assertEqual( self.assertEqual(
str(module), str(module),
'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05)') 'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCL)'
)
module = nn.BatchNorm2D(1) module = nn.BatchNorm2D(1)
self.assertEqual( self.assertEqual(
...@@ -220,7 +221,8 @@ class TestLayerPrint(unittest.TestCase): ...@@ -220,7 +221,8 @@ class TestLayerPrint(unittest.TestCase):
module = nn.BatchNorm3D(1) module = nn.BatchNorm3D(1)
self.assertEqual( self.assertEqual(
str(module), str(module),
'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05)') 'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCDHW)'
)
module = nn.SyncBatchNorm(2) module = nn.SyncBatchNorm(2)
self.assertEqual( self.assertEqual(
......
...@@ -745,6 +745,19 @@ class BatchNorm1D(_BatchNormBase): ...@@ -745,6 +745,19 @@ class BatchNorm1D(_BatchNormBase):
print(batch_norm_out) print(batch_norm_out)
""" """
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCL',
use_global_stats=None,
name=None):
super(BatchNorm1D,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, use_global_stats, name)
def _check_data_format(self, input): def _check_data_format(self, input):
if input == 'NCHW' or input == 'NC' or input == 'NCL': if input == 'NCHW' or input == 'NC' or input == 'NCL':
self._data_format = 'NCHW' self._data_format = 'NCHW'
...@@ -924,6 +937,19 @@ class BatchNorm3D(_BatchNormBase): ...@@ -924,6 +937,19 @@ class BatchNorm3D(_BatchNormBase):
print(batch_norm_out) print(batch_norm_out)
""" """
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCDHW',
use_global_stats=None,
name=None):
super(BatchNorm3D,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, use_global_stats, name)
def _check_data_format(self, input): def _check_data_format(self, input):
if input == 'NCHW' or input == 'NCDHW': if input == 'NCHW' or input == 'NCDHW':
self._data_format = 'NCHW' self._data_format = 'NCHW'
...@@ -1036,7 +1062,7 @@ class SyncBatchNorm(_BatchNormBase): ...@@ -1036,7 +1062,7 @@ class SyncBatchNorm(_BatchNormBase):
name=None): name=None):
super(SyncBatchNorm, super(SyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr, self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, name) bias_attr, data_format, None, name)
def forward(self, x): def forward(self, x):
# create output # create output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册