未验证 提交 f311a927 编写于 作者: S Siming Dai 提交者: GitHub

fix fisher yates sample (#55329)

上级 27fd2bc2
......@@ -201,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx,
return_eids);
}
template <typename T, int CTA_SIZE, int BLOCK_CTAS, int TILE_SIZE>
template <typename T>
__global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
int k,
const int64_t num_rows,
......@@ -209,11 +209,6 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
const T* in_rows,
T* src,
const T* dst_count) {
assert(blockDim.x == CTA_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE + threadIdx.y;
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
#ifdef PADDLE_WITH_HIP
hiprandState rng;
hiprand_init(
......@@ -224,10 +219,9 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
rand_seed * gridDim.x + blockIdx.x, threadIdx.y + threadIdx.x, 0, &rng);
#endif
while (out_row < last_row) {
CUDA_KERNEL_LOOP(out_row, num_rows) {
const T row = in_rows[out_row];
if (row > len_col_ptr - 1) {
out_row += BLOCK_CTAS;
continue;
}
const T in_row_start = dst_count[row];
......@@ -239,7 +233,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
} else {
split = deg - k;
}
for (int idx = split + threadIdx.x; idx <= deg - 1; idx += CTA_SIZE) {
for (int idx = split; idx <= deg - 1; idx++) {
#ifdef PADDLE_WITH_HIP
const int num = hiprand(&rng) % (idx + 1);
#else
......@@ -251,11 +245,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
static_cast<unsigned long long int>( // NOLINT
src[in_row_start + idx])));
}
#ifdef PADDLE_WITH_CUDA
__syncthreads();
#endif
}
out_row += BLOCK_CTAS;
}
}
......@@ -340,7 +330,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
const dim3 block(CTA_SIZE, BLOCK_CTAS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
FisherYatesSampleKernel<T, CTA_SIZE, BLOCK_CTAS, TILE_SIZE>
FisherYatesSampleKernel<T>
<<<grid, block, 0, dev_ctx.stream()>>>(0,
sample_size,
bs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册