diff --git a/paddle/fluid/operators/interpolate_v2_op.cu b/paddle/fluid/operators/interpolate_v2_op.cu index d61eb46d97e98972963f5871a4c6e7b06468337c..cd297c53f89a0f7efc622de7c385b9f75dc7462b 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cu +++ b/paddle/fluid/operators/interpolate_v2_op.cu @@ -61,13 +61,13 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D( template __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex( - int* in_img_idx, int* w_id, T* w1lambda, T* w2lambda, T src_w, - const int in_img_w) { - src_w = (src_w > 0) ? src_w : 0.f; - *in_img_idx = static_cast(src_w); - *w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0; - *w1lambda = src_w - *in_img_idx; - *w2lambda = 1.f - *w1lambda; + int* in_img_idx, int* x_id, T* lambda1, T* lambda2, T src_x, + const int in_img_x) { + src_x = (src_x > 0) ? src_x : 0.f; + *in_img_idx = static_cast(src_x); + *x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0; + *lambda1 = src_x - *in_img_idx; + *lambda2 = 1.f - *lambda1; } struct FastDivModForInterpolate { @@ -670,83 +670,102 @@ __global__ void KeBilinearInterpBwShareMemory( } } +__device__ __forceinline__ int GetInputIndex(const size_t nc, const int height, + const int width, const int h, + const int w) { + return (nc * height + h) * width + w; +} + +template +__global__ void KeBilinearInterpNCHWBw(T* in, const int in_h, const int in_w, + const int out_h, const int out_w, + const int n, const int num_channels, + float ratio_h, float ratio_w, + const T* __restrict__ out, + const T align_type_value) { + int index = threadIdx.x + blockDim.x * blockIdx.x; + int stride = blockDim.x * gridDim.x; + int num_out = n * num_channels * out_h * out_w; + int num_in = n * num_channels * in_h * in_w; + + for (; index < num_out; index += stride) { + int index_tmp = index; + int w2 = index_tmp % out_w; + index_tmp /= out_w; + int h2 = index_tmp % out_h; + int nc = index_tmp / out_h; + + int h1, y_id; + T h1lambda, h0lambda; + T src_y = ratio_h * (h2 + align_type_value) - align_type_value; + + PreCalculatorForLinearInterpInputIndex(&h1, &y_id, &h1lambda, &h0lambda, + src_y, in_h); + int w1, x_id; + T w1lambda, w0lambda; + T src_x = ratio_w * (w2 + align_type_value) - align_type_value; + PreCalculatorForLinearInterpInputIndex(&w1, &x_id, &w1lambda, &w0lambda, + src_x, in_w); + + T d2val = out[index]; + + platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1), + h0lambda * w0lambda * d2val); + platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id), + h0lambda * w1lambda * d2val); + platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1), + h1lambda * w0lambda * d2val); + platform::CudaAtomicAdd( + in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id), + h1lambda * w1lambda * d2val); + } +} + template __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w, const T* __restrict__ out, const int out_h, const int out_w, const int n, - const int num_channels, float ratio_h, - float ratio_w, const T align_type_value, - bool is_nchw) { + const int out_chw, const int num_channels, + float ratio_h, float ratio_w, + const T align_type_value, + FastDivModForInterpolate divmods) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; int in_chw = in_h * in_w * num_channels; - int out_chw = num_channels * out_h * out_w; int nthreads = n * out_chw; - if (is_nchw) { - for (; tid < nthreads; tid += stride) { - int out_id_h = tid / out_chw; - int out_id_w = tid % out_chw; - const int in_img_size = in_h * in_w; - const int out_img_size = out_h * out_w; - T value = out[out_id_h * out_chw + out_id_w]; - - int channel_id = out_id_w / out_img_size; - int out_img_idy = (out_id_w % out_img_size) / out_w; - int out_img_idx = tid % out_w; - int in_img_idx, in_img_idy, w_id, h_id; - T w1lambda, h1lambda, w2lambda, h2lambda; - - T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; - T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; - - PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda, - &w2lambda, src_w, in_w); - PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda, - &h2lambda, src_h, in_h); - - T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size + - in_img_idy * in_w + in_img_idx]; - platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value); - platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * value); - platform::CudaAtomicAdd(&in_pos[h_id * in_w], - h1lambda * w2lambda * value); - platform::CudaAtomicAdd(&in_pos[h_id * in_w + w_id], - h1lambda * w1lambda * value); - } - } else { - for (; tid < nthreads; tid += stride) { - int out_id_h = tid / out_chw; - int out_id_w = tid % out_chw; - const int in_img_size = in_h * in_w; - const int out_img_size = out_h * out_w; - T value = out[out_id_h * out_chw + out_id_w]; - - int out_img_idy = out_id_w / (out_w * num_channels); - int out_img_idx = out_id_w % (out_w * num_channels) / num_channels; - int channel_id = tid % num_channels; - - int in_img_idx, in_img_idy, w_id, h_id; - T w1lambda, h1lambda, w2lambda, h2lambda; - T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; - T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; - - PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda, - &w2lambda, src_w, in_w); - PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda, - &h2lambda, src_h, in_h); - - T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels + - in_img_idx * num_channels + channel_id]; - platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value); - platform::CudaAtomicAdd(&in_pos[w_id * num_channels], - h2lambda * w1lambda * value); - platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels], - h1lambda * w2lambda * value); - platform::CudaAtomicAdd( - &in_pos[h_id * in_w * num_channels + w_id * num_channels], - h1lambda * w1lambda * value); - } + for (; tid < nthreads; tid += stride) { + auto out_id_divmod = divmods.output_w_div.Divmod(tid); + int out_id_h = out_id_divmod.val[0]; + int out_id_w = out_id_divmod.val[1]; + + int channel_id = divmods.channels_div.Divmod(tid).val[1]; + auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w); + int out_img_idy = outimg_id_divmod.val[0]; + int out_img_idx = + divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0]; + + int in_img_idx, in_img_idy, w_id, h_id; + T w1lambda, h1lambda, w2lambda, h2lambda; + T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; + T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; + + PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda, + &w2lambda, src_w, in_w); + PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda, + &h2lambda, src_h, in_h); + + T value = out[tid]; + T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels + + in_img_idx * num_channels + channel_id]; + platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value); + platform::CudaAtomicAdd(&in_pos[w_id * num_channels], + h2lambda * w1lambda * value); + platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels], + h1lambda * w2lambda * value); + platform::CudaAtomicAdd( + &in_pos[h_id * in_w * num_channels + w_id * num_channels], + h1lambda * w1lambda * value); } } @@ -1907,11 +1926,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, ctx.cuda_device_context().stream()>>>( input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c, ratio_h, ratio_w, align_type_value, is_nchw); + } else if (!optimize_flag & is_nchw) { + // + const int num_kernels = n * c * out_h * out_w; + const int num_threads = + std::min(ctx.cuda_device_context().GetMaxThreadsPerBlock(), 1024); + KeBilinearInterpNCHWBw< + T><<>>( + input_grad_data, in_h, in_w, out_h, out_w, n, c, ratio_h, ratio_w, + output_grad_data, align_type_value); } else { + int64_t cw = c * out_w; + auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw); KeBilinearInterpBw<<>>( - input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c, - ratio_h, ratio_w, align_type_value, is_nchw); + input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods); } } else if ("bicubic" == interp_method) { #ifdef __HIPCC__