未验证 提交 82eb486e 编写于 作者: Z zhang wenhui 提交者: GitHub

fix test_group_norm, test=develop (#27929)

上级 766b3515
......@@ -31,24 +31,24 @@ class TestDygraphGroupNormv2(unittest.TestCase):
if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0))
for p in places:
shape = [2, 6, 2, 2]
shape = [2, 2, 2, 2]
def compute_v1(x):
with fluid.dygraph.guard(p):
gn = fluid.dygraph.GroupNorm(channels=6, groups=2)
gn = fluid.dygraph.GroupNorm(channels=2, groups=2)
y = gn(fluid.dygraph.to_variable(x))
return y.numpy()
def compute_v2(x):
with fluid.dygraph.guard(p):
gn = paddle.nn.GroupNorm(num_channels=6, num_groups=2)
gn = paddle.nn.GroupNorm(num_channels=2, num_groups=2)
y = gn(fluid.dygraph.to_variable(x))
return y.numpy()
def test_weight_bias_false():
with fluid.dygraph.guard(p):
gn = paddle.nn.GroupNorm(
num_channels=6,
num_channels=2,
num_groups=2,
weight_attr=False,
bias_attr=False)
......@@ -56,7 +56,7 @@ class TestDygraphGroupNormv2(unittest.TestCase):
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
result = np.allclose(y1, y2)
result = np.allclose(y1, y2, atol=1e-5)
if not result:
print("y1:", y1, "\ty2:", y2)
self.assertTrue(result)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册