未验证 提交 ef599afe 编写于 作者: C chenxujun 提交者: GitHub

Fix paddle.incubate.graph_reindex divide by 0 error (#51714)

上级 6aa3670f
...@@ -42,6 +42,10 @@ void GraphReindexKernel(const Context& dev_ctx, ...@@ -42,6 +42,10 @@ void GraphReindexKernel(const Context& dev_ctx,
std::unordered_map<T, T> node_map; std::unordered_map<T, T> node_map;
std::vector<T> unique_nodes; std::vector<T> unique_nodes;
int reindex_id = 0; int reindex_id = 0;
PADDLE_ENFORCE_NE(
0,
bs,
errors::InvalidArgument("The first of dims should not be equal to 0."));
for (int i = 0; i < bs; i++) { for (int i = 0; i < bs; i++) {
T node = x_data[i]; T node = x_data[i];
unique_nodes.emplace_back(node); unique_nodes.emplace_back(node);
......
...@@ -389,6 +389,10 @@ void GraphReindexKernel(const Context& dev_ctx, ...@@ -389,6 +389,10 @@ void GraphReindexKernel(const Context& dev_ctx,
const T* neighbors_data = neighbors.data<T>(); const T* neighbors_data = neighbors.data<T>();
const int* count_data = count.data<int>(); const int* count_data = count.data<int>();
const int bs = x.dims()[0]; const int bs = x.dims()[0];
PADDLE_ENFORCE_NE(
0,
bs,
errors::InvalidArgument("The first of dims should not be equal to 0."));
const int num_edges = neighbors.dims()[0]; const int num_edges = neighbors.dims()[0];
reindex_src->Resize({num_edges}); reindex_src->Resize({num_edges});
......
...@@ -212,6 +212,20 @@ class TestGraphReindex(unittest.TestCase): ...@@ -212,6 +212,20 @@ class TestGraphReindex(unittest.TestCase):
) )
np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05) np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05)
def test_reindex_div_zero(self):
paddle.disable_static()
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32')
with self.assertRaises(ValueError):
paddle.incubate.graph_reindex(
x=x,
neighbors=x,
count=x,
value_buffer=x,
index_buffer=x,
flag_buffer_hashtable=False,
)
class TestGeometricGraphReindex(unittest.TestCase): class TestGeometricGraphReindex(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册