未验证 提交 8f5eae47 编写于 作者: Z Zhan Rongrui 提交者: GitHub

[Bug fixes] Fix bugs in some sparse test (#53428)

上级 4ccbcce5
......@@ -36,6 +36,8 @@ class TestReshape(unittest.TestCase):
paddle.sparse.reshape.
"""
mask = np.random.randint(0, 2, x_shape)
while np.sum(mask) == 0:
mask = paddle.randint(0, 2, x_shape)
np_x = np.random.randint(-100, 100, x_shape) * mask
# check cpu kernel
......
......@@ -23,6 +23,8 @@ class TestTranspose(unittest.TestCase):
# x: sparse, out: sparse
def check_result(self, x_shape, dims, format):
mask = paddle.randint(0, 2, x_shape).astype("float32")
while paddle.sum(mask) == 0:
mask = paddle.randint(0, 2, x_shape).astype("float32")
# "+ 1" to make sure that all zero elements in "origin_x" is caused by multiplying by "mask",
# or the backward checks may fail.
origin_x = (paddle.rand(x_shape, dtype='float32') + 1) * mask
......
......@@ -30,6 +30,8 @@ class TestSparseUnary(unittest.TestCase):
def check_result(self, dense_func, sparse_func, format, *args):
origin_x = paddle.rand([8, 16, 32], dtype='float32')
mask = paddle.randint(0, 2, [8, 16, 32]).astype('float32')
while paddle.sum(mask) == 0:
mask = paddle.randint(0, 2, [8, 16, 32]).astype("float32")
# --- check sparse coo with dense --- #
dense_x = origin_x * mask
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册