未验证 提交 8d95a109 编写于 作者: Z zhongpu 提交者: GitHub

fix if logic in dygraph, test=develop (#23512)

上级 316ea549
...@@ -2610,9 +2610,9 @@ class GroupNorm(layers.Layer): ...@@ -2610,9 +2610,9 @@ class GroupNorm(layers.Layer):
def forward(self, input): def forward(self, input):
inputs = {'X': input} inputs = {'X': input}
if self.bias: if self.bias is not None:
inputs['Bias'] = self.bias inputs['Bias'] = self.bias
if self.weight: if self.weight is not None:
inputs['Scale'] = self.weight inputs['Scale'] = self.weight
# create output # create output
......
...@@ -1114,7 +1114,12 @@ class TestLayer(LayerTest): ...@@ -1114,7 +1114,12 @@ class TestLayer(LayerTest):
dtype='float32', dtype='float32',
lod_level=1, lod_level=1,
append_batch_size=False) append_batch_size=False)
ret = layers.group_norm(input=X, groups=2) ret = layers.group_norm(
input=X,
groups=2,
param_attr=fluid.initializer.Uniform(
low=-0.5, high=0.5),
bias_attr=fluid.initializer.ConstantInitializer(value=1))
static_ret = self.get_static_graph_result( static_ret = self.get_static_graph_result(
feed={ feed={
'X': fluid.create_lod_tensor( 'X': fluid.create_lod_tensor(
...@@ -1130,7 +1135,12 @@ class TestLayer(LayerTest): ...@@ -1130,7 +1135,12 @@ class TestLayer(LayerTest):
dtype='float32', dtype='float32',
lod_level=1, lod_level=1,
append_batch_size=False) append_batch_size=False)
groupNorm = nn.GroupNorm(channels=shape[1], groups=2) groupNorm = nn.GroupNorm(
channels=shape[1],
groups=2,
param_attr=fluid.initializer.Uniform(
low=-0.5, high=0.5),
bias_attr=fluid.initializer.ConstantInitializer(value=1))
ret = groupNorm(X) ret = groupNorm(X)
static_ret2 = self.get_static_graph_result( static_ret2 = self.get_static_graph_result(
feed={ feed={
...@@ -1141,7 +1151,12 @@ class TestLayer(LayerTest): ...@@ -1141,7 +1151,12 @@ class TestLayer(LayerTest):
with_lod=True)[0] with_lod=True)[0]
with self.dynamic_graph(): with self.dynamic_graph():
groupNorm = nn.GroupNorm(channels=shape[1], groups=2) groupNorm = nn.GroupNorm(
channels=shape[1],
groups=2,
param_attr=fluid.initializer.Uniform(
low=-0.5, high=0.5),
bias_attr=fluid.initializer.ConstantInitializer(value=1))
dy_ret = groupNorm(base.to_variable(input)) dy_ret = groupNorm(base.to_variable(input))
dy_rlt_value = dy_ret.numpy() dy_rlt_value = dy_ret.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册