From d8afe40737e202b04bc03c2111afe34230ec5afa Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Mon, 12 Apr 2021 14:26:41 +0800 Subject: [PATCH] Optimization of bilinear backward OP CUDA kernel. (#30950) --- paddle/fluid/operators/interpolate_v2_op.cu | 288 +++++++++++++++----- 1 file changed, 218 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/interpolate_v2_op.cu b/paddle/fluid/operators/interpolate_v2_op.cu index 90abcaa8b47..9c19278ac4d 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cu +++ b/paddle/fluid/operators/interpolate_v2_op.cu @@ -12,6 +12,8 @@ #include #include #include "paddle/fluid/operators/interpolate_v2_op.h" +#include "paddle/fluid/operators/math/math_cuda_utils.h" +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_launch_config.h" @@ -302,81 +304,214 @@ __global__ void KeBilinearInterpFw( } template -__global__ void KeBilinearInterpBw( - T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, - const size_t input_w, const T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const T ratio_h, const T ratio_w, - const bool align_corners, const int align_mode, - const DataLayout data_layout) { - int nthreads = output_h * output_w; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - bool align_flag = (align_mode == 0 && !align_corners); - for (; tid < nthreads; tid += stride) { - int out_id_h = tid / output_w; - int out_id_w = tid % output_w; - int in_img_size = input_w / num_channels; - int out_img_size = output_w / num_channels; +__forceinline__ __device__ void PreCalculatorForInputIndex( + int* in_img_idx, int* in_img_idy, int* w_id, int* h_id, T* w1lambda, + T* h1lambda, T* w2lambda, T* h2lambda, T src_w, T src_h, const int in_img_w, + const int in_img_h) { + src_w = (src_w > 0) ? src_w : 0.f; + src_h = (src_h > 0) ? src_h : 0.f; + *in_img_idx = static_cast(src_w); + *in_img_idy = static_cast(src_h); + *w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0; + *h_id = (*in_img_idy < in_img_h - 1) ? 1 : 0; + *w1lambda = src_w - *in_img_idx; + *h1lambda = src_h - *in_img_idy; + *w2lambda = 1.f - *w1lambda; + *h2lambda = 1.f - *h1lambda; +} - int channel_id, out_img_idy, out_img_idx; - if (data_layout == DataLayout::kNCHW) { - channel_id = out_id_w / out_img_size; - out_img_idy = (out_id_w % out_img_size) / out_img_w; - out_img_idx = tid % out_img_w; - } else { - out_img_idy = out_id_w / (out_img_w * num_channels); - out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; - channel_id = tid % num_channels; +/* Calculate the minimum of partial elements in a block */ +template +__inline__ __device__ T PartialBlockMin(T val, size_t threads_num_in_block, + unsigned mask) { + __shared__ T shared[WARP_SIZE]; + __shared__ T shared_last_val; + __shared__ int shared_last_idx; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + int threshold = (threads_num_in_block & (-WARP_SIZE)); + + if (threadIdx.x < threshold) { + shared_last_idx = (threshold >> 5) - 1; + val = math::warpReduceMin(val, mask); + if (lane == 0) { + shared[wid] = val; } + } else { + shared_last_val = std::numeric_limits::max(); + platform::CudaAtomicMin(&shared_last_val, val); + shared[wid] = shared_last_val; + shared_last_idx = wid; + } + __syncthreads(); - int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5 - : ratio_h * out_img_idy; - in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; - int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; - T src_h = ratio_h * (out_img_idy + 0.5) - 0.5; - src_h = (src_h > 0) ? src_h : 0; - T h1lambda = - align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy; - T h2lambda = 1.f - h1lambda; + if (threadIdx.x < threshold) { + val = (lane <= shared_last_idx) ? shared[lane] + : std::numeric_limits::max(); + val = math::warpReduceMin(val, mask); + shared_last_val = val; + } + __syncthreads(); + if (threadIdx.x >= threshold) { + val = shared_last_val; + } + return val; +} - int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5 - : ratio_w * out_img_idx; - in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; - int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0; - T src_w = ratio_w * (out_img_idx + 0.5) - 0.5; - src_w = (src_w > 0) ? src_w : 0; - T w1lambda = - align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx; - T w2lambda = 1.f - w1lambda; +template +__global__ void KeBilinearInterpBwShareMemory( + T* in, const int in_h, const int in_w, const T* __restrict__ out, + const int out_h, const int out_w, const int n, const int num_channels, + float ratio_h, float ratio_w, const T align_type_value, bool is_nchw) { + __shared__ T s_data[2][1024]; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int in_chw = in_h * in_w * num_channels; + int out_chw = num_channels * out_h * out_w; + int nthreads = n * out_chw; - T* in_pos; - if (data_layout == DataLayout::kNCHW) { - in_pos = &in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / out_chw; + int out_id_w = tid % out_chw; + const int in_img_size = in_h * in_w; + const int out_img_size = out_h * out_w; + T value = out[out_id_h * out_chw + out_id_w]; + + int channel_id = out_id_w / out_img_size; + int out_img_idy = (out_id_w % out_img_size) / out_w; + int out_img_idx = tid % out_w; + + int in_img_idx, in_img_idy, w_id, h_id; + T w1lambda, h1lambda, w2lambda, h2lambda; + T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; + T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; + PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id, + &w1lambda, &h1lambda, &w2lambda, &h2lambda, + src_w, src_h, in_w, in_h); + + // top_left_index is just input_index. + int input_index = out_id_h * in_chw + channel_id * in_img_size + + in_img_idy * in_w + in_img_idx; + int top_right_index = input_index + w_id; + int bot_left_index = input_index + h_id * in_w; + int bot_right_index = input_index + h_id * in_w + w_id; + int in_top_min_index, in_bot_min_index; + + s_data[0][threadIdx.x] = 0.f; + s_data[1][threadIdx.x] = 0.f; + int remain = nthreads - (tid & (-blockDim.x)); + int in_top_max_index = math::blockReduceMax(top_right_index, FINAL_MASK); + int in_bot_max_index = math::blockReduceMax(bot_right_index, FINAL_MASK); + + if (remain > blockDim.x) { + in_top_min_index = math::blockReduceMin(input_index, FINAL_MASK); + in_bot_min_index = math::blockReduceMin(bot_left_index, FINAL_MASK); } else { - in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + - in_img_idx * num_channels + channel_id]; + in_top_min_index = PartialBlockMin(input_index, remain, FINAL_MASK); + in_bot_min_index = PartialBlockMin(bot_left_index, remain, FINAL_MASK); } + int upper_limit_share_idx = (in_top_max_index - in_top_min_index) > + (in_bot_max_index - in_bot_min_index) + ? (in_top_max_index - in_top_min_index) + : (in_bot_max_index - in_bot_min_index); + if (h_id != 0) { + platform::CudaAtomicAdd(&s_data[0][input_index - in_top_min_index], + h2lambda * w2lambda * value); + platform::CudaAtomicAdd(&s_data[0][top_right_index - in_top_min_index], + h2lambda * w1lambda * value); + platform::CudaAtomicAdd(&s_data[1][bot_left_index - in_bot_min_index], + h1lambda * w2lambda * value); + platform::CudaAtomicAdd(&s_data[1][bot_right_index - in_bot_min_index], + h1lambda * w1lambda * value); + } else { + platform::CudaAtomicAdd(&s_data[0][top_right_index - in_top_min_index], + (h2lambda + h1lambda) * w1lambda * value); + platform::CudaAtomicAdd(&s_data[1][bot_left_index - in_bot_min_index], + (h1lambda + h2lambda) * w2lambda * value); + } + __syncthreads(); - const T* out_pos = &out[out_id_h * output_w + out_id_w]; + if (threadIdx.x <= upper_limit_share_idx) { + platform::CudaAtomicAdd(&in[in_top_min_index + threadIdx.x], + s_data[0][threadIdx.x]); + platform::CudaAtomicAdd(&in[in_bot_min_index + threadIdx.x], + s_data[1][threadIdx.x]); + } + } +} - if (data_layout == DataLayout::kNCHW) { - platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); - platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]); - platform::CudaAtomicAdd(&in_pos[h_id * in_img_w], - h1lambda * w2lambda * out_pos[0]); - platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id], - h1lambda * w1lambda * out_pos[0]); - } else { - platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]); +template +__global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w, + const T* __restrict__ out, const int out_h, + const int out_w, const int n, + const int num_channels, float ratio_h, + float ratio_w, const T align_type_value, + bool is_nchw) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int in_chw = in_h * in_w * num_channels; + int out_chw = num_channels * out_h * out_w; + int nthreads = n * out_chw; + + if (is_nchw) { + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / out_chw; + int out_id_w = tid % out_chw; + const int in_img_size = in_h * in_w; + const int out_img_size = out_h * out_w; + T value = out[out_id_h * out_chw + out_id_w]; + + int channel_id = out_id_w / out_img_size; + int out_img_idy = (out_id_w % out_img_size) / out_w; + int out_img_idx = tid % out_w; + int in_img_idx, in_img_idy, w_id, h_id; + T w1lambda, h1lambda, w2lambda, h2lambda; + + T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; + T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; + PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id, + &w1lambda, &h1lambda, &w2lambda, &h2lambda, + src_w, src_h, in_w, in_h); + + T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size + + in_img_idy * in_w + in_img_idx]; + platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value); + platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * value); + platform::CudaAtomicAdd(&in_pos[h_id * in_w], + h1lambda * w2lambda * value); + platform::CudaAtomicAdd(&in_pos[h_id * in_w + w_id], + h1lambda * w1lambda * value); + } + } else { + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / out_chw; + int out_id_w = tid % out_chw; + const int in_img_size = in_h * in_w; + const int out_img_size = out_h * out_w; + T value = out[out_id_h * out_chw + out_id_w]; + + int out_img_idy = out_id_w / (out_w * num_channels); + int out_img_idx = out_id_w % (out_w * num_channels) / num_channels; + int channel_id = tid % num_channels; + + int in_img_idx, in_img_idy, w_id, h_id; + T w1lambda, h1lambda, w2lambda, h2lambda; + T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; + T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; + PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id, + &w1lambda, &h1lambda, &w2lambda, &h2lambda, + src_w, src_h, in_w, in_h); + + T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels + + in_img_idx * num_channels + channel_id]; + platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value); platform::CudaAtomicAdd(&in_pos[w_id * num_channels], - h2lambda * w1lambda * out_pos[0]); - platform::CudaAtomicAdd(&in_pos[h_id * in_img_w * num_channels], - h1lambda * w2lambda * out_pos[0]); + h2lambda * w1lambda * value); + platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels], + h1lambda * w2lambda * value); platform::CudaAtomicAdd( - &in_pos[h_id * in_img_w * num_channels + w_id * num_channels], - h1lambda * w1lambda * out_pos[0]); + &in_pos[h_id * in_w * num_channels + w_id * num_channels], + h1lambda * w1lambda * value); } } } @@ -1373,7 +1508,6 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, int out_hw = out_h * out_w; int in_chw = c * in_hw; int out_chw = c * out_hw; - int pixelNum = n * out_chw; platform::GpuLaunchConfig config = @@ -1386,11 +1520,25 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } else if ("bilinear" == interp_method) { - KeBilinearInterpBw<<>>( - input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, - n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode, - data_layout); + const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0; + bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false; + bool optimize_flag = false; + optimize_flag = (in_h < (out_h >> 6) && in_w < (out_w >> 6)) + ? true + : ((in_h == 1 && in_w == 1) ? true : false); + + if (optimize_flag & is_nchw) { + KeBilinearInterpBwShareMemory< + T><<>>( + input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c, + ratio_h, ratio_w, align_type_value, is_nchw); + } else { + KeBilinearInterpBw<<>>( + input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c, + ratio_h, ratio_w, align_type_value, is_nchw); + } } else if ("bicubic" == interp_method) { KeBicubicInterpBw<<>>( -- GitLab