diff --git a/paddle/phi/kernels/cpu/graph_reindex_kernel.cc b/paddle/phi/kernels/cpu/graph_reindex_kernel.cc index c0a88f3222717fc02da1e445bdacf99ef8f14ffd..92f2dc41e65fb114e0170a4bf818bf26d420c3c2 100644 --- a/paddle/phi/kernels/cpu/graph_reindex_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_reindex_kernel.cc @@ -59,11 +59,15 @@ void GraphReindexKernel(const Context& dev_ctx, src[i] = node_map[node]; } // Reindex Dst + // Add support for multi-type edges reindex + int num_edge_types = count.dims()[0] / bs; int cnt = 0; - for (int i = 0; i < bs; i++) { - for (int j = 0; j < count_data[i]; j++) { - T node = x_data[i]; - dst[cnt++] = node_map[node]; + for (int i = 0; i < num_edge_types; i++) { + for (int j = 0; j < bs; j++) { + for (int k = 0; k < count_data[i * bs + j]; k++) { + T node = x_data[j]; + dst[cnt++] = node_map[node]; + } } } diff --git a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu index 9869d5a517bcbc0ebd5a63ae4ca21d29067c3293..9c6c83a738f89b67eaca29caf63a786657f4580c 100644 --- a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu @@ -331,26 +331,37 @@ void GraphReindexKernel(const Context& dev_ctx, } // Get reindex dst edge. + // Add support for multi-type edges reindex. + int num_ac_count = count.dims()[0]; + int num_edge_types = num_ac_count / bs; thrust::device_vector unique_dst_reindex(bs); thrust::sequence(unique_dst_reindex.begin(), unique_dst_reindex.end()); - thrust::device_vector dst_ptr(bs); - thrust::exclusive_scan(count_data, count_data + bs, dst_ptr.begin()); constexpr int BLOCK_WARPS = 128 / WARP_SIZE; constexpr int TILE_SIZE = BLOCK_WARPS * 16; const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); - reindex_dst->Resize({num_edges}); T* reindex_dst_data = dev_ctx.template Alloc(reindex_dst); - - GetDstEdgeCUDAKernel<<>>( - bs, - thrust::raw_pointer_cast(unique_dst_reindex.data()), - count_data, - thrust::raw_pointer_cast(dst_ptr.data()), - reindex_dst_data); + int begin = 0; + for (int i = 0; i < num_edge_types; i++) { + thrust::device_vector dst_ptr(bs); + thrust::exclusive_scan( + count_data + i * bs, count_data + (i + 1) * bs, dst_ptr.begin()); + + GetDstEdgeCUDAKernel<<>>( + bs, + thrust::raw_pointer_cast(unique_dst_reindex.data()), + count_data + i * bs, + thrust::raw_pointer_cast(dst_ptr.data()), + reindex_dst_data + begin); + + int count_i = + thrust::reduce(thrust::device_pointer_cast(count_data) + i * bs, + thrust::device_pointer_cast(count_data) + (i + 1) * bs); + begin += count_i; + } out_nodes->Resize({static_cast(unique_nodes.size())}); T* out_nodes_data = dev_ctx.template Alloc(out_nodes); diff --git a/python/paddle/fluid/tests/unittests/test_graph_reindex.py b/python/paddle/fluid/tests/unittests/test_graph_reindex.py index 52abbbe81aef93046b4b2458ad350eeac6a5bba3..4a98beb0cceb9917c332c71dccb8506e2f13e2bc 100644 --- a/python/paddle/fluid/tests/unittests/test_graph_reindex.py +++ b/python/paddle/fluid/tests/unittests/test_graph_reindex.py @@ -62,6 +62,63 @@ class TestGraphReindex(unittest.TestCase): self.assertTrue(np.allclose(self.reindex_dst, reindex_dst)) self.assertTrue(np.allclose(self.out_nodes, out_nodes)) + def test_heter_reindex_result(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + neighbors = paddle.to_tensor(self.neighbors) + neighbors = paddle.concat([neighbors, neighbors]) + count = paddle.to_tensor(self.count) + count = paddle.concat([count, count]) + + reindex_src, reindex_dst, out_nodes = \ + paddle.incubate.graph_reindex(x, neighbors, count) + self.assertTrue( + np.allclose(self.reindex_src, reindex_src[:self.neighbors.shape[ + 0]])) + self.assertTrue( + np.allclose(self.reindex_src, reindex_src[self.neighbors.shape[ + 0]:])) + self.assertTrue( + np.allclose(self.reindex_dst, reindex_dst[:self.neighbors.shape[ + 0]])) + self.assertTrue( + np.allclose(self.reindex_dst, reindex_dst[self.neighbors.shape[ + 0]:])) + self.assertTrue(np.allclose(self.out_nodes, out_nodes)) + + def test_heter_reindex_result_v2(self): + paddle.disable_static() + x = np.arange(5).astype("int64") + neighbors1 = np.random.randint(100, size=20).astype("int64") + count1 = np.array([2, 8, 4, 3, 3], dtype="int32") + neighbors2 = np.random.randint(100, size=20).astype("int64") + count2 = np.array([4, 5, 1, 6, 4], dtype="int32") + neighbors = np.concatenate([neighbors1, neighbors2]) + counts = np.concatenate([count1, count2]) + + # Get numpy result. + out_nodes = list(x) + for neighbor in neighbors: + if neighbor not in out_nodes: + out_nodes.append(neighbor) + out_nodes = np.array(out_nodes, dtype="int64") + reindex_dict = {node: ind for ind, node in enumerate(out_nodes)} + reindex_src = np.array([reindex_dict[node] for node in neighbors]) + reindex_dst = [] + for count in [count1, count2]: + for node, c in zip(x, count): + for i in range(c): + reindex_dst.append(reindex_dict[node]) + reindex_dst = np.array(reindex_dst, dtype="int64") + + reindex_src_, reindex_dst_, out_nodes_ = \ + paddle.incubate.graph_reindex(paddle.to_tensor(x), + paddle.to_tensor(neighbors), + paddle.to_tensor(counts)) + self.assertTrue(np.allclose(reindex_src, reindex_src_)) + self.assertTrue(np.allclose(reindex_dst, reindex_dst_)) + self.assertTrue(np.allclose(out_nodes, out_nodes_)) + def test_reindex_result_static(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): diff --git a/python/paddle/incubate/operators/graph_khop_sampler.py b/python/paddle/incubate/operators/graph_khop_sampler.py index 5442b213ceb476f83f411dd67c7c1fc38eb4bc41..64aecca8411abc7a5c755ec998a55b22e211dc58 100644 --- a/python/paddle/incubate/operators/graph_khop_sampler.py +++ b/python/paddle/incubate/operators/graph_khop_sampler.py @@ -38,10 +38,6 @@ def graph_khop_sampler(row, and `sample_sizes` means the number of neighbors and number of layers we want to sample. - **Note**: - Currently the API will reindex the output edges after finishing sampling. We - will add a choice or a new API for whether to reindex the edges in the near future. - Args: row (Tensor): One of the components of the CSC format of the input graph, and the shape should be [num_edges, 1] or [num_edges]. The available diff --git a/python/paddle/incubate/operators/graph_reindex.py b/python/paddle/incubate/operators/graph_reindex.py index 328b87a699750c173b12f3545c8cde6107c1c969..4cfd96ebf44562764be13f521e1d74e051642e09 100644 --- a/python/paddle/incubate/operators/graph_reindex.py +++ b/python/paddle/incubate/operators/graph_reindex.py @@ -35,6 +35,12 @@ def graph_reindex(x, is to reindex the ids information of the input nodes, and return the corresponding graph edges after reindex. + **Notes**: + The number in x should be unique, otherwise it would cause potential errors. + Besides, we also support multi-edge-types neighbors reindexing. If we have different + edge_type neighbors for x, we should concatenate all the neighbors and count of x. + We will reindex all the nodes from 0. + Take input nodes x = [0, 1, 2] as an example. If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2], then we know that the neighbors of 0 is [8, 9], the neighbors of 1 @@ -70,18 +76,31 @@ def graph_reindex(x, import paddle x = [0, 1, 2] - neighbors = [8, 9, 0, 4, 7, 6, 7] - count = [2, 3, 2] + neighbors_e1 = [8, 9, 0, 4, 7, 6, 7] + count_e1 = [2, 3, 2] x = paddle.to_tensor(x, dtype="int64") - neighbors = paddle.to_tensor(neighbors, dtype="int64") - count = paddle.to_tensor(count, dtype="int32") + neighbors_e1 = paddle.to_tensor(neighbors_e1, dtype="int64") + count_e1 = paddle.to_tensor(count_e1, dtype="int32") reindex_src, reindex_dst, out_nodes = \ - paddle.incubate.graph_reindex(x, neighbors, count) + paddle.incubate.graph_reindex(x, neighbors_e1, count_e1) # reindex_src: [3, 4, 0, 5, 6, 7, 6] # reindex_dst: [0, 0, 1, 1, 1, 2, 2] # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6] + neighbors_e2 = [0, 2, 3, 5, 1] + count_e2 = [1, 3, 1] + neighbors_e2 = paddle.to_tensor(neighbors_e2, dtype="int64") + count_e2 = paddle.to_tensor(count_e2, dtype="int32") + + neighbors = paddle.concat([neighbors_e1, neighbors_e2]) + count = paddle.concat([count_e1, count_e2]) + reindex_src, reindex_dst, out_nodes = \ + paddle.incubate.graph_reindex(x, neighbors, count) + # reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1] + # reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2] + # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5] + """ if flag_buffer_hashtable: if value_buffer is None or index_buffer is None: