From da951788342f7ad0dd3298f6d611a8cda0cf6403 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 8 Sep 2020 13:55:01 +0800 Subject: [PATCH] [2.0API] fix weight_norm support negative dim and unittest in convert_syncbn (#27108) (#27157) * fix 2.0api, test=develop * fix, test=develop --- .../fluid/tests/unittests/test_dygraph_weight_norm.py | 10 ++++++++++ .../fluid/tests/unittests/test_sync_batch_norm_op.py | 9 +++++---- python/paddle/nn/layer/norm.py | 8 ++++---- python/paddle/nn/utils/weight_norm_hook.py | 8 ++++++++ 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py b/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py index 466226c53fa..a963c2ece09 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py @@ -121,6 +121,9 @@ class TestDygraphWeightNorm(unittest.TestCase): before_weight = linear.weight.numpy() if self.dim == None: self.dim = -1 + + if self.dim != -1: + self.dim = (self.dim + len(before_weight)) % len(before_weight) wn = weight_norm(linear, dim=self.dim) outputs = [] for name, data in self.data.items(): @@ -158,6 +161,13 @@ class TestDygraphWeightNormCase3(TestDygraphWeightNorm): self.dim = 3 +class TestDygraphWeightNormCase4(TestDygraphWeightNorm): + def init_test_case(self): + self.batch_size = 3 + self.data_desc = (['x', [2, 3, 3]], ) + self.dim = -3 + + class TestDygraphRemoveWeightNorm(unittest.TestCase): def setUp(self): self.init_test_case() diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index 09cd40d9cc5..1c11e831b0a 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -227,14 +227,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase): return with program_guard(Program(), Program()): + compare_model = paddle.nn.Sequential( + paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5)) model = paddle.nn.Sequential( paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5)) - sync_model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) - for idx, sublayer in enumerate(model.sublayers()): + model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) + for idx, sublayer in enumerate(compare_model.sublayers()): if isinstance(sublayer, paddle.nn.BatchNorm2d): self.assertEqual( - isinstance(sync_model[idx], paddle.nn.SyncBatchNorm), - True) + isinstance(model[idx], paddle.nn.SyncBatchNorm), True) if __name__ == '__main__': diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index d13bf66ba5b..2000fbf388f 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -1130,10 +1130,10 @@ class SyncBatchNorm(_BatchNormBase): """ layer_output = layer if isinstance(layer, _BatchNormBase): - layer_output = SyncBatchNorm(layer._num_features, layer._epsilon, - layer._momentum, layer._weight_attr, - layer._bias_attr, layer._data_format, - layer._name) + layer_output = SyncBatchNorm( + layer._num_features, layer._momentum, layer._epsilon, + layer._weight_attr, layer._bias_attr, layer._data_format, + layer._track_running_stats, layer._name) if layer._weight_attr != False and layer._bias_attr != False: with no_grad(): diff --git a/python/paddle/nn/utils/weight_norm_hook.py b/python/paddle/nn/utils/weight_norm_hook.py index ad53bf39466..7a21e7661d4 100644 --- a/python/paddle/nn/utils/weight_norm_hook.py +++ b/python/paddle/nn/utils/weight_norm_hook.py @@ -112,6 +112,14 @@ class WeightNorm(object): if dim is None: dim = -1 + # support dim is negative numeber, (dim = -1) == (dim = None) + weight_dim = len(layer._parameters[name].shape) + assert ( + dim < weight_dim and dim >= -1 * weight_dim + ), "dim must set between [-R, R), R means the dimension of weight." + if dim != -1: + dim = (dim + weight_dim) % weight_dim + fn = WeightNorm(name, dim) w = getattr(layer, name) -- GitLab