未验证 提交 a1174973 编写于 作者: L Lijunhui 提交者: GitHub

Optimize bilinear interpolation foward (#39243)

* bilinear_fw init

* optimize code

* pre-compute linear_interp input index
上级 c86765ed
...@@ -59,6 +59,17 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D( ...@@ -59,6 +59,17 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
return config; return config;
} }
template <typename T>
__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
int* in_img_idx, int* w_id, T* w1lambda, T* w2lambda, T src_w,
const int in_img_w) {
src_w = (src_w > 0) ? src_w : 0.f;
*in_img_idx = static_cast<int>(src_w);
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0;
*w1lambda = src_w - *in_img_idx;
*w2lambda = 1.f - *w1lambda;
}
struct FastDivModForInterpolate { struct FastDivModForInterpolate {
public: public:
FastDivMod channels_div; FastDivMod channels_div;
...@@ -417,96 +428,93 @@ __global__ void KeLinearInterpBw(T* in, const size_t in_img_w, ...@@ -417,96 +428,93 @@ __global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
} }
template <typename T> template <typename T>
__global__ void KeBilinearInterpFw( __global__ void KeBilinearInterpNCHWFw(const T* in, const size_t in_img_h,
const T* in, const size_t in_img_h, const size_t in_img_w, const size_t in_img_w, T* out,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w, const size_t out_img_w, const size_t nc,
const size_t num_channels, const float ratio_h, const float ratio_w, const float ratio_h, const float ratio_w,
const bool align_corners, const int align_mode, const T align_type_value) {
const DataLayout data_layout) { int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int nthreads = output_h * output_w; int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int stride = blockDim.x * gridDim.x; int nc_stride = blockDim.z * gridDim.z;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx; int in_img_idx, in_img_idy, h_id, w_id;
if (data_layout == DataLayout::kNCHW) { T h1lambda, w1lambda, h2lambda, w2lambda;
channel_id = out_id_w / out_img_size; T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
out_img_idy = (out_id_w % out_img_size) / out_img_w; T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idy = align_flag PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5) &w2lambda, src_w, in_img_w);
: static_cast<int>(ratio_h * out_img_idy); PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; &h2lambda, src_h, in_img_h);
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5) int in_index_stride = nc_stride * in_img_h * in_img_w;
: static_cast<int>(ratio_w * out_img_idx);
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) { int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + int out_index_stride = nc_stride * out_img_h * out_img_w;
in_img_idy * in_img_w + in_img_idx];
// bilinear interpolation // prevent from multiple threads writing
out[out_id_h * output_w + out_id_w] = if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
const T* in_pos = &in[in_index];
out[out_index] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) + h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] + h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]); w1lambda * in_pos[h_id * in_img_w + w_id]);
} else {
const T* in_pos =
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
// bilinear interpolation in_index += in_index_stride;
out[out_id_h * output_w + out_id_w] = out_index += out_index_stride;
h2lambda * nc_id += nc_stride;
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w * num_channels] +
w1lambda * in_pos[h_id * in_img_w * num_channels +
w_id * num_channels]);
} }
} }
} }
template <typename T> template <typename T>
__forceinline__ __device__ void PreCalculatorForInputIndex( __global__ void KeBilinearInterpFw(
int* in_img_idx, int* in_img_idy, int* w_id, int* h_id, T* w1lambda, const T* in, const size_t in_img_h, const size_t in_img_w,
T* h1lambda, T* w2lambda, T* h2lambda, T src_w, T src_h, const int in_img_w, const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const int in_img_h) { const size_t out_img_w, const size_t output_h, const size_t output_w,
src_w = (src_w > 0) ? src_w : 0.f; const size_t num_channels, const float ratio_h, const float ratio_w,
src_h = (src_h > 0) ? src_h : 0.f; const T align_type_value, FastDivModForInterpolate divmods) {
*in_img_idx = static_cast<int>(src_w); int nthreads = output_h * output_w;
*in_img_idy = static_cast<int>(src_h); int tid = blockIdx.x * blockDim.x + threadIdx.x;
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0; int stride = blockDim.x * gridDim.x;
*h_id = (*in_img_idy < in_img_h - 1) ? 1 : 0;
*w1lambda = src_w - *in_img_idx; for (; tid < nthreads; tid += stride) {
*h1lambda = src_h - *in_img_idy; auto out_id_divmod = divmods.output_w_div.Divmod(tid);
*w2lambda = 1.f - *w1lambda; int out_id_h = out_id_divmod.val[0];
*h2lambda = 1.f - *h1lambda; int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_img_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_img_h);
// bilinear interpolation
const T* in_pos =
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
out[tid] =
h2lambda *
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h1lambda *
(w2lambda * in_pos[h_id * in_img_w * num_channels] +
w1lambda *
in_pos[h_id * in_img_w * num_channels + w_id * num_channels]);
}
} }
/* Calculate the minimum of partial elements in a block */ /* Calculate the minimum of partial elements in a block */
...@@ -574,9 +582,11 @@ __global__ void KeBilinearInterpBwShareMemory( ...@@ -574,9 +582,11 @@ __global__ void KeBilinearInterpBwShareMemory(
T w1lambda, h1lambda, w2lambda, h2lambda; T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda, PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
src_w, src_h, in_w, in_h); &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
// top_left_index is just input_index. // top_left_index is just input_index.
int input_index = out_id_h * in_chw + channel_id * in_img_size + int input_index = out_id_h * in_chw + channel_id * in_img_size +
...@@ -661,9 +671,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w, ...@@ -661,9 +671,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda, PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
src_w, src_h, in_w, in_h); &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size + T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size +
in_img_idy * in_w + in_img_idx]; in_img_idy * in_w + in_img_idx];
...@@ -690,9 +702,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w, ...@@ -690,9 +702,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
T w1lambda, h1lambda, w2lambda, h2lambda; T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda, PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
src_w, src_h, in_w, in_h); &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels + T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id]; in_img_idx * num_channels + channel_id];
...@@ -1398,11 +1412,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, ...@@ -1398,11 +1412,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
thread_num = 512; thread_num = 512;
} }
#endif #endif
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0, if (data_layout == DataLayout::kNCHW) {
ctx.cuda_device_context().stream()>>>( // get launch 3D config
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, int nc = n * c;
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); platform::GpuLaunchConfig config_3d =
GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
KeBilinearInterpNCHWFw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, output_data, out_h, out_w, nc, ratio_h,
ratio_w, align_type_value);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods);
}
} else if ("bicubic" == interp_method) { } else if ("bicubic" == interp_method) {
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr int thread_per_block = 256; constexpr int thread_per_block = 256;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册