From aac00f6a08142d17c1a4c8c09db936c742801f9f Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Tue, 9 Nov 2021 14:22:53 +0800 Subject: [PATCH] optimize backward (#37055) --- paddle/fluid/operators/index_select_op.cu | 41 +++++++++++++++-------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/index_select_op.cu b/paddle/fluid/operators/index_select_op.cu index 43761d97962..2353781daaa 100644 --- a/paddle/fluid/operators/index_select_op.cu +++ b/paddle/fluid/operators/index_select_op.cu @@ -54,14 +54,18 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad, int64_t pre_idx = idx / (stride * size); int64_t dim_idx = idx % (stride * size) / stride; - int64_t begin_idx = idx + (delta * pre_idx - dim_idx) * stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]); +} - input_grad[idx] = 0.0; - for (int64_t i = 0; i < nums; i++) { - if (index[i] == dim_idx) { - input_grad[idx] += output_grad[begin_idx + i * stride]; - } +template +__global__ void index_select_grad_init(T* input_grad, int64_t N) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; } + input_grad[idx] = 0.0; } template @@ -143,8 +147,8 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel { dim = dim >= 0 ? dim : dim + input_dim.size(); auto stride_dim = framework::stride(input_dim); int64_t stride = stride_dim[dim]; - int64_t size = input_dim[dim]; - int64_t delta = output_dim[dim] - size; + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; const auto& index_type = index->type(); bool index_type_match = index_type == framework::proto::VarType::INT64 || @@ -161,17 +165,22 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel { int64_t numel = in_grad->numel(); int64_t index_nums = index->numel(); + int64_t out_nums = output_grad->numel(); auto stream = context.template device_context().stream(); + index_select_grad_init< + T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel); + if (index_type == framework::proto::VarType::INT64) { const int64_t* index_data = index->data(); index_select_grad_cuda_kernel<<< - (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, - index_data, index_nums, numel, - stride, size, delta); + index_data, index_nums, + out_nums, stride, size, delta); #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); #else @@ -180,10 +189,10 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel { } else { const int* index_data = index->data(); index_select_grad_cuda_kernel<<< - (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, - index_data, index_nums, numel, - stride, size, delta); + index_data, index_nums, + out_nums, stride, size, delta); #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); #else @@ -201,12 +210,16 @@ REGISTER_OP_CUDA_KERNEL( index_select, ops::IndexSelectCUDAKernel, ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel, ops::IndexSelectCUDAKernel, ops::IndexSelectCUDAKernel); REGISTER_OP_CUDA_KERNEL( index_select_grad, ops::IndexSelectGradCUDAKernel, ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel, ops::IndexSelectGradCUDAKernel, ops::IndexSelectGradCUDAKernel); -- GitLab