diff --git a/paddle/fluid/operators/graph_khop_sampler_op.cu b/paddle/fluid/operators/graph_khop_sampler_op.cu index 39767b5e20a8783568f1fb66bbd46fa08f79726b..c9e4dac74a85a919999a96186c83d64fc3d1c0d3 100644 --- a/paddle/fluid/operators/graph_khop_sampler_op.cu +++ b/paddle/fluid/operators/graph_khop_sampler_op.cu @@ -423,6 +423,31 @@ class GraphKhopSamplerOpCUDAKernel : public framework::OpKernel { std::vector sample_sizes = ctx.Attr>("sample_sizes"); bool return_eids = ctx.Attr("return_eids"); + auto row_dims = src->dims(); + auto row_dims_lens = row_dims.size(); + auto col_dims = dst_count->dims(); + auto col_dims_lens = col_dims.size(); + auto x_dims = vertices->dims(); + auto x_dims_lens = x_dims.size(); + for (int i = 0; i < row_dims_lens; i++) { + PADDLE_ENFORCE_NE( + row_dims[i], + 0, + phi::errors::InvalidArgument("The size of Row(X) should not be 0.")); + } + for (int i = 0; i < col_dims_lens; i++) { + PADDLE_ENFORCE_NE(col_dims[i], + 0, + phi::errors::InvalidArgument( + "The size of Col_Ptr(X) should not be 0.")); + } + for (int i = 0; i < x_dims_lens; i++) { + PADDLE_ENFORCE_NE(x_dims[i], + 0, + phi::errors::InvalidArgument( + "The size of Input_Node(X) should not be 0.")); + } + const T* src_data = src->data(); const T* dst_count_data = dst_count->data(); const T* p_vertices = vertices->data(); diff --git a/paddle/fluid/operators/graph_khop_sampler_op.h b/paddle/fluid/operators/graph_khop_sampler_op.h index f5ec87f23c88bf93a6e7ac9ffb95199d5bbc0430..a22b7a6ee20d8f4531728ffbaabd0a37e2cecbec 100644 --- a/paddle/fluid/operators/graph_khop_sampler_op.h +++ b/paddle/fluid/operators/graph_khop_sampler_op.h @@ -199,6 +199,31 @@ class GraphKhopSamplerOpKernel : public framework::OpKernel { auto* src = ctx.Input("Row"); auto* dst_count = ctx.Input("Col_Ptr"); auto* vertices = ctx.Input("X"); + auto row_dims = src->dims(); + auto row_dims_lens = row_dims.size(); + auto col_dims = dst_count->dims(); + auto col_dims_lens = col_dims.size(); + auto x_dims = vertices->dims(); + auto x_dims_lens = x_dims.size(); + for (int i = 0; i < row_dims_lens; i++) { + PADDLE_ENFORCE_NE( + row_dims[i], + 0, + phi::errors::InvalidArgument("The size of Row(X) should not be 0.")); + } + for (int i = 0; i < col_dims_lens; i++) { + PADDLE_ENFORCE_NE(col_dims[i], + 0, + phi::errors::InvalidArgument( + "The size of Col_Ptr(X) should not be 0.")); + } + for (int i = 0; i < x_dims_lens; i++) { + PADDLE_ENFORCE_NE(x_dims[i], + 0, + phi::errors::InvalidArgument( + "The size of Input_Node(X) should not be 0.")); + } + std::vector sample_sizes = ctx.Attr>("sample_sizes"); bool return_eids = ctx.Attr("return_eids"); diff --git a/python/paddle/fluid/tests/unittests/test_graph_khop_sampler.py b/python/paddle/fluid/tests/unittests/test_graph_khop_sampler.py index 019757518d6c9a28f341854dd0b572fb11113ada..4d4d42b0fa36133bae60a5317c59875bd527f3f0 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_khop_sampler.py +++ b/python/paddle/fluid/tests/unittests/test_graph_khop_sampler.py @@ -259,6 +259,36 @@ class TestGraphKhopSampler(unittest.TestCase): in_neighbors = np.isin(edge_src_n, self.dst_src_dict[n]) self.assertTrue(np.sum(in_neighbors) == in_neighbors.shape[0]) + def test_for_null_pointer_error(self): + def test_in_row(): + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32') + y = paddle.to_tensor([10], dtype='int32') + layer = paddle.incubate.graph_khop_sampler( + row=x, colptr=x, input_nodes=y, sample_sizes=[0] + ) + + def test_in_col(): + array = np.array([], dtype=np.float32) + x = paddle.to_tensor([10], dtype='int32') + col = paddle.to_tensor(np.reshape(array, [0]), dtype='int32') + y = paddle.to_tensor([10], dtype='int32') + layer = paddle.incubate.graph_khop_sampler( + row=x, colptr=col, input_nodes=y, sample_sizes=[0] + ) + + def test_in_input_nodes(): + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32') + y = paddle.to_tensor([10], dtype='int32') + layer = paddle.incubate.graph_khop_sampler( + row=y, colptr=y, input_nodes=x, sample_sizes=[0] + ) + + self.assertRaises(ValueError, test_in_row) + self.assertRaises(ValueError, test_in_col) + self.assertRaises(ValueError, test_in_input_nodes) + if __name__ == "__main__": unittest.main()