未验证 提交 da951788 编写于 作者: C ceci3 提交者: GitHub

[2.0API] fix weight_norm support negative dim and unittest in convert_syncbn (#27108) (#27157)

* fix 2.0api, test=develop

* fix, test=develop
上级 2118868f
...@@ -121,6 +121,9 @@ class TestDygraphWeightNorm(unittest.TestCase): ...@@ -121,6 +121,9 @@ class TestDygraphWeightNorm(unittest.TestCase):
before_weight = linear.weight.numpy() before_weight = linear.weight.numpy()
if self.dim == None: if self.dim == None:
self.dim = -1 self.dim = -1
if self.dim != -1:
self.dim = (self.dim + len(before_weight)) % len(before_weight)
wn = weight_norm(linear, dim=self.dim) wn = weight_norm(linear, dim=self.dim)
outputs = [] outputs = []
for name, data in self.data.items(): for name, data in self.data.items():
...@@ -158,6 +161,13 @@ class TestDygraphWeightNormCase3(TestDygraphWeightNorm): ...@@ -158,6 +161,13 @@ class TestDygraphWeightNormCase3(TestDygraphWeightNorm):
self.dim = 3 self.dim = 3
class TestDygraphWeightNormCase4(TestDygraphWeightNorm):
def init_test_case(self):
self.batch_size = 3
self.data_desc = (['x', [2, 3, 3]], )
self.dim = -3
class TestDygraphRemoveWeightNorm(unittest.TestCase): class TestDygraphRemoveWeightNorm(unittest.TestCase):
def setUp(self): def setUp(self):
self.init_test_case() self.init_test_case()
......
...@@ -227,14 +227,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase): ...@@ -227,14 +227,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
return return
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
compare_model = paddle.nn.Sequential(
paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5))
model = paddle.nn.Sequential( model = paddle.nn.Sequential(
paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5)) paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5))
sync_model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
for idx, sublayer in enumerate(model.sublayers()): for idx, sublayer in enumerate(compare_model.sublayers()):
if isinstance(sublayer, paddle.nn.BatchNorm2d): if isinstance(sublayer, paddle.nn.BatchNorm2d):
self.assertEqual( self.assertEqual(
isinstance(sync_model[idx], paddle.nn.SyncBatchNorm), isinstance(model[idx], paddle.nn.SyncBatchNorm), True)
True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -1130,10 +1130,10 @@ class SyncBatchNorm(_BatchNormBase): ...@@ -1130,10 +1130,10 @@ class SyncBatchNorm(_BatchNormBase):
""" """
layer_output = layer layer_output = layer
if isinstance(layer, _BatchNormBase): if isinstance(layer, _BatchNormBase):
layer_output = SyncBatchNorm(layer._num_features, layer._epsilon, layer_output = SyncBatchNorm(
layer._momentum, layer._weight_attr, layer._num_features, layer._momentum, layer._epsilon,
layer._bias_attr, layer._data_format, layer._weight_attr, layer._bias_attr, layer._data_format,
layer._name) layer._track_running_stats, layer._name)
if layer._weight_attr != False and layer._bias_attr != False: if layer._weight_attr != False and layer._bias_attr != False:
with no_grad(): with no_grad():
......
...@@ -112,6 +112,14 @@ class WeightNorm(object): ...@@ -112,6 +112,14 @@ class WeightNorm(object):
if dim is None: if dim is None:
dim = -1 dim = -1
# support dim is negative numeber, (dim = -1) == (dim = None)
weight_dim = len(layer._parameters[name].shape)
assert (
dim < weight_dim and dim >= -1 * weight_dim
), "dim must set between [-R, R), R means the dimension of weight."
if dim != -1:
dim = (dim + weight_dim) % weight_dim
fn = WeightNorm(name, dim) fn = WeightNorm(name, dim)
w = getattr(layer, name) w = getattr(layer, name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册