diff --git a/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py b/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py index 34b7a95299f3e05fb35157e0849f3f8732fb8ab9..84315ede556ac22b47b3708605a488e063c5ff32 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py @@ -36,6 +36,8 @@ class TestReshape(unittest.TestCase): paddle.sparse.reshape. """ mask = np.random.randint(0, 2, x_shape) + while np.sum(mask) == 0: + mask = paddle.randint(0, 2, x_shape) np_x = np.random.randint(-100, 100, x_shape) * mask # check cpu kernel diff --git a/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py b/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py index 32772d389138300786b2057a184b21690f7704a6..74a9d7e3308b3d3089871fc66b3f4b17f804c33c 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py @@ -23,6 +23,8 @@ class TestTranspose(unittest.TestCase): # x: sparse, out: sparse def check_result(self, x_shape, dims, format): mask = paddle.randint(0, 2, x_shape).astype("float32") + while paddle.sum(mask) == 0: + mask = paddle.randint(0, 2, x_shape).astype("float32") # "+ 1" to make sure that all zero elements in "origin_x" is caused by multiplying by "mask", # or the backward checks may fail. origin_x = (paddle.rand(x_shape, dtype='float32') + 1) * mask diff --git a/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py b/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py index edb7393bf305631e25af4d0dff8eaaafcc676c65..908121ace01181faf339f2f618f3d1fc4ccde86b 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py @@ -30,6 +30,8 @@ class TestSparseUnary(unittest.TestCase): def check_result(self, dense_func, sparse_func, format, *args): origin_x = paddle.rand([8, 16, 32], dtype='float32') mask = paddle.randint(0, 2, [8, 16, 32]).astype('float32') + while paddle.sum(mask) == 0: + mask = paddle.randint(0, 2, [8, 16, 32]).astype("float32") # --- check sparse coo with dense --- # dense_x = origin_x * mask