未验证 提交 d8afe407 编写于 作者: L limingshu 提交者: GitHub

Optimization of bilinear backward OP CUDA kernel. (#30950)

上级 af374ae6
......@@ -12,6 +12,8 @@
#include <algorithm>
#include <string>
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
......@@ -302,81 +304,214 @@ __global__ void KeBilinearInterpFw(
}
template <typename T>
__global__ void KeBilinearInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratio_w,
const bool align_corners, const int align_mode,
const DataLayout data_layout) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
__forceinline__ __device__ void PreCalculatorForInputIndex(
int* in_img_idx, int* in_img_idy, int* w_id, int* h_id, T* w1lambda,
T* h1lambda, T* w2lambda, T* h2lambda, T src_w, T src_h, const int in_img_w,
const int in_img_h) {
src_w = (src_w > 0) ? src_w : 0.f;
src_h = (src_h > 0) ? src_h : 0.f;
*in_img_idx = static_cast<int>(src_w);
*in_img_idy = static_cast<int>(src_h);
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0;
*h_id = (*in_img_idy < in_img_h - 1) ? 1 : 0;
*w1lambda = src_w - *in_img_idx;
*h1lambda = src_h - *in_img_idy;
*w2lambda = 1.f - *w1lambda;
*h2lambda = 1.f - *h1lambda;
}
int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
/* Calculate the minimum of partial elements in a block */
template <typename T>
__inline__ __device__ T PartialBlockMin(T val, size_t threads_num_in_block,
unsigned mask) {
__shared__ T shared[WARP_SIZE];
__shared__ T shared_last_val;
__shared__ int shared_last_idx;
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
int threshold = (threads_num_in_block & (-WARP_SIZE));
if (threadIdx.x < threshold) {
shared_last_idx = (threshold >> 5) - 1;
val = math::warpReduceMin(val, mask);
if (lane == 0) {
shared[wid] = val;
}
} else {
shared_last_val = std::numeric_limits<T>::max();
platform::CudaAtomicMin(&shared_last_val, val);
shared[wid] = shared_last_val;
shared_last_idx = wid;
}
__syncthreads();
int in_img_idy = align_flag ? ratio_h * (out_img_idy + 0.5) - 0.5
: ratio_h * out_img_idy;
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
if (threadIdx.x < threshold) {
val = (lane <= shared_last_idx) ? shared[lane]
: std::numeric_limits<T>::max();
val = math::warpReduceMin(val, mask);
shared_last_val = val;
}
__syncthreads();
if (threadIdx.x >= threshold) {
val = shared_last_val;
}
return val;
}
int in_img_idx = align_flag ? ratio_w * (out_img_idx + 0.5) - 0.5
: ratio_w * out_img_idx;
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
template <typename T>
__global__ void KeBilinearInterpBwShareMemory(
T* in, const int in_h, const int in_w, const T* __restrict__ out,
const int out_h, const int out_w, const int n, const int num_channels,
float ratio_h, float ratio_w, const T align_type_value, bool is_nchw) {
__shared__ T s_data[2][1024];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
int out_chw = num_channels * out_h * out_w;
int nthreads = n * out_chw;
T* in_pos;
if (data_layout == DataLayout::kNCHW) {
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / out_chw;
int out_id_w = tid % out_chw;
const int in_img_size = in_h * in_w;
const int out_img_size = out_h * out_w;
T value = out[out_id_h * out_chw + out_id_w];
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_w;
int out_img_idx = tid % out_w;
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda,
src_w, src_h, in_w, in_h);
// top_left_index is just input_index.
int input_index = out_id_h * in_chw + channel_id * in_img_size +
in_img_idy * in_w + in_img_idx;
int top_right_index = input_index + w_id;
int bot_left_index = input_index + h_id * in_w;
int bot_right_index = input_index + h_id * in_w + w_id;
int in_top_min_index, in_bot_min_index;
s_data[0][threadIdx.x] = 0.f;
s_data[1][threadIdx.x] = 0.f;
int remain = nthreads - (tid & (-blockDim.x));
int in_top_max_index = math::blockReduceMax(top_right_index, FINAL_MASK);
int in_bot_max_index = math::blockReduceMax(bot_right_index, FINAL_MASK);
if (remain > blockDim.x) {
in_top_min_index = math::blockReduceMin(input_index, FINAL_MASK);
in_bot_min_index = math::blockReduceMin(bot_left_index, FINAL_MASK);
} else {
in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
in_top_min_index = PartialBlockMin(input_index, remain, FINAL_MASK);
in_bot_min_index = PartialBlockMin(bot_left_index, remain, FINAL_MASK);
}
int upper_limit_share_idx = (in_top_max_index - in_top_min_index) >
(in_bot_max_index - in_bot_min_index)
? (in_top_max_index - in_top_min_index)
: (in_bot_max_index - in_bot_min_index);
if (h_id != 0) {
platform::CudaAtomicAdd(&s_data[0][input_index - in_top_min_index],
h2lambda * w2lambda * value);
platform::CudaAtomicAdd(&s_data[0][top_right_index - in_top_min_index],
h2lambda * w1lambda * value);
platform::CudaAtomicAdd(&s_data[1][bot_left_index - in_bot_min_index],
h1lambda * w2lambda * value);
platform::CudaAtomicAdd(&s_data[1][bot_right_index - in_bot_min_index],
h1lambda * w1lambda * value);
} else {
platform::CudaAtomicAdd(&s_data[0][top_right_index - in_top_min_index],
(h2lambda + h1lambda) * w1lambda * value);
platform::CudaAtomicAdd(&s_data[1][bot_left_index - in_bot_min_index],
(h1lambda + h2lambda) * w2lambda * value);
}
__syncthreads();
const T* out_pos = &out[out_id_h * output_w + out_id_w];
if (threadIdx.x <= upper_limit_share_idx) {
platform::CudaAtomicAdd(&in[in_top_min_index + threadIdx.x],
s_data[0][threadIdx.x]);
platform::CudaAtomicAdd(&in[in_bot_min_index + threadIdx.x],
s_data[1][threadIdx.x]);
}
}
}
if (data_layout == DataLayout::kNCHW) {
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w],
h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id],
h1lambda * w1lambda * out_pos[0]);
} else {
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
template <typename T>
__global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
const T* __restrict__ out, const int out_h,
const int out_w, const int n,
const int num_channels, float ratio_h,
float ratio_w, const T align_type_value,
bool is_nchw) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
int out_chw = num_channels * out_h * out_w;
int nthreads = n * out_chw;
if (is_nchw) {
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / out_chw;
int out_id_w = tid % out_chw;
const int in_img_size = in_h * in_w;
const int out_img_size = out_h * out_w;
T value = out[out_id_h * out_chw + out_id_w];
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_w;
int out_img_idx = tid % out_w;
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda,
src_w, src_h, in_w, in_h);
T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size +
in_img_idy * in_w + in_img_idx];
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * value);
platform::CudaAtomicAdd(&in_pos[h_id * in_w],
h1lambda * w2lambda * value);
platform::CudaAtomicAdd(&in_pos[h_id * in_w + w_id],
h1lambda * w1lambda * value);
}
} else {
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / out_chw;
int out_id_w = tid % out_chw;
const int in_img_size = in_h * in_w;
const int out_img_size = out_h * out_w;
T value = out[out_id_h * out_chw + out_id_w];
int out_img_idy = out_id_w / (out_w * num_channels);
int out_img_idx = out_id_w % (out_w * num_channels) / num_channels;
int channel_id = tid % num_channels;
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda,
src_w, src_h, in_w, in_h);
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id];
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w * num_channels],
h1lambda * w2lambda * out_pos[0]);
h2lambda * w1lambda * value);
platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
h1lambda * w2lambda * value);
platform::CudaAtomicAdd(
&in_pos[h_id * in_img_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * out_pos[0]);
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * value);
}
}
}
......@@ -1373,7 +1508,6 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
int out_hw = out_h * out_w;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
int pixelNum = n * out_chw;
platform::GpuLaunchConfig config =
......@@ -1386,11 +1520,25 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
data_layout);
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false;
bool optimize_flag = false;
optimize_flag = (in_h < (out_h >> 6) && in_w < (out_w >> 6))
? true
: ((in_h == 1 && in_w == 1) ? true : false);
if (optimize_flag & is_nchw) {
KeBilinearInterpBwShareMemory<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
ratio_h, ratio_w, align_type_value, is_nchw);
} else {
KeBilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
ratio_h, ratio_w, align_type_value, is_nchw);
}
} else if ("bicubic" == interp_method) {
KeBicubicInterpBw<T><<<config.block_per_grid, 512, 0,
ctx.cuda_device_context().stream()>>>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册