From 8f5eae47cd13e5615cc5abafe95560370fb4f489 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <46243324+zrr1999@users.noreply.github.com> Date: Fri, 28 Apr 2023 14:35:03 +0800 Subject: [PATCH] [Bug fixes] Fix bugs in some sparse test (#53428) --- python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py | 2 ++ python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py | 2 ++ python/paddle/fluid/tests/unittests/test_sparse_unary_op.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py b/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py index 34b7a95299f..84315ede556 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_reshape_op.py @@ -36,6 +36,8 @@ class TestReshape(unittest.TestCase): paddle.sparse.reshape. """ mask = np.random.randint(0, 2, x_shape) + while np.sum(mask) == 0: + mask = paddle.randint(0, 2, x_shape) np_x = np.random.randint(-100, 100, x_shape) * mask # check cpu kernel diff --git a/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py b/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py index 32772d38913..74a9d7e3308 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_transpose_op.py @@ -23,6 +23,8 @@ class TestTranspose(unittest.TestCase): # x: sparse, out: sparse def check_result(self, x_shape, dims, format): mask = paddle.randint(0, 2, x_shape).astype("float32") + while paddle.sum(mask) == 0: + mask = paddle.randint(0, 2, x_shape).astype("float32") # "+ 1" to make sure that all zero elements in "origin_x" is caused by multiplying by "mask", # or the backward checks may fail. origin_x = (paddle.rand(x_shape, dtype='float32') + 1) * mask diff --git a/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py b/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py index edb7393bf30..908121ace01 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py @@ -30,6 +30,8 @@ class TestSparseUnary(unittest.TestCase): def check_result(self, dense_func, sparse_func, format, *args): origin_x = paddle.rand([8, 16, 32], dtype='float32') mask = paddle.randint(0, 2, [8, 16, 32]).astype('float32') + while paddle.sum(mask) == 0: + mask = paddle.randint(0, 2, [8, 16, 32]).astype("float32") # --- check sparse coo with dense --- # dense_x = origin_x * mask -- GitLab