未验证 提交 aac00f6a 编写于 作者: H Haohongxiang 提交者: GitHub

optimize backward (#37055)

上级 71816707
...@@ -54,14 +54,18 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad, ...@@ -54,14 +54,18 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
int64_t pre_idx = idx / (stride * size); int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride; int64_t dim_idx = idx % (stride * size) / stride;
int64_t begin_idx = idx + (delta * pre_idx - dim_idx) * stride; IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
input_grad[idx] = 0.0; template <typename T>
for (int64_t i = 0; i < nums; i++) { __global__ void index_select_grad_init(T* input_grad, int64_t N) {
if (index[i] == dim_idx) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
input_grad[idx] += output_grad[begin_idx + i * stride]; if (idx >= N) {
} return;
} }
input_grad[idx] = 0.0;
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -143,8 +147,8 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> { ...@@ -143,8 +147,8 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
dim = dim >= 0 ? dim : dim + input_dim.size(); dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = framework::stride(input_dim); auto stride_dim = framework::stride(input_dim);
int64_t stride = stride_dim[dim]; int64_t stride = stride_dim[dim];
int64_t size = input_dim[dim]; int64_t size = output_dim[dim];
int64_t delta = output_dim[dim] - size; int64_t delta = input_dim[dim] - size;
const auto& index_type = index->type(); const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 || bool index_type_match = index_type == framework::proto::VarType::INT64 ||
...@@ -161,17 +165,22 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> { ...@@ -161,17 +165,22 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
int64_t numel = in_grad->numel(); int64_t numel = in_grad->numel();
int64_t index_nums = index->numel(); int64_t index_nums = index->numel();
int64_t out_nums = output_grad->numel();
auto stream = auto stream =
context.template device_context<platform::CUDADeviceContext>().stream(); context.template device_context<platform::CUDADeviceContext>().stream();
index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel);
if (index_type == framework::proto::VarType::INT64) { if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>(); const int64_t* index_data = index->data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<< index_select_grad_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums, numel, index_data, index_nums,
stride, size, delta); out_nums, stride, size, delta);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else #else
...@@ -180,10 +189,10 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> { ...@@ -180,10 +189,10 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
} else { } else {
const int* index_data = index->data<int>(); const int* index_data = index->data<int>();
index_select_grad_cuda_kernel<T, int><<< index_select_grad_cuda_kernel<T, int><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, (out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums, numel, index_data, index_nums,
stride, size, delta); out_nums, stride, size, delta);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else #else
...@@ -201,12 +210,16 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -201,12 +210,16 @@ REGISTER_OP_CUDA_KERNEL(
index_select, index_select,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>, ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>, ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>, ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
index_select_grad, index_select_grad,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>, ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>, ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>, ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>); int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册