未验证 提交 4ff8fca5 编写于 作者: U umiswing 提交者: GitHub

Fix test_sparse_norm_op failure. (#55405)

* Fix test failed on cudnn.

* Fix codestyle.
上级 cf76e7ae
...@@ -100,19 +100,23 @@ class TestSparseBatchNorm(unittest.TestCase): ...@@ -100,19 +100,23 @@ class TestSparseBatchNorm(unittest.TestCase):
else: else:
bn = paddle.nn.BatchNorm3D(shape[-1], data_format=data_format) bn = paddle.nn.BatchNorm3D(shape[-1], data_format=data_format)
y = bn(x) y = bn(x)
y.backward() np.random.seed(5)
loss_data = np.random.uniform(-0.01, 0.01, y.shape).astype("float32")
loss = paddle.to_tensor(loss_data)
y.backward(loss)
sp_x = paddle.to_tensor(data).to_sparse_coo(dim - 1) sp_x = paddle.to_tensor(data).to_sparse_coo(dim - 1)
sp_x.stop_gradient = False sp_x.stop_gradient = False
sp_bn = paddle.sparse.nn.BatchNorm(shape[-1], data_format=data_format) sp_bn = paddle.sparse.nn.BatchNorm(shape[-1], data_format=data_format)
sp_y = sp_bn(sp_x) sp_y = sp_bn(sp_x)
sp_y.backward() sp_loss = loss.to_sparse_coo(dim - 1)
sp_y.backward(sp_loss)
np.testing.assert_allclose( np.testing.assert_allclose(
y.numpy(), sp_y.to_dense().numpy(), rtol=1e-5 sp_y.to_dense().numpy(), y.numpy(), rtol=1e-5
) )
np.testing.assert_allclose( np.testing.assert_allclose(
x.grad.numpy(), sp_x.grad.to_dense().numpy(), rtol=1e-5 sp_x.grad.to_dense().numpy(), x.grad.numpy(), rtol=1e-5
) )
def test_nd(self): def test_nd(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册