未验证 提交 da951788 编写于 作者: C ceci3 提交者: GitHub

[2.0API] fix weight_norm support negative dim and unittest in convert_syncbn (#27108) (#27157)

* fix 2.0api, test=develop

* fix, test=develop
上级 2118868f
......@@ -121,6 +121,9 @@ class TestDygraphWeightNorm(unittest.TestCase):
before_weight = linear.weight.numpy()
if self.dim == None:
self.dim = -1
if self.dim != -1:
self.dim = (self.dim + len(before_weight)) % len(before_weight)
wn = weight_norm(linear, dim=self.dim)
outputs = []
for name, data in self.data.items():
......@@ -158,6 +161,13 @@ class TestDygraphWeightNormCase3(TestDygraphWeightNorm):
self.dim = 3
class TestDygraphWeightNormCase4(TestDygraphWeightNorm):
def init_test_case(self):
self.batch_size = 3
self.data_desc = (['x', [2, 3, 3]], )
self.dim = -3
class TestDygraphRemoveWeightNorm(unittest.TestCase):
def setUp(self):
self.init_test_case()
......
......@@ -227,14 +227,15 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
return
with program_guard(Program(), Program()):
compare_model = paddle.nn.Sequential(
paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5))
model = paddle.nn.Sequential(
paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5))
sync_model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
for idx, sublayer in enumerate(model.sublayers()):
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
for idx, sublayer in enumerate(compare_model.sublayers()):
if isinstance(sublayer, paddle.nn.BatchNorm2d):
self.assertEqual(
isinstance(sync_model[idx], paddle.nn.SyncBatchNorm),
True)
isinstance(model[idx], paddle.nn.SyncBatchNorm), True)
if __name__ == '__main__':
......
......@@ -1130,10 +1130,10 @@ class SyncBatchNorm(_BatchNormBase):
"""
layer_output = layer
if isinstance(layer, _BatchNormBase):
layer_output = SyncBatchNorm(layer._num_features, layer._epsilon,
layer._momentum, layer._weight_attr,
layer._bias_attr, layer._data_format,
layer._name)
layer_output = SyncBatchNorm(
layer._num_features, layer._momentum, layer._epsilon,
layer._weight_attr, layer._bias_attr, layer._data_format,
layer._track_running_stats, layer._name)
if layer._weight_attr != False and layer._bias_attr != False:
with no_grad():
......
......@@ -112,6 +112,14 @@ class WeightNorm(object):
if dim is None:
dim = -1
# support dim is negative numeber, (dim = -1) == (dim = None)
weight_dim = len(layer._parameters[name].shape)
assert (
dim < weight_dim and dim >= -1 * weight_dim
), "dim must set between [-R, R), R means the dimension of weight."
if dim != -1:
dim = (dim + weight_dim) % weight_dim
fn = WeightNorm(name, dim)
w = getattr(layer, name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册