diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index 6e4fc414afd4a34c2cdcf70fa8091b87c8982397..d8cc11e02eeea398ce239ca7af0f52214c6c6d8f 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -38,6 +38,24 @@ struct FastDivModForPooling { } }; +struct FastDivModForPooling3D { + public: + paddle::platform::FastDivMod channel; + paddle::platform::FastDivMod width; + paddle::platform::FastDivMod height; + paddle::platform::FastDivMod depth; + + explicit HOSTDEVICE FastDivModForPooling3D(const int channels, + const int output_width, + const int output_height, + const int output_depth) { + channel = paddle::platform::FastDivMod(channels); + width = paddle::platform::FastDivMod(output_width); + height = paddle::platform::FastDivMod(output_height); + depth = paddle::platform::FastDivMod(output_depth); + } +}; + struct FastDivModForPoolingWithMoreStaff { public: paddle::platform::FastDivMod channel; @@ -2003,7 +2021,7 @@ template class MaxPool2dWithIndexFunctor; template class MaxPool2dWithIndexGradFunctor; template -__global__ void KernelMaxPool3DWithIdx(const int nthreads, +__global__ void KernelMaxPool3DWithIdx(const int ncd, const T1* input_data, const int channels, const int input_depth, @@ -2023,57 +2041,65 @@ __global__ void KernelMaxPool3DWithIdx(const int nthreads, const int padding_width, bool adaptive, T1* output_data, - T2* mask_data) { - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; - index += blockDim.x * gridDim.x) { - int pw = index % output_width; - int ph = (index / output_width) % output_height; - int pd = (index / output_width / output_height) % output_depth; - int c = (index / output_width / output_height / output_depth) % channels; - int batch_idx = - index / output_width / output_height / output_depth / channels; - - int dstart, dend; - int hstart, hend; - int wstart, wend; - if (adaptive) { - dstart = AdaptStartIndex(pd, input_depth, output_depth); - dend = AdaptEndIndex(pd, input_depth, output_depth); - - hstart = AdaptStartIndex(ph, input_height, output_height); - hend = AdaptEndIndex(ph, input_height, output_height); - - wstart = AdaptStartIndex(pw, input_width, output_width); - wend = AdaptEndIndex(pw, input_width, output_width); - } else { - dstart = pd * stride_depth - padding_depth; - hstart = ph * stride_height - padding_height; - wstart = pw * stride_width - padding_width; - dend = min(dstart + ksize_depth, input_depth); - hend = min(hstart + ksize_height, input_height); - wend = min(wstart + ksize_width, input_width); - dstart = max(dstart, 0); - hstart = max(hstart, 0); - wstart = max(wstart, 0); - } - - T1 ele = -FLT_MAX; - int max_index = -1; - input_data += - (batch_idx * channels + c) * input_depth * input_height * input_width; + T2* mask_data, + FastDivModForPooling3D divmods_output) { + int w_offset, h_offset, d_offset, nc_offset; + int dstart, dend, hstart, hend, wstart, wend; + const T1* input_data_cur; + + w_offset = blockIdx.x * blockDim.x + threadIdx.x; + h_offset = blockIdx.y * blockDim.y + threadIdx.y; + + if (w_offset < output_width && h_offset < output_height) { + for (int index_z = blockIdx.z * blockDim.z + threadIdx.z; index_z < ncd; + index_z += gridDim.z * blockDim.z) { + auto output_depth_divmod = divmods_output.depth.Divmod(index_z); + d_offset = output_depth_divmod.val[1]; + nc_offset = output_depth_divmod.val[0]; + int output_index = + nc_offset * output_depth * output_height * output_width + + d_offset * output_height * output_width + h_offset * output_width + + w_offset; + int input_offset = nc_offset * input_depth * input_height * input_width; + input_data_cur = input_data + input_offset; + + if (adaptive) { + dstart = AdaptStartIndex(d_offset, input_depth, output_depth); + dend = AdaptEndIndex(d_offset, input_depth, output_depth); + + hstart = AdaptStartIndex(h_offset, input_height, output_height); + hend = AdaptEndIndex(h_offset, input_height, output_height); + + wstart = AdaptStartIndex(w_offset, input_width, output_width); + wend = AdaptEndIndex(w_offset, input_width, output_width); + } else { + dstart = d_offset * stride_depth - padding_depth; + hstart = h_offset * stride_height - padding_height; + wstart = w_offset * stride_width - padding_width; + dend = min(dstart + ksize_depth, input_depth); + hend = min(hstart + ksize_height, input_height); + wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + } - for (int d = dstart; d < dend; ++d) { - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - if (ele < input_data[(d * input_height + h) * input_width + w]) { - max_index = (d * input_height + h) * input_width + w; - ele = input_data[max_index]; + T1 ele = -FLT_MAX; + int max_index = -1; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < + input_data_cur[(d * input_height + h) * input_width + w]) { + max_index = (d * input_height + h) * input_width + w; + ele = input_data_cur[max_index]; + } } } } + output_data[output_index] = ele; + mask_data[output_index] = max_index; } - output_data[index] = ele; - mask_data[index] = max_index; } } @@ -2201,19 +2227,25 @@ class MaxPool3dWithIndexFunctor { T1* output_data = context.template Alloc(output); T2* mask_data = context.template Alloc(mask); - int nthreads = batch_size * output_channels * output_depth * output_height * - output_width; - int thread_num = 1024; -#ifdef WITH_NV_JETSON - backends::gpu::ChangeThreadNum(context, &thread_num); -#endif + int ncd = batch_size * input_channels * output_depth; - int blocks = (nthreads + thread_num - 1) / thread_num; - dim3 threads(thread_num, 1); - dim3 grid(blocks, 1); + int thread_x = 32; + int thread_y = 8; + int thread_z = 1; + dim3 threads(thread_x, thread_y, thread_z); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + int block_x = (output_width + threads.x - 1) / threads.x; + int block_y = (output_height + threads.y - 1) / threads.y; + int block_z = (ncd > max_grid_dim[2] * threads.z) + ? max_grid_dim[2] + : (ncd + threads.z - 1) / threads.z; + dim3 grid(block_x, block_y, block_z); + + auto pool_divmods_output = FastDivModForPooling3D( + input_channels, output_width, output_height, output_depth); KernelMaxPool3DWithIdx - <<>>(nthreads, + <<>>(ncd, input_data, input_channels, input_depth, @@ -2233,7 +2265,8 @@ class MaxPool3dWithIndexFunctor { padding_width, adaptive, output_data, - mask_data); + mask_data, + pool_divmods_output); } };