未验证 提交 a1174973 编写于 作者: L Lijunhui 提交者: GitHub

Optimize bilinear interpolation foward (#39243)

* bilinear_fw init

* optimize code

* pre-compute linear_interp input index
上级 c86765ed
...@@ -59,6 +59,17 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D( ...@@ -59,6 +59,17 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
return config; return config;
} }
template <typename T>
__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
int* in_img_idx, int* w_id, T* w1lambda, T* w2lambda, T src_w,
const int in_img_w) {
src_w = (src_w > 0) ? src_w : 0.f;
*in_img_idx = static_cast<int>(src_w);
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0;
*w1lambda = src_w - *in_img_idx;
*w2lambda = 1.f - *w1lambda;
}
struct FastDivModForInterpolate { struct FastDivModForInterpolate {
public: public:
FastDivMod channels_div; FastDivMod channels_div;
...@@ -416,99 +427,96 @@ __global__ void KeLinearInterpBw(T* in, const size_t in_img_w, ...@@ -416,99 +427,96 @@ __global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
} }
} }
template <typename T>
__global__ void KeBilinearInterpNCHWFw(const T* in, const size_t in_img_h,
const size_t in_img_w, T* out,
const size_t out_img_h,
const size_t out_img_w, const size_t nc,
const float ratio_h, const float ratio_w,
const T align_type_value) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_img_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_img_h);
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;
// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
const T* in_pos = &in[in_index];
out[out_index] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}
template <typename T> template <typename T>
__global__ void KeBilinearInterpFw( __global__ void KeBilinearInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w, const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w, const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w, const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const int align_mode, const T align_type_value, FastDivModForInterpolate divmods) {
const DataLayout data_layout) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
bool align_flag = (align_mode == 0 && !align_corners);
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx; for (; tid < nthreads; tid += stride) {
if (data_layout == DataLayout::kNCHW) { auto out_id_divmod = divmods.output_w_div.Divmod(tid);
channel_id = out_id_w / out_img_size; int out_id_h = out_id_divmod.val[0];
out_img_idy = (out_id_w % out_img_size) / out_img_w; int out_id_w = out_id_divmod.val[1];
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idy = align_flag int channel_id = divmods.channels_div.Divmod(tid).val[1];
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5) auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
: static_cast<int>(ratio_h * out_img_idy); int out_img_idy = outimg_id_divmod.val[0];
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0; int out_img_idx =
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0; divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
src_h = (src_h > 0) ? src_h : 0;
T h1lambda =
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
T h2lambda = 1.f - h1lambda;
int in_img_idx = align_flag int in_img_idx, in_img_idy, h_id, w_id;
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5) T h1lambda, w1lambda, h2lambda, w2lambda;
: static_cast<int>(ratio_w * out_img_idx); T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0; T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
src_w = (src_w > 0) ? src_w : 0;
T w1lambda =
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
if (data_layout == DataLayout::kNCHW) { PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size + &w2lambda, src_w, in_img_w);
in_img_idy * in_img_w + in_img_idx]; PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_img_h);
// bilinear interpolation // bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
} else {
const T* in_pos = const T* in_pos =
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id]; in_img_idx * num_channels + channel_id];
out[tid] =
// bilinear interpolation
out[out_id_h * output_w + out_id_w] =
h2lambda * h2lambda *
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) + (w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w * num_channels] + h1lambda *
w1lambda * in_pos[h_id * in_img_w * num_channels + (w2lambda * in_pos[h_id * in_img_w * num_channels] +
w_id * num_channels]); w1lambda *
} in_pos[h_id * in_img_w * num_channels + w_id * num_channels]);
} }
} }
template <typename T>
__forceinline__ __device__ void PreCalculatorForInputIndex(
int* in_img_idx, int* in_img_idy, int* w_id, int* h_id, T* w1lambda,
T* h1lambda, T* w2lambda, T* h2lambda, T src_w, T src_h, const int in_img_w,
const int in_img_h) {
src_w = (src_w > 0) ? src_w : 0.f;
src_h = (src_h > 0) ? src_h : 0.f;
*in_img_idx = static_cast<int>(src_w);
*in_img_idy = static_cast<int>(src_h);
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0;
*h_id = (*in_img_idy < in_img_h - 1) ? 1 : 0;
*w1lambda = src_w - *in_img_idx;
*h1lambda = src_h - *in_img_idy;
*w2lambda = 1.f - *w1lambda;
*h2lambda = 1.f - *h1lambda;
}
/* Calculate the minimum of partial elements in a block */ /* Calculate the minimum of partial elements in a block */
template <typename T> template <typename T>
__inline__ __device__ T PartialBlockMin(T val, size_t threads_num_in_block, __inline__ __device__ T PartialBlockMin(T val, size_t threads_num_in_block,
...@@ -574,9 +582,11 @@ __global__ void KeBilinearInterpBwShareMemory( ...@@ -574,9 +582,11 @@ __global__ void KeBilinearInterpBwShareMemory(
T w1lambda, h1lambda, w2lambda, h2lambda; T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda, PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
src_w, src_h, in_w, in_h); &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
// top_left_index is just input_index. // top_left_index is just input_index.
int input_index = out_id_h * in_chw + channel_id * in_img_size + int input_index = out_id_h * in_chw + channel_id * in_img_size +
...@@ -661,9 +671,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w, ...@@ -661,9 +671,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda, PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
src_w, src_h, in_w, in_h); &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size + T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size +
in_img_idy * in_w + in_img_idx]; in_img_idy * in_w + in_img_idx];
...@@ -690,9 +702,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w, ...@@ -690,9 +702,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
T w1lambda, h1lambda, w2lambda, h2lambda; T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value; T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value; T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
&w1lambda, &h1lambda, &w2lambda, &h2lambda, PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
src_w, src_h, in_w, in_h); &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels + T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id]; in_img_idx * num_channels + channel_id];
...@@ -1398,11 +1412,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, ...@@ -1398,11 +1412,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
thread_num = 512; thread_num = 512;
} }
#endif #endif
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
int nc = n * c;
platform::GpuLaunchConfig config_3d =
GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
KeBilinearInterpNCHWFw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, output_data, out_h, out_w, nc, ratio_h,
ratio_w, align_type_value);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0, KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0,
ctx.cuda_device_context().stream()>>>( ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout); out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods);
}
} else if ("bicubic" == interp_method) { } else if ("bicubic" == interp_method) {
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr int thread_per_block = 256; constexpr int thread_per_block = 256;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册