未验证 提交 913317fe 编写于 作者: C ceci3 提交者: GitHub

fix bn docs (#32492)

* fix bn docs

* fix unittest
上级 7f162b5e
......@@ -210,7 +210,8 @@ class TestLayerPrint(unittest.TestCase):
module = nn.BatchNorm1D(1)
self.assertEqual(
str(module),
'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05)')
'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCL)'
)
module = nn.BatchNorm2D(1)
self.assertEqual(
......@@ -220,7 +221,8 @@ class TestLayerPrint(unittest.TestCase):
module = nn.BatchNorm3D(1)
self.assertEqual(
str(module),
'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05)')
'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCDHW)'
)
module = nn.SyncBatchNorm(2)
self.assertEqual(
......
......@@ -745,6 +745,19 @@ class BatchNorm1D(_BatchNormBase):
print(batch_norm_out)
"""
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCL',
use_global_stats=None,
name=None):
super(BatchNorm1D,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, use_global_stats, name)
def _check_data_format(self, input):
if input == 'NCHW' or input == 'NC' or input == 'NCL':
self._data_format = 'NCHW'
......@@ -924,6 +937,19 @@ class BatchNorm3D(_BatchNormBase):
print(batch_norm_out)
"""
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCDHW',
use_global_stats=None,
name=None):
super(BatchNorm3D,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, use_global_stats, name)
def _check_data_format(self, input):
if input == 'NCHW' or input == 'NCDHW':
self._data_format = 'NCHW'
......@@ -1036,7 +1062,7 @@ class SyncBatchNorm(_BatchNormBase):
name=None):
super(SyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, name)
bias_attr, data_format, None, name)
def forward(self, x):
# create output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册