未验证 提交 2632d77d 编写于 作者: 5 5u13 提交者: GitHub

optimization of max_pool3d forward (#45820)

上级 a001f263
......@@ -38,6 +38,24 @@ struct FastDivModForPooling {
}
};
struct FastDivModForPooling3D {
public:
paddle::platform::FastDivMod channel;
paddle::platform::FastDivMod width;
paddle::platform::FastDivMod height;
paddle::platform::FastDivMod depth;
explicit HOSTDEVICE FastDivModForPooling3D(const int channels,
const int output_width,
const int output_height,
const int output_depth) {
channel = paddle::platform::FastDivMod(channels);
width = paddle::platform::FastDivMod(output_width);
height = paddle::platform::FastDivMod(output_height);
depth = paddle::platform::FastDivMod(output_depth);
}
};
struct FastDivModForPoolingWithMoreStaff {
public:
paddle::platform::FastDivMod channel;
......@@ -2003,7 +2021,7 @@ template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>;
template <typename T1, typename T2>
__global__ void KernelMaxPool3DWithIdx(const int nthreads,
__global__ void KernelMaxPool3DWithIdx(const int ncd,
const T1* input_data,
const int channels,
const int input_depth,
......@@ -2023,57 +2041,65 @@ __global__ void KernelMaxPool3DWithIdx(const int nthreads,
const int padding_width,
bool adaptive,
T1* output_data,
T2* mask_data) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int pd = (index / output_width / output_height) % output_depth;
int c = (index / output_width / output_height / output_depth) % channels;
int batch_idx =
index / output_width / output_height / output_depth / channels;
int dstart, dend;
int hstart, hend;
int wstart, wend;
if (adaptive) {
dstart = AdaptStartIndex(pd, input_depth, output_depth);
dend = AdaptEndIndex(pd, input_depth, output_depth);
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
} else {
dstart = pd * stride_depth - padding_depth;
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
dend = min(dstart + ksize_depth, input_depth);
hend = min(hstart + ksize_height, input_height);
wend = min(wstart + ksize_width, input_width);
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
}
T1 ele = -FLT_MAX;
int max_index = -1;
input_data +=
(batch_idx * channels + c) * input_depth * input_height * input_width;
T2* mask_data,
FastDivModForPooling3D divmods_output) {
int w_offset, h_offset, d_offset, nc_offset;
int dstart, dend, hstart, hend, wstart, wend;
const T1* input_data_cur;
w_offset = blockIdx.x * blockDim.x + threadIdx.x;
h_offset = blockIdx.y * blockDim.y + threadIdx.y;
if (w_offset < output_width && h_offset < output_height) {
for (int index_z = blockIdx.z * blockDim.z + threadIdx.z; index_z < ncd;
index_z += gridDim.z * blockDim.z) {
auto output_depth_divmod = divmods_output.depth.Divmod(index_z);
d_offset = output_depth_divmod.val[1];
nc_offset = output_depth_divmod.val[0];
int output_index =
nc_offset * output_depth * output_height * output_width +
d_offset * output_height * output_width + h_offset * output_width +
w_offset;
int input_offset = nc_offset * input_depth * input_height * input_width;
input_data_cur = input_data + input_offset;
if (adaptive) {
dstart = AdaptStartIndex(d_offset, input_depth, output_depth);
dend = AdaptEndIndex(d_offset, input_depth, output_depth);
hstart = AdaptStartIndex(h_offset, input_height, output_height);
hend = AdaptEndIndex(h_offset, input_height, output_height);
wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width);
} else {
dstart = d_offset * stride_depth - padding_depth;
hstart = h_offset * stride_height - padding_height;
wstart = w_offset * stride_width - padding_width;
dend = min(dstart + ksize_depth, input_depth);
hend = min(hstart + ksize_height, input_height);
wend = min(wstart + ksize_width, input_width);
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
}
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele < input_data[(d * input_height + h) * input_width + w]) {
max_index = (d * input_height + h) * input_width + w;
ele = input_data[max_index];
T1 ele = -FLT_MAX;
int max_index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele <
input_data_cur[(d * input_height + h) * input_width + w]) {
max_index = (d * input_height + h) * input_width + w;
ele = input_data_cur[max_index];
}
}
}
}
output_data[output_index] = ele;
mask_data[output_index] = max_index;
}
output_data[index] = ele;
mask_data[index] = max_index;
}
}
......@@ -2201,19 +2227,25 @@ class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
T1* output_data = context.template Alloc<T1>(output);
T2* mask_data = context.template Alloc<T2>(mask);
int nthreads = batch_size * output_channels * output_depth * output_height *
output_width;
int thread_num = 1024;
#ifdef WITH_NV_JETSON
backends::gpu::ChangeThreadNum(context, &thread_num);
#endif
int ncd = batch_size * input_channels * output_depth;
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
int thread_x = 32;
int thread_y = 8;
int thread_z = 1;
dim3 threads(thread_x, thread_y, thread_z);
std::array<int, 3> max_grid_dim = context.GetCUDAMaxGridDimSize();
int block_x = (output_width + threads.x - 1) / threads.x;
int block_y = (output_height + threads.y - 1) / threads.y;
int block_z = (ncd > max_grid_dim[2] * threads.z)
? max_grid_dim[2]
: (ncd + threads.z - 1) / threads.z;
dim3 grid(block_x, block_y, block_z);
auto pool_divmods_output = FastDivModForPooling3D(
input_channels, output_width, output_height, output_depth);
KernelMaxPool3DWithIdx<T1, T2>
<<<grid, threads, 0, context.stream()>>>(nthreads,
<<<grid, threads, 0, context.stream()>>>(ncd,
input_data,
input_channels,
input_depth,
......@@ -2233,7 +2265,8 @@ class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
padding_width,
adaptive,
output_data,
mask_data);
mask_data,
pool_divmods_output);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册