未验证 提交 aac00f6a 编写于 作者: H Haohongxiang 提交者: GitHub

optimize backward (#37055)

上级 71816707
......@@ -54,14 +54,18 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
int64_t begin_idx = idx + (delta * pre_idx - dim_idx) * stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
input_grad[idx] = 0.0;
for (int64_t i = 0; i < nums; i++) {
if (index[i] == dim_idx) {
input_grad[idx] += output_grad[begin_idx + i * stride];
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
template <typename DeviceContext, typename T>
......@@ -143,8 +147,8 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = framework::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = input_dim[dim];
int64_t delta = output_dim[dim] - size;
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
......@@ -161,17 +165,22 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
int64_t numel = in_grad->numel();
int64_t index_nums = index->numel();
int64_t out_nums = output_grad->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel);
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums, numel,
stride, size, delta);
index_data, index_nums,
out_nums, stride, size, delta);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
......@@ -180,10 +189,10 @@ class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
} else {
const int* index_data = index->data<int>();
index_select_grad_cuda_kernel<T, int><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums, numel,
stride, size, delta);
index_data, index_nums,
out_nums, stride, size, delta);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
#else
......@@ -201,12 +210,16 @@ REGISTER_OP_CUDA_KERNEL(
index_select,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_select_grad,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册