From ef599afe52aa49f0b06a6a5652990c6e397f7984 Mon Sep 17 00:00:00 2001 From: chenxujun Date: Fri, 17 Mar 2023 11:55:39 +0800 Subject: [PATCH] Fix paddle.incubate.graph_reindex divide by 0 error (#51714) --- paddle/phi/kernels/cpu/graph_reindex_kernel.cc | 4 ++++ paddle/phi/kernels/gpu/graph_reindex_kernel.cu | 4 ++++ .../fluid/tests/unittests/test_graph_reindex.py | 14 ++++++++++++++ 3 files changed, 22 insertions(+) diff --git a/paddle/phi/kernels/cpu/graph_reindex_kernel.cc b/paddle/phi/kernels/cpu/graph_reindex_kernel.cc index 428bcb03170..21fb8fb1697 100644 --- a/paddle/phi/kernels/cpu/graph_reindex_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_reindex_kernel.cc @@ -42,6 +42,10 @@ void GraphReindexKernel(const Context& dev_ctx, std::unordered_map node_map; std::vector unique_nodes; int reindex_id = 0; + PADDLE_ENFORCE_NE( + 0, + bs, + errors::InvalidArgument("The first of dims should not be equal to 0.")); for (int i = 0; i < bs; i++) { T node = x_data[i]; unique_nodes.emplace_back(node); diff --git a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu index f9a6bf2f682..23094eaa42b 100644 --- a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu @@ -389,6 +389,10 @@ void GraphReindexKernel(const Context& dev_ctx, const T* neighbors_data = neighbors.data(); const int* count_data = count.data(); const int bs = x.dims()[0]; + PADDLE_ENFORCE_NE( + 0, + bs, + errors::InvalidArgument("The first of dims should not be equal to 0.")); const int num_edges = neighbors.dims()[0]; reindex_src->Resize({num_edges}); diff --git a/python/paddle/fluid/tests/unittests/test_graph_reindex.py b/python/paddle/fluid/tests/unittests/test_graph_reindex.py index db767504559..275c49b1cbd 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_reindex.py +++ b/python/paddle/fluid/tests/unittests/test_graph_reindex.py @@ -212,6 +212,20 @@ class TestGraphReindex(unittest.TestCase): ) np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05) + def test_reindex_div_zero(self): + paddle.disable_static() + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32') + with self.assertRaises(ValueError): + paddle.incubate.graph_reindex( + x=x, + neighbors=x, + count=x, + value_buffer=x, + index_buffer=x, + flag_buffer_hashtable=False, + ) + class TestGeometricGraphReindex(unittest.TestCase): def setUp(self): -- GitLab