diff --git a/paddle/phi/kernels/cpu/graph_reindex_kernel.cc b/paddle/phi/kernels/cpu/graph_reindex_kernel.cc index 428bcb031704cdfb53a8a035a96260d0575b292f..21fb8fb16977f45554b4a531752fbe47c5801c1e 100644 --- a/paddle/phi/kernels/cpu/graph_reindex_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_reindex_kernel.cc @@ -42,6 +42,10 @@ void GraphReindexKernel(const Context& dev_ctx, std::unordered_map node_map; std::vector unique_nodes; int reindex_id = 0; + PADDLE_ENFORCE_NE( + 0, + bs, + errors::InvalidArgument("The first of dims should not be equal to 0.")); for (int i = 0; i < bs; i++) { T node = x_data[i]; unique_nodes.emplace_back(node); diff --git a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu index f9a6bf2f682629ed5f8b5bc7993d1cba6bda0853..23094eaa42bf951088208e654af7a3ccb23a7055 100644 --- a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu @@ -389,6 +389,10 @@ void GraphReindexKernel(const Context& dev_ctx, const T* neighbors_data = neighbors.data(); const int* count_data = count.data(); const int bs = x.dims()[0]; + PADDLE_ENFORCE_NE( + 0, + bs, + errors::InvalidArgument("The first of dims should not be equal to 0.")); const int num_edges = neighbors.dims()[0]; reindex_src->Resize({num_edges}); diff --git a/python/paddle/fluid/tests/unittests/test_graph_reindex.py b/python/paddle/fluid/tests/unittests/test_graph_reindex.py index db767504559d89fdbe52686bc1a494b08427b1ca..275c49b1cbd47f4aaef6d06d360a7eb6ae7653e9 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_reindex.py +++ b/python/paddle/fluid/tests/unittests/test_graph_reindex.py @@ -212,6 +212,20 @@ class TestGraphReindex(unittest.TestCase): ) np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05) + def test_reindex_div_zero(self): + paddle.disable_static() + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32') + with self.assertRaises(ValueError): + paddle.incubate.graph_reindex( + x=x, + neighbors=x, + count=x, + value_buffer=x, + index_buffer=x, + flag_buffer_hashtable=False, + ) + class TestGeometricGraphReindex(unittest.TestCase): def setUp(self):