From 7a1cf277543fc67c9f7db423eca933e53577d035 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Wed, 2 Nov 2022 14:22:45 +0800 Subject: [PATCH] [geometric] Optimize graph sample speed (#47531) (#47548) --- .../gpu/graph_sample_neighbors_kernel.cu | 75 +++++++++---------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu index 6632d3f8b2e..3ea1dbc8e19 100644 --- a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu @@ -58,7 +58,7 @@ struct MaxFunctor { } }; -template +template __global__ void SampleKernel(const uint64_t rand_seed, int k, const int64_t num_nodes, @@ -71,8 +71,7 @@ __global__ void SampleKernel(const uint64_t rand_seed, T* output_eids, int* output_ptr, bool return_eids) { - assert(blockDim.x == WARP_SIZE); - assert(blockDim.y == BLOCK_WARPS); + assert(blockDim.x == CTA_SIZE); int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; const int64_t last_row = @@ -80,13 +79,13 @@ __global__ void SampleKernel(const uint64_t rand_seed, #ifdef PADDLE_WITH_HIP hiprandState rng; hiprand_init(rand_seed * gridDim.x + blockIdx.x, - threadIdx.y * WARP_SIZE + threadIdx.x, + threadIdx.y * CTA_SIZE + threadIdx.x, 0, &rng); #else - curandState rng; + curandStatePhilox4_32_10_t rng; curand_init(rand_seed * gridDim.x + blockIdx.x, - threadIdx.y * WARP_SIZE + threadIdx.x, + threadIdx.y * CTA_SIZE + threadIdx.x, 0, &rng); #endif @@ -94,7 +93,7 @@ __global__ void SampleKernel(const uint64_t rand_seed, while (out_row < last_row) { T node = nodes[out_row]; if (node > len_col_ptr - 1) { - out_row += BLOCK_WARPS; + out_row += BLOCK_CTAS; continue; } T in_row_start = col_ptr[node]; @@ -102,21 +101,21 @@ __global__ void SampleKernel(const uint64_t rand_seed, int out_row_start = output_ptr[out_row]; if (deg <= k) { - for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { + for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) { output[out_row_start + idx] = row[in_row_start + idx]; if (return_eids) { output_eids[out_row_start + idx] = eids[in_row_start + idx]; } } } else { - for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { + for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) { output[out_row_start + idx] = idx; } #ifdef PADDLE_WITH_CUDA - __syncwarp(); + __syncthreads(); #endif - for (int idx = k + threadIdx.x; idx < deg; idx += WARP_SIZE) { + for (int idx = k + threadIdx.x; idx < deg; idx += CTA_SIZE) { #ifdef PADDLE_WITH_HIP const int num = hiprand(&rng) % (idx + 1); #else @@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed, } } #ifdef PADDLE_WITH_CUDA - __syncwarp(); + __syncthreads(); #endif - for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { + for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) { T perm_idx = output[out_row_start + idx] + in_row_start; output[out_row_start + idx] = row[perm_idx]; if (return_eids) { @@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed, } } - out_row += BLOCK_WARPS; + out_row += BLOCK_CTAS; } } @@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx, thrust::exclusive_scan( output_count, output_count + bs, output_ptr.begin(), 0); - constexpr int WARP_SIZE = 32; - constexpr int BLOCK_WARPS = 128 / WARP_SIZE; - constexpr int TILE_SIZE = BLOCK_WARPS * 16; - const dim3 block(WARP_SIZE, BLOCK_WARPS); + constexpr int CTA_SIZE = 128; + constexpr int BLOCK_CTAS = 128 / CTA_SIZE; + constexpr int TILE_SIZE = BLOCK_CTAS; + const dim3 block(CTA_SIZE, BLOCK_CTAS); const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); - SampleKernel + SampleKernel <<>>( 0, sample_size, @@ -202,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx, return_eids); } -template +template __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, int k, const int64_t num_rows, @@ -210,8 +209,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, const T* in_rows, T* src, const T* dst_count) { - assert(blockDim.x == WARP_SIZE); - assert(blockDim.y == BLOCK_WARPS); + assert(blockDim.x == CTA_SIZE); int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; const int64_t last_row = @@ -221,7 +219,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, hiprand_init( rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); #else - curandState rng; + curandStatePhilox4_32_10_t rng; curand_init( rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); #endif @@ -229,7 +227,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, while (out_row < last_row) { const T row = in_rows[out_row]; if (row > len_col_ptr - 1) { - out_row += BLOCK_WARPS; + out_row += BLOCK_CTAS; continue; } const T in_row_start = dst_count[row]; @@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, } else { split = deg - k; } - for (int idx = split + threadIdx.x; idx <= deg - 1; idx += WARP_SIZE) { + for (int idx = split + threadIdx.x; idx <= deg - 1; idx += CTA_SIZE) { #ifdef PADDLE_WITH_HIP const int num = hiprand(&rng) % (idx + 1); #else @@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, src[in_row_start + idx]))); } #ifdef PADDLE_WITH_CUDA - __syncwarp(); + __syncthreads(); #endif } - out_row += BLOCK_WARPS; + out_row += BLOCK_CTAS; } } -template +template __global__ void GatherEdge(int k, int64_t num_rows, const T* in_rows, @@ -273,8 +271,7 @@ __global__ void GatherEdge(int k, int* output_ptr, T* perm_data, bool return_eids) { - assert(blockDim.x == WARP_SIZE); - assert(blockDim.y == BLOCK_WARPS); + assert(blockDim.x == CTA_SIZE); int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; const int64_t last_row = @@ -287,7 +284,7 @@ __global__ void GatherEdge(int k, const T out_row_start = output_ptr[out_row]; if (deg <= k) { - for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { + for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) { outputs[out_row_start + idx] = src[in_row_start + idx]; if (return_eids) { output_eids[out_row_start + idx] = eids[in_row_start + idx]; @@ -304,7 +301,7 @@ __global__ void GatherEdge(int k, end = deg; } - for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) { + for (int idx = begin + threadIdx.x; idx < end; idx += CTA_SIZE) { outputs[out_row_start + idx - begin] = src[perm_data[in_row_start + idx]]; if (return_eids) { @@ -313,7 +310,7 @@ __global__ void GatherEdge(int k, } } } - out_row += BLOCK_WARPS; + out_row += BLOCK_CTAS; } } @@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx, thrust::exclusive_scan( output_count, output_count + bs, output_ptr.begin(), 0); - constexpr int WARP_SIZE = 32; - constexpr int BLOCK_WARPS = 128 / WARP_SIZE; - constexpr int TILE_SIZE = BLOCK_WARPS * 16; - const dim3 block(WARP_SIZE, BLOCK_WARPS); + constexpr int CTA_SIZE = 128; + constexpr int BLOCK_CTAS = 128 / CTA_SIZE; + constexpr int TILE_SIZE = BLOCK_CTAS; + const dim3 block(CTA_SIZE, BLOCK_CTAS); const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); - FisherYatesSampleKernel + FisherYatesSampleKernel <<>>(0, sample_size, bs, @@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx, perm_data, col_ptr); - GatherEdge + GatherEdge <<>>( sample_size, bs, -- GitLab