未验证 提交 0e563da6 编写于 作者: 5 5u13 提交者: GitHub

optimization of max_pool3d grad (#45934)

上级 6d067860
...@@ -2319,87 +2319,52 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd, ...@@ -2319,87 +2319,52 @@ __global__ void KernelMaxPool3DWithIdx(const int ncd,
} }
template <typename T1, typename T2> template <typename T1, typename T2>
__global__ void KernelMaxPool3DWithIdxGrad(const int nthreads, __global__ void KernelMaxPool3DWithIdxGrad(
const T1* output_grad, const int ncd,
const T2* mask, const T1* output_grad,
const int channels, const T2* mask,
const int input_depth, const int channels,
const int input_height, const int input_depth,
const int input_width, const int input_height,
const int output_depth, const int input_width,
const int output_height, const int output_depth,
const int output_width, const int output_height,
const int ksize_depth, const int output_width,
const int ksize_height, const int ksize_depth,
const int ksize_width, const int ksize_height,
const int stride_depth, const int ksize_width,
const int stride_height, const int stride_depth,
const int stride_width, const int stride_height,
const int padding_depth, const int stride_width,
const int padding_height, const int padding_depth,
const int padding_width, const int padding_height,
bool adaptive, const int padding_width,
T1* input_grad) { bool adaptive,
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; T1* input_grad,
index += blockDim.x * gridDim.x) { FastDivModForPooling3D divmods_output) {
int w_offset = index % input_width; int w_offset, h_offset, d_offset, nc_offset;
int h_offset = (index / input_width) % input_height;
int d_offset = (index / input_width / input_height) % input_depth;
int c_offset =
(index / input_width / input_height / input_depth) % channels;
int batch_idx = index / input_width / input_height / input_depth / channels;
int pdstart, pdend;
int phstart, phend;
int pwstart, pwend;
if (adaptive) {
pdstart = d_offset * output_depth / input_depth;
pdend =
min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
phstart = h_offset * output_height / input_height;
phend =
min((h_offset + 1) * output_height / input_height + 1, output_height);
pwstart = w_offset * output_width / input_width;
pwend =
min((w_offset + 1) * output_width / input_width + 1, output_width);
} else {
pdstart =
(d_offset + padding_depth < ksize_depth)
? 0
: (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
phstart =
(h_offset + padding_height < ksize_height)
? 0
: (h_offset + padding_height - ksize_height) / stride_height + 1;
pwstart =
(w_offset + padding_width < ksize_width)
? 0
: (w_offset + padding_width - ksize_width) / stride_width + 1;
pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
phend =
min((h_offset + padding_height) / stride_height + 1, output_height);
pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
}
T1 input_grad_data = 0; w_offset = blockIdx.x * blockDim.x + threadIdx.x;
int input_current_feature_map_idx = h_offset = blockIdx.y * blockDim.y + threadIdx.y;
(d_offset * input_height + h_offset) * input_width + w_offset;
int output_idx = (batch_idx * channels + c_offset) * output_depth *
output_height * output_width;
mask += output_idx;
output_grad += output_idx;
for (int pd = pdstart; pd < pdend; ++pd) { if (w_offset < output_width && h_offset < output_height) {
for (int ph = phstart; ph < phend; ++ph) { for (int index_z = blockIdx.z * blockDim.z + threadIdx.z; index_z < ncd;
for (int pw = pwstart; pw < pwend; ++pw) { index_z += gridDim.z * blockDim.z) {
if (mask[(pd * output_height + ph) * output_width + pw] == auto output_depth_divmod = divmods_output.depth.Divmod(index_z);
input_current_feature_map_idx) d_offset = output_depth_divmod.val[1];
input_grad_data += nc_offset = output_depth_divmod.val[0];
output_grad[(pd * output_height + ph) * output_width + pw]; int output_index =
} nc_offset * output_depth * output_height * output_width +
d_offset * output_height * output_width + h_offset * output_width +
w_offset;
int max_index = mask[output_index];
if (max_index != -1) {
paddle::platform::CudaAtomicAdd(
&input_grad[nc_offset * input_depth * input_height * input_width +
max_index],
output_grad[output_index]);
} }
} }
input_grad[index] = input_grad_data;
} }
} }
...@@ -2523,14 +2488,25 @@ class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> { ...@@ -2523,14 +2488,25 @@ class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
const T2* mask_data = mask.data<T2>(); const T2* mask_data = mask.data<T2>();
T1* input_grad_data = context.template Alloc<T1>(input_grad); T1* input_grad_data = context.template Alloc<T1>(input_grad);
int nthreads = int ncd = batch_size * input_channels * output_depth;
batch_size * input_channels * input_depth * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024; int thread_x = 32;
dim3 threads(1024, 1); int thread_y = 8;
dim3 grid(blocks, 1); int thread_z = 1;
dim3 threads(thread_x, thread_y, thread_z);
std::array<int, 3> max_grid_dim = context.GetCUDAMaxGridDimSize();
int block_x = (output_width + threads.x - 1) / threads.x;
int block_y = (output_height + threads.y - 1) / threads.y;
int block_z = (ncd > max_grid_dim[2] * threads.z)
? max_grid_dim[2]
: (ncd + threads.z - 1) / threads.z;
dim3 grid(block_x, block_y, block_z);
auto pool_divmods_output = FastDivModForPooling3D(
input_channels, output_width, output_height, output_depth);
KernelMaxPool3DWithIdxGrad<T1, T2> KernelMaxPool3DWithIdxGrad<T1, T2>
<<<grid, threads, 0, context.stream()>>>(nthreads, <<<grid, threads, 0, context.stream()>>>(ncd,
output_grad_data, output_grad_data,
mask_data, mask_data,
input_channels, input_channels,
...@@ -2550,7 +2526,8 @@ class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> { ...@@ -2550,7 +2526,8 @@ class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
padding_height, padding_height,
padding_width, padding_width,
adaptive, adaptive,
input_grad_data); input_grad_data,
pool_divmods_output);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册