未验证 提交 f311a927 编写于 作者: S Siming Dai 提交者: GitHub

fix fisher yates sample (#55329)

上级 27fd2bc2
...@@ -201,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx, ...@@ -201,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx,
return_eids); return_eids);
} }
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE> template <typename T>
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed, __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
int k, int k,
const int64_t num_rows, const int64_t num_rows,
...@@ -209,11 +209,6 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -209,11 +209,6 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
const T* in_rows, const T* in_rows,
T* src, T* src,
const T* dst_count) { const T* dst_count) {
assert(blockDim.x == CTA_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
hiprandState rng; hiprandState rng;
hiprand_init( hiprand_init(
...@@ -224,10 +219,9 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -224,10 +219,9 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng); rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#endif #endif
while (out_row < last_row) { CUDA_KERNEL_LOOP(out_row, num_rows) {
const T row = in_rows[out_row]; const T row = in_rows[out_row];
if (row > len_col_ptr - 1) { if (row > len_col_ptr - 1) {
out_row += BLOCK_CTAS;
continue; continue;
} }
const T in_row_start = dst_count[row]; const T in_row_start = dst_count[row];
...@@ -239,7 +233,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -239,7 +233,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
} else { } else {
split = deg - k; split = deg - k;
} }
for (int idx = split + threadIdx.x; idx <= deg - 1; idx += CTA_SIZE) { for (int idx = split; idx <= deg - 1; idx++) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1); const int num = hiprand(&rng) % (idx + 1);
#else #else
...@@ -251,11 +245,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed, ...@@ -251,11 +245,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
static_cast<unsigned long long int>( // NOLINT static_cast<unsigned long long int>( // NOLINT
src[in_row_start + idx]))); src[in_row_start + idx])));
} }
#ifdef PADDLE_WITH_CUDA
__syncthreads();
#endif
} }
out_row += BLOCK_CTAS;
} }
} }
...@@ -340,7 +330,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx, ...@@ -340,7 +330,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
const dim3 block(CTA_SIZE, BLOCK_CTAS); const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE); const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
FisherYatesSampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE> FisherYatesSampleKernel<T>
<<<grid, block, 0, dev_ctx.stream()>>>(0, <<<grid, block, 0, dev_ctx.stream()>>>(0,
sample_size, sample_size,
bs, bs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册