diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 8b9fea3fc934f8b2469c25c517692826ae10893c..8d6eae60be46ea1e06763f9a337e61f8b7edfe1f 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -2610,9 +2610,9 @@ class GroupNorm(layers.Layer): def forward(self, input): inputs = {'X': input} - if self.bias: + if self.bias is not None: inputs['Bias'] = self.bias - if self.weight: + if self.weight is not None: inputs['Scale'] = self.weight # create output diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 666e7c86bb297c44fb8caeb646a0abcc6da2f3d4..7c0e020948e3e87ac339d86c2c3c767b5280a80f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1114,7 +1114,12 @@ class TestLayer(LayerTest): dtype='float32', lod_level=1, append_batch_size=False) - ret = layers.group_norm(input=X, groups=2) + ret = layers.group_norm( + input=X, + groups=2, + param_attr=fluid.initializer.Uniform( + low=-0.5, high=0.5), + bias_attr=fluid.initializer.ConstantInitializer(value=1)) static_ret = self.get_static_graph_result( feed={ 'X': fluid.create_lod_tensor( @@ -1130,7 +1135,12 @@ class TestLayer(LayerTest): dtype='float32', lod_level=1, append_batch_size=False) - groupNorm = nn.GroupNorm(channels=shape[1], groups=2) + groupNorm = nn.GroupNorm( + channels=shape[1], + groups=2, + param_attr=fluid.initializer.Uniform( + low=-0.5, high=0.5), + bias_attr=fluid.initializer.ConstantInitializer(value=1)) ret = groupNorm(X) static_ret2 = self.get_static_graph_result( feed={ @@ -1141,7 +1151,12 @@ class TestLayer(LayerTest): with_lod=True)[0] with self.dynamic_graph(): - groupNorm = nn.GroupNorm(channels=shape[1], groups=2) + groupNorm = nn.GroupNorm( + channels=shape[1], + groups=2, + param_attr=fluid.initializer.Uniform( + low=-0.5, high=0.5), + bias_attr=fluid.initializer.ConstantInitializer(value=1)) dy_ret = groupNorm(base.to_variable(input)) dy_rlt_value = dy_ret.numpy()