未验证 提交 8f7c02f2 编写于 作者: H Haohongxiang 提交者: GitHub

[Op] Fix uncontrolled randomness of index_select op (#41078)

* fix uncontrolled randomness of op

* fix bugs
上级 eef46770
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS; using paddle::platform::PADDLE_CUDA_NUM_THREADS;
...@@ -32,16 +34,14 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad, ...@@ -32,16 +34,14 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
int64_t stride, int64_t stride,
int64_t size, int64_t size,
int64_t delta) { int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; CUDA_KERNEL_LOOP(idx, N) {
if (idx >= N) { int64_t pre_idx = idx / (stride * size);
return; int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx =
idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
} }
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
} }
template <typename T> template <typename T>
...@@ -95,34 +95,38 @@ void IndexSelectGradKernel(const Context& ctx, ...@@ -95,34 +95,38 @@ void IndexSelectGradKernel(const Context& ctx,
0, 0,
stream>>>(in_grad_data, numel); stream>>>(in_grad_data, numel);
int blocks =
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
int threads = PADDLE_CUDA_NUM_THREADS;
if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of index_select with single thread.";
blocks = 1;
threads = 1;
}
if (index_type == phi::DataType::INT64) { if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>(); const int64_t* index_data = index.data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<< index_select_grad_cuda_kernel<T, int64_t><<<blocks, threads, 0, stream>>>(
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, output_grad_data,
PADDLE_CUDA_NUM_THREADS, in_grad_data,
0, index_data,
stream>>>(output_grad_data, index_nums,
in_grad_data, out_nums,
index_data, stride,
index_nums, size,
out_nums, delta);
stride,
size,
delta);
} else { } else {
const int* index_data = index.data<int>(); const int* index_data = index.data<int>();
index_select_grad_cuda_kernel<T, int><<< index_select_grad_cuda_kernel<T, int><<<blocks, threads, 0, stream>>>(
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, output_grad_data,
PADDLE_CUDA_NUM_THREADS, in_grad_data,
0, index_data,
stream>>>(output_grad_data, index_nums,
in_grad_data, out_nums,
index_data, stride,
index_nums, size,
out_nums, delta);
stride,
size,
delta);
} }
} }
......
...@@ -69,6 +69,17 @@ class TestIndexSelectOpCase2(TestIndexSelectOp): ...@@ -69,6 +69,17 @@ class TestIndexSelectOpCase2(TestIndexSelectOp):
self.index_size = 10 self.index_size = 10
class TestIndexSelectOpCaseSingleThread(TestIndexSelectOp):
def init_dtype_type(self):
if fluid.is_compiled_with_cuda():
fluid.set_flags({'FLAGS_cudnn_deterministic': True})
self.x_type = np.float32
self.index_type = np.int32
self.dim = -2
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestIndexSelectAPI(unittest.TestCase): class TestIndexSelectAPI(unittest.TestCase):
def input_data(self): def input_data(self):
self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册