From 8f7c02f2ced02f7930f8d5ae485e91c65cd376e2 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Wed, 30 Mar 2022 16:04:42 +0800 Subject: [PATCH] [Op] Fix uncontrolled randomness of index_select op (#41078) * fix uncontrolled randomness of op * fix bugs --- .../kernels/gpu/index_select_grad_kernel.cu | 70 ++++++++++--------- .../tests/unittests/test_index_select_op.py | 11 +++ 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu index c63063bc578..b3bd307e2aa 100644 --- a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu @@ -19,6 +19,8 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" +DECLARE_bool(cudnn_deterministic); + namespace phi { using paddle::platform::PADDLE_CUDA_NUM_THREADS; @@ -32,16 +34,14 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad, int64_t stride, int64_t size, int64_t delta) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= N) { - return; + CUDA_KERNEL_LOOP(idx, N) { + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = + idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]); } - - int64_t pre_idx = idx / (stride * size); - int64_t dim_idx = idx % (stride * size) / stride; - IndexT src_dim_idx = index[dim_idx]; - int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; - paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]); } template @@ -95,34 +95,38 @@ void IndexSelectGradKernel(const Context& ctx, 0, stream>>>(in_grad_data, numel); + int blocks = + (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; + int threads = PADDLE_CUDA_NUM_THREADS; + + if (FLAGS_cudnn_deterministic) { + VLOG(2) << "Run grad kernel of index_select with single thread."; + blocks = 1; + threads = 1; + } + if (index_type == phi::DataType::INT64) { const int64_t* index_data = index.data(); - index_select_grad_cuda_kernel<<< - (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(output_grad_data, - in_grad_data, - index_data, - index_nums, - out_nums, - stride, - size, - delta); + index_select_grad_cuda_kernel<<>>( + output_grad_data, + in_grad_data, + index_data, + index_nums, + out_nums, + stride, + size, + delta); } else { const int* index_data = index.data(); - index_select_grad_cuda_kernel<<< - (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, - 0, - stream>>>(output_grad_data, - in_grad_data, - index_data, - index_nums, - out_nums, - stride, - size, - delta); + index_select_grad_cuda_kernel<<>>( + output_grad_data, + in_grad_data, + index_data, + index_nums, + out_nums, + stride, + size, + delta); } } diff --git a/python/paddle/fluid/tests/unittests/test_index_select_op.py b/python/paddle/fluid/tests/unittests/test_index_select_op.py index e551989ed32..f4545d40690 100644 --- a/python/paddle/fluid/tests/unittests/test_index_select_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_select_op.py @@ -69,6 +69,17 @@ class TestIndexSelectOpCase2(TestIndexSelectOp): self.index_size = 10 +class TestIndexSelectOpCaseSingleThread(TestIndexSelectOp): + def init_dtype_type(self): + if fluid.is_compiled_with_cuda(): + fluid.set_flags({'FLAGS_cudnn_deterministic': True}) + self.x_type = np.float32 + self.index_type = np.int32 + self.dim = -2 + self.x_shape = (10, 10, 4, 10) + self.index_size = 10 + + class TestIndexSelectAPI(unittest.TestCase): def input_data(self): self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], -- GitLab