未验证 提交 2632d77d 编写于 作者: 5 5u13 提交者: GitHub

optimization of max_pool3d forward (#45820)

上级 a001f263
...@@ -38,6 +38,24 @@ struct FastDivModForPooling { ...@@ -38,6 +38,24 @@ struct FastDivModForPooling {
} }
}; };
struct FastDivModForPooling3D {
public:
paddle::platform::FastDivMod channel;
paddle::platform::FastDivMod width;
paddle::platform::FastDivMod height;
paddle::platform::FastDivMod depth;
explicit HOSTDEVICE FastDivModForPooling3D(const int channels,
const int output_width,
const int output_height,
const int output_depth) {
channel = paddle::platform::FastDivMod(channels);
width = paddle::platform::FastDivMod(output_width);
height = paddle::platform::FastDivMod(output_height);
depth = paddle::platform::FastDivMod(output_depth);
}
};
struct FastDivModForPoolingWithMoreStaff { struct FastDivModForPoolingWithMoreStaff {
public: public:
paddle::platform::FastDivMod channel; paddle::platform::FastDivMod channel;
...@@ -2003,7 +2021,7 @@ template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>; ...@@ -2003,7 +2021,7 @@ template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>; template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>;
template <typename T1, typename T2> template <typename T1, typename T2>
__global__ void KernelMaxPool3DWithIdx(const int nthreads, __global__ void KernelMaxPool3DWithIdx(const int ncd,
const T1* input_data, const T1* input_data,
const int channels, const int channels,
const int input_depth, const int input_depth,
...@@ -2023,32 +2041,41 @@ __global__ void KernelMaxPool3DWithIdx(const int nthreads, ...@@ -2023,32 +2041,41 @@ __global__ void KernelMaxPool3DWithIdx(const int nthreads,
const int padding_width, const int padding_width,
bool adaptive, bool adaptive,
T1* output_data, T1* output_data,
T2* mask_data) { T2* mask_data,
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; FastDivModForPooling3D divmods_output) {
index += blockDim.x * gridDim.x) { int w_offset, h_offset, d_offset, nc_offset;
int pw = index % output_width; int dstart, dend, hstart, hend, wstart, wend;
int ph = (index / output_width) % output_height; const T1* input_data_cur;
int pd = (index / output_width / output_height) % output_depth;
int c = (index / output_width / output_height / output_depth) % channels; w_offset = blockIdx.x * blockDim.x + threadIdx.x;
int batch_idx = h_offset = blockIdx.y * blockDim.y + threadIdx.y;
index / output_width / output_height / output_depth / channels;
if (w_offset < output_width && h_offset < output_height) {
for (int index_z = blockIdx.z * blockDim.z + threadIdx.z; index_z < ncd;
index_z += gridDim.z * blockDim.z) {
auto output_depth_divmod = divmods_output.depth.Divmod(index_z);
d_offset = output_depth_divmod.val[1];
nc_offset = output_depth_divmod.val[0];
int output_index =
nc_offset * output_depth * output_height * output_width +
d_offset * output_height * output_width + h_offset * output_width +
w_offset;
int input_offset = nc_offset * input_depth * input_height * input_width;
input_data_cur = input_data + input_offset;
int dstart, dend;
int hstart, hend;
int wstart, wend;
if (adaptive) { if (adaptive) {
dstart = AdaptStartIndex(pd, input_depth, output_depth); dstart = AdaptStartIndex(d_offset, input_depth, output_depth);
dend = AdaptEndIndex(pd, input_depth, output_depth); dend = AdaptEndIndex(d_offset, input_depth, output_depth);
hstart = AdaptStartIndex(ph, input_height, output_height); hstart = AdaptStartIndex(h_offset, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height); hend = AdaptEndIndex(h_offset, input_height, output_height);
wstart = AdaptStartIndex(pw, input_width, output_width); wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width); wend = AdaptEndIndex(w_offset, input_width, output_width);
} else { } else {
dstart = pd * stride_depth - padding_depth; dstart = d_offset * stride_depth - padding_depth;
hstart = ph * stride_height - padding_height; hstart = h_offset * stride_height - padding_height;
wstart = pw * stride_width - padding_width; wstart = w_offset * stride_width - padding_width;
dend = min(dstart + ksize_depth, input_depth); dend = min(dstart + ksize_depth, input_depth);
hend = min(hstart + ksize_height, input_height); hend = min(hstart + ksize_height, input_height);
wend = min(wstart + ksize_width, input_width); wend = min(wstart + ksize_width, input_width);
...@@ -2059,21 +2086,20 @@ __global__ void KernelMaxPool3DWithIdx(const int nthreads, ...@@ -2059,21 +2086,20 @@ __global__ void KernelMaxPool3DWithIdx(const int nthreads,
T1 ele = -FLT_MAX; T1 ele = -FLT_MAX;
int max_index = -1; int max_index = -1;
input_data +=
(batch_idx * channels + c) * input_depth * input_height * input_width;
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (ele < input_data[(d * input_height + h) * input_width + w]) { if (ele <
input_data_cur[(d * input_height + h) * input_width + w]) {
max_index = (d * input_height + h) * input_width + w; max_index = (d * input_height + h) * input_width + w;
ele = input_data[max_index]; ele = input_data_cur[max_index];
} }
} }
} }
} }
output_data[index] = ele; output_data[output_index] = ele;
mask_data[index] = max_index; mask_data[output_index] = max_index;
}
} }
} }
...@@ -2201,19 +2227,25 @@ class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> { ...@@ -2201,19 +2227,25 @@ class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
T1* output_data = context.template Alloc<T1>(output); T1* output_data = context.template Alloc<T1>(output);
T2* mask_data = context.template Alloc<T2>(mask); T2* mask_data = context.template Alloc<T2>(mask);
int nthreads = batch_size * output_channels * output_depth * output_height * int ncd = batch_size * input_channels * output_depth;
output_width;
int thread_num = 1024;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(context, &thread_num);
#endif
int blocks = (nthreads + thread_num - 1) / thread_num; int thread_x = 32;
dim3 threads(thread_num, 1); int thread_y = 8;
dim3 grid(blocks, 1); int thread_z = 1;
dim3 threads(thread_x, thread_y, thread_z);
std::array<int, 3> max_grid_dim = context.GetCUDAMaxGridDimSize();
int block_x = (output_width + threads.x - 1) / threads.x;
int block_y = (output_height + threads.y - 1) / threads.y;
int block_z = (ncd > max_grid_dim[2] * threads.z)
? max_grid_dim[2]
: (ncd + threads.z - 1) / threads.z;
dim3 grid(block_x, block_y, block_z);
auto pool_divmods_output = FastDivModForPooling3D(
input_channels, output_width, output_height, output_depth);
KernelMaxPool3DWithIdx<T1, T2> KernelMaxPool3DWithIdx<T1, T2>
<<<grid, threads, 0, context.stream()>>>(nthreads, <<<grid, threads, 0, context.stream()>>>(ncd,
input_data, input_data,
input_channels, input_channels,
input_depth, input_depth,
...@@ -2233,7 +2265,8 @@ class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> { ...@@ -2233,7 +2265,8 @@ class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
padding_width, padding_width,
adaptive, adaptive,
output_data, output_data,
mask_data); mask_data,
pool_divmods_output);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册