diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py b/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py index 466226c53fabbd315acd19c6421f210d0ca225c1..a963c2ece0958048b5f0c850184a0930022e6671 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py @@ -121,6 +121,9 @@ class TestDygraphWeightNorm(unittest.TestCase): before_weight = linear.weight.numpy() if self.dim == None: self.dim = -1 + + if self.dim != -1: + self.dim = (self.dim + len(before_weight)) % len(before_weight) wn = weight_norm(linear, dim=self.dim) outputs = [] for name, data in self.data.items(): @@ -158,6 +161,13 @@ class TestDygraphWeightNormCase3(TestDygraphWeightNorm): self.dim = 3 +class TestDygraphWeightNormCase4(TestDygraphWeightNorm): + def init_test_case(self): + self.batch_size = 3 + self.data_desc = (['x', [2, 3, 3]], ) + self.dim = -3 + + class TestDygraphRemoveWeightNorm(unittest.TestCase): def setUp(self): self.init_test_case() diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index 09cd40d9cc59914c82cc343bb78b72fbc2b29e59..1c11e831b0ad31a3c450c70e7f7c258455409d05 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -227,14 +227,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase): return with program_guard(Program(), Program()): + compare_model = paddle.nn.Sequential( + paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5)) model = paddle.nn.Sequential( paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5)) - sync_model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) - for idx, sublayer in enumerate(model.sublayers()): + model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) + for idx, sublayer in enumerate(compare_model.sublayers()): if isinstance(sublayer, paddle.nn.BatchNorm2d): self.assertEqual( - isinstance(sync_model[idx], paddle.nn.SyncBatchNorm), - True) + isinstance(model[idx], paddle.nn.SyncBatchNorm), True) if __name__ == '__main__': diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index d13bf66ba5bfe483284e78dbcd2a42f8f3397210..2000fbf388f88d1da7119402104706a433cebf06 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -1130,10 +1130,10 @@ class SyncBatchNorm(_BatchNormBase): """ layer_output = layer if isinstance(layer, _BatchNormBase): - layer_output = SyncBatchNorm(layer._num_features, layer._epsilon, - layer._momentum, layer._weight_attr, - layer._bias_attr, layer._data_format, - layer._name) + layer_output = SyncBatchNorm( + layer._num_features, layer._momentum, layer._epsilon, + layer._weight_attr, layer._bias_attr, layer._data_format, + layer._track_running_stats, layer._name) if layer._weight_attr != False and layer._bias_attr != False: with no_grad(): diff --git a/python/paddle/nn/utils/weight_norm_hook.py b/python/paddle/nn/utils/weight_norm_hook.py index ad53bf394660f3a7e0e48fdbd5eb530abd0852bb..7a21e7661d4e78d0004996ee67c80ddc35006bc3 100644 --- a/python/paddle/nn/utils/weight_norm_hook.py +++ b/python/paddle/nn/utils/weight_norm_hook.py @@ -112,6 +112,14 @@ class WeightNorm(object): if dim is None: dim = -1 + # support dim is negative numeber, (dim = -1) == (dim = None) + weight_dim = len(layer._parameters[name].shape) + assert ( + dim < weight_dim and dim >= -1 * weight_dim + ), "dim must set between [-R, R), R means the dimension of weight." + if dim != -1: + dim = (dim + weight_dim) % weight_dim + fn = WeightNorm(name, dim) w = getattr(layer, name)