未验证 提交 2a932e55 编写于 作者: S Siming Dai 提交者: GitHub

[geometric] Optimize graph sample speed (#47531)

上级 32efda3d
......@@ -58,7 +58,7 @@ struct MaxFunctor {
}
};
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
__global__ void SampleKernel(const uint64_t rand_seed,
int k,
const int64_t num_nodes,
......@@ -71,8 +71,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
T* output_eids,
int* output_ptr,
bool return_eids) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
assert(blockDim.x == CTA_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
......@@ -80,13 +79,13 @@ __global__ void SampleKernel(const uint64_t rand_seed,
#ifdef PADDLE_WITH_HIP
hiprandState rng;
hiprand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x,
threadIdx.y * CTA_SIZE + threadIdx.x,
0,
&rng);
#else
curandState rng;
curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x,
threadIdx.y * WARP_SIZE + threadIdx.x,
threadIdx.y * CTA_SIZE + threadIdx.x,
0,
&rng);
#endif
......@@ -94,7 +93,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
while (out_row < last_row) {
T node = nodes[out_row];
if (node > len_col_ptr - 1) {
out_row += BLOCK_WARPS;
out_row += BLOCK_CTAS;
continue;
}
T in_row_start = col_ptr[node];
......@@ -102,21 +101,21 @@ __global__ void SampleKernel(const uint64_t rand_seed,
int out_row_start = output_ptr[out_row];
if (deg <= k) {
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
output[out_row_start + idx] = row[in_row_start + idx];
if (return_eids) {
output_eids[out_row_start + idx] = eids[in_row_start + idx];
}
}
} else {
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) {
output[out_row_start + idx] = idx;
}
#ifdef PADDLE_WITH_CUDA
__syncwarp();
__syncthreads();
#endif
for (int idx = k + threadIdx.x; idx < deg; idx += WARP_SIZE) {
for (int idx = k + threadIdx.x; idx < deg; idx += CTA_SIZE) {
#ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1);
#else
......@@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
}
}
#ifdef PADDLE_WITH_CUDA
__syncwarp();
__syncthreads();
#endif
for (int idx = threadIdx.x; idx < k; idx += WARP_SIZE) {
for (int idx = threadIdx.x; idx < k; idx += CTA_SIZE) {
T perm_idx = output[out_row_start + idx] + in_row_start;
output[out_row_start + idx] = row[perm_idx];
if (return_eids) {
......@@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
}
}
out_row += BLOCK_WARPS;
out_row += BLOCK_CTAS;
}
}
......@@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx,
thrust::exclusive_scan(
output_count, output_count + bs, output_ptr.begin(), 0);
constexpr int WARP_SIZE = 32;
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
constexpr int CTA_SIZE = 128;
constexpr int BLOCK_CTAS = 128 / CTA_SIZE;
constexpr int TILE_SIZE = BLOCK_CTAS;
const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
SampleKernel<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
SampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
<<<grid, block, 0, dev_ctx.stream()>>>(
0,
sample_size,
......@@ -202,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx,
return_eids);
}
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
int k,
const int64_t num_rows,
......@@ -210,8 +209,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
const T* in_rows,
T* src,
const T* dst_count) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
assert(blockDim.x == CTA_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
......@@ -221,7 +219,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
hiprand_init(
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#else
curandState rng;
curandStatePhilox4_32_10_t rng;
curand_init(
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#endif
......@@ -229,7 +227,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
while (out_row < last_row) {
const T row = in_rows[out_row];
if (row > len_col_ptr - 1) {
out_row += BLOCK_WARPS;
out_row += BLOCK_CTAS;
continue;
}
const T in_row_start = dst_count[row];
......@@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
} else {
split = deg - k;
}
for (int idx = split + threadIdx.x; idx <= deg - 1; idx += WARP_SIZE) {
for (int idx = split + threadIdx.x; idx <= deg - 1; idx += CTA_SIZE) {
#ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1);
#else
......@@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
src[in_row_start + idx])));
}
#ifdef PADDLE_WITH_CUDA
__syncwarp();
__syncthreads();
#endif
}
out_row += BLOCK_WARPS;
out_row += BLOCK_CTAS;
}
}
template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
__global__ void GatherEdge(int k,
int64_t num_rows,
const T* in_rows,
......@@ -273,8 +271,7 @@ __global__ void GatherEdge(int k,
int* output_ptr,
T* perm_data,
bool return_eids) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);
assert(blockDim.x == CTA_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
......@@ -287,7 +284,7 @@ __global__ void GatherEdge(int k,
const T out_row_start = output_ptr[out_row];
if (deg <= k) {
for (int idx = threadIdx.x; idx < deg; idx += WARP_SIZE) {
for (int idx = threadIdx.x; idx < deg; idx += CTA_SIZE) {
outputs[out_row_start + idx] = src[in_row_start + idx];
if (return_eids) {
output_eids[out_row_start + idx] = eids[in_row_start + idx];
......@@ -304,7 +301,7 @@ __global__ void GatherEdge(int k,
end = deg;
}
for (int idx = begin + threadIdx.x; idx < end; idx += WARP_SIZE) {
for (int idx = begin + threadIdx.x; idx < end; idx += CTA_SIZE) {
outputs[out_row_start + idx - begin] =
src[perm_data[in_row_start + idx]];
if (return_eids) {
......@@ -313,7 +310,7 @@ __global__ void GatherEdge(int k,
}
}
}
out_row += BLOCK_WARPS;
out_row += BLOCK_CTAS;
}
}
......@@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
thrust::exclusive_scan(
output_count, output_count + bs, output_ptr.begin(), 0);
constexpr int WARP_SIZE = 32;
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
constexpr int CTA_SIZE = 128;
constexpr int BLOCK_CTAS = 128 / CTA_SIZE;
constexpr int TILE_SIZE = BLOCK_CTAS;
const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
FisherYatesSampleKernel<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
FisherYatesSampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
<<<grid, block, 0, dev_ctx.stream()>>>(0,
sample_size,
bs,
......@@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
perm_data,
col_ptr);
GatherEdge<T, WARP_SIZE, BLOCK_WARPS, TILE_SIZE>
GatherEdge<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
<<<grid, block, 0, dev_ctx.stream()>>>(
sample_size,
bs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册