未验证 提交 7a1cf277 编写于 作者: S Siming Dai 提交者: GitHub

[geometric] Optimize graph sample speed (#47531) (#47548)

上级 61953b90
...@@ -58,7 +58,7 @@ struct MaxFunctor { ...@@ -58,7 +58,7 @@ struct MaxFunctor {
} }
}; };
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE> template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
__global__ void SampleKernel(const uint64_t rand_seed, __global__ void SampleKernel(const uint64_t rand_seed,
int k, int k,
const int64_t num_nodes, const int64_t num_nodes,
...@@ -71,8 +71,7 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -71,8 +71,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
T* output_eids, T* output_eids,
int* output_ptr, int* output_ptr,
bool return_eids) { bool return_eids) {
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == CTA_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row = const int64_t last_row =
...@@ -80,13 +79,13 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -80,13 +79,13 @@ __global__ void SampleKernel(const uint64_t rand_seed,
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hiprandState rng; hiprandState rng;
hiprand_init(rand_seed * gridDim.x + blockIdx.x, hiprand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x, threadIdx.y * CTA_SIZE + threadIdx.x,
0, 0,
&rng); &rng);
#else #else
curandState rng; curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, curand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x, threadIdx.y * CTA_SIZE + threadIdx.x,
0, 0,
&rng); &rng);
#endif #endif
...@@ -94,7 +93,7 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -94,7 +93,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
while (out_row < last_row) { while (out_row < last_row) {
T node = nodes[out_row]; T node = nodes[out_row];
if (node > len_col_ptr - 1) { if (node > len_col_ptr - 1) {
out_row += BLOCK_WARPS; out_row += BLOCK_CTAS;
continue; continue;
} }
T in_row_start = col_ptr[node]; T in_row_start = col_ptr[node];
...@@ -102,21 +101,21 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -102,21 +101,21 @@ __global__ void SampleKernel(const uint64_t rand_seed,
int out_row_start = output_ptr[out_row]; int out_row_start = output_ptr[out_row];
if (deg <= k) { if (deg <= k) {
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
output[out_row_start + idx] = row[in_row_start + idx]; output[out_row_start + idx] = row[in_row_start + idx];
if (return_eids) { if (return_eids) {
output_eids[out_row_start + idx] = eids[in_row_start + idx]; output_eids[out_row_start + idx] = eids[in_row_start + idx];
} }
} }
} else { } else {
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) {
output[out_row_start + idx] = idx; output[out_row_start + idx] = idx;
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
__syncwarp(); __syncthreads();
#endif #endif
for (int idx = k + threadIdx.x; idx < deg; idx += WARP_SIZE) { for (int idx = k + threadIdx.x; idx < deg; idx += CTA_SIZE) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1); const int num = hiprand(&rng) % (idx + 1);
#else #else
...@@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
} }
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
__syncwarp(); __syncthreads();
#endif #endif
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) {
T perm_idx = output[out_row_start + idx] + in_row_start; T perm_idx = output[out_row_start + idx] + in_row_start;
output[out_row_start + idx] = row[perm_idx]; output[out_row_start + idx] = row[perm_idx];
if (return_eids) { if (return_eids) {
...@@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed, ...@@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
} }
} }
out_row += BLOCK_WARPS; out_row += BLOCK_CTAS;
} }
} }
...@@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx, ...@@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx,
thrust::exclusive_scan( thrust::exclusive_scan(
output_count, output_count + bs, output_ptr.begin(), 0); output_count, output_count + bs, output_ptr.begin(), 0);
constexpr int WARP_SIZE = 32; constexpr int CTA_SIZE = 128;
constexpr int BLOCK_WARPS = 128 / WARP_SIZE; constexpr int BLOCK_CTAS = 128 / CTA_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16; constexpr int TILE_SIZE = BLOCK_CTAS;
const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
SampleKernel<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE> SampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
<<<grid, block, 0, dev_ctx.stream()>>>( <<<grid, block, 0, dev_ctx.stream()>>>(
0, 0,
sample_size, sample_size,
...@@ -202,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx, ...@@ -202,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx,
return_eids); return_eids);
} }
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE> template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed, __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
int k, int k,
const int64_t num_rows, const int64_t num_rows,
...@@ -210,8 +209,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -210,8 +209,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
const T* in_rows, const T* in_rows,
T* src, T* src,
const T* dst_count) { const T* dst_count) {
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == CTA_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row = const int64_t last_row =
...@@ -221,7 +219,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -221,7 +219,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
hiprand_init( hiprand_init(
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#else #else
curandState rng; curandStatePhilox4_32_10_t rng;
curand_init( curand_init(
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#endif #endif
...@@ -229,7 +227,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -229,7 +227,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
while (out_row < last_row) { while (out_row < last_row) {
const T row = in_rows[out_row]; const T row = in_rows[out_row];
if (row > len_col_ptr - 1) { if (row > len_col_ptr - 1) {
out_row += BLOCK_WARPS; out_row += BLOCK_CTAS;
continue; continue;
} }
const T in_row_start = dst_count[row]; const T in_row_start = dst_count[row];
...@@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
} else { } else {
split = deg - k; split = deg - k;
} }
for (int idx = split + threadIdx.x; idx <= deg - 1; idx += WARP_SIZE) { for (int idx = split + threadIdx.x; idx <= deg - 1; idx += CTA_SIZE) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1); const int num = hiprand(&rng) % (idx + 1);
#else #else
...@@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
src[in_row_start + idx]))); src[in_row_start + idx])));
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
__syncwarp(); __syncthreads();
#endif #endif
} }
out_row += BLOCK_WARPS; out_row += BLOCK_CTAS;
} }
} }
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE> template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
__global__ void GatherEdge(int k, __global__ void GatherEdge(int k,
int64_t num_rows, int64_t num_rows,
const T* in_rows, const T* in_rows,
...@@ -273,8 +271,7 @@ __global__ void GatherEdge(int k, ...@@ -273,8 +271,7 @@ __global__ void GatherEdge(int k,
int* output_ptr, int* output_ptr,
T* perm_data, T* perm_data,
bool return_eids) { bool return_eids) {
assert(blockDim.x == WARP_SIZE); assert(blockDim.x == CTA_SIZE);
assert(blockDim.y == BLOCK_WARPS);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y; int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row = const int64_t last_row =
...@@ -287,7 +284,7 @@ __global__ void GatherEdge(int k, ...@@ -287,7 +284,7 @@ __global__ void GatherEdge(int k,
const T out_row_start = output_ptr[out_row]; const T out_row_start = output_ptr[out_row];
if (deg <= k) { if (deg <= k) {
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) { for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
outputs[out_row_start + idx] = src[in_row_start + idx]; outputs[out_row_start + idx] = src[in_row_start + idx];
if (return_eids) { if (return_eids) {
output_eids[out_row_start + idx] = eids[in_row_start + idx]; output_eids[out_row_start + idx] = eids[in_row_start + idx];
...@@ -304,7 +301,7 @@ __global__ void GatherEdge(int k, ...@@ -304,7 +301,7 @@ __global__ void GatherEdge(int k,
end = deg; end = deg;
} }
for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) { for (int idx = begin + threadIdx.x; idx < end; idx += CTA_SIZE) {
outputs[out_row_start + idx - begin] = outputs[out_row_start + idx - begin] =
src[perm_data[in_row_start + idx]]; src[perm_data[in_row_start + idx]];
if (return_eids) { if (return_eids) {
...@@ -313,7 +310,7 @@ __global__ void GatherEdge(int k, ...@@ -313,7 +310,7 @@ __global__ void GatherEdge(int k,
} }
} }
} }
out_row += BLOCK_WARPS; out_row += BLOCK_CTAS;
} }
} }
...@@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx, ...@@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
thrust::exclusive_scan( thrust::exclusive_scan(
output_count, output_count + bs, output_ptr.begin(), 0); output_count, output_count + bs, output_ptr.begin(), 0);
constexpr int WARP_SIZE = 32; constexpr int CTA_SIZE = 128;
constexpr int BLOCK_WARPS = 128 / WARP_SIZE; constexpr int BLOCK_CTAS = 128 / CTA_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16; constexpr int TILE_SIZE = BLOCK_CTAS;
const dim3 block(WARP_SIZE, BLOCK_WARPS); const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
FisherYatesSampleKernel<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE> FisherYatesSampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
<<<grid, block, 0, dev_ctx.stream()>>>(0, <<<grid, block, 0, dev_ctx.stream()>>>(0,
sample_size, sample_size,
bs, bs,
...@@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx, ...@@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
perm_data, perm_data,
col_ptr); col_ptr);
GatherEdge<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE> GatherEdge<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
<<<grid, block, 0, dev_ctx.stream()>>>( <<<grid, block, 0, dev_ctx.stream()>>>(
sample_size, sample_size,
bs, bs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册