未验证 提交 9e1f762c 编写于 作者: L Lijunhui 提交者: GitHub

Optimize bilinear_interp backward (#39423)

* bilinear_bw init

* optimize code

* optimize

* optimize 2

* optimize functions

* modify func name
上级 2c21d240
......@@ -61,13 +61,13 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
template <typename T>
__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
int* in_img_idx, int* w_id, T* w1lambda, T* w2lambda, T src_w,
const int in_img_w) {
src_w = (src_w > 0) ? src_w : 0.f;
*in_img_idx = static_cast<int>(src_w);
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0;
*w1lambda = src_w - *in_img_idx;
*w2lambda = 1.f - *w1lambda;
int* in_img_idx, int* x_id, T* lambda1, T* lambda2, T src_x,
const int in_img_x) {
src_x = (src_x > 0) ? src_x : 0.f;
*in_img_idx = static_cast<int>(src_x);
*x_id = (*in_img_idx < in_img_x - 1) ? 1 : 0;
*lambda1 = src_x - *in_img_idx;
*lambda2 = 1.f - *lambda1;
}
struct FastDivModForInterpolate {
......@@ -670,83 +670,102 @@ __global__ void KeBilinearInterpBwShareMemory(
}
}
__device__ __forceinline__ int GetInputIndex(const size_t nc, const int height,
const int width, const int h,
const int w) {
return (nc * height + h) * width + w;
}
template <typename T>
__global__ void KeBilinearInterpNCHWBw(T* in, const int in_h, const int in_w,
const int out_h, const int out_w,
const int n, const int num_channels,
float ratio_h, float ratio_w,
const T* __restrict__ out,
const T align_type_value) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
int stride = blockDim.x * gridDim.x;
int num_out = n * num_channels * out_h * out_w;
int num_in = n * num_channels * in_h * in_w;
for (; index < num_out; index += stride) {
int index_tmp = index;
int w2 = index_tmp % out_w;
index_tmp /= out_w;
int h2 = index_tmp % out_h;
int nc = index_tmp / out_h;
int h1, y_id;
T h1lambda, h0lambda;
T src_y = ratio_h * (h2 + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&h1, &y_id, &h1lambda, &h0lambda,
src_y, in_h);
int w1, x_id;
T w1lambda, w0lambda;
T src_x = ratio_w * (w2 + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&w1, &x_id, &w1lambda, &w0lambda,
src_x, in_w);
T d2val = out[index];
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
h0lambda * w0lambda * d2val);
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
h0lambda * w1lambda * d2val);
platform::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
h1lambda * w0lambda * d2val);
platform::CudaAtomicAdd(
in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
h1lambda * w1lambda * d2val);
}
}
template <typename T>
__global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
const T* __restrict__ out, const int out_h,
const int out_w, const int n,
const int num_channels, float ratio_h,
float ratio_w, const T align_type_value,
bool is_nchw) {
const int out_chw, const int num_channels,
float ratio_h, float ratio_w,
const T align_type_value,
FastDivModForInterpolate divmods) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
int out_chw = num_channels * out_h * out_w;
int nthreads = n * out_chw;
if (is_nchw) {
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / out_chw;
int out_id_w = tid % out_chw;
const int in_img_size = in_h * in_w;
const int out_img_size = out_h * out_w;
T value = out[out_id_h * out_chw + out_id_w];
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_w;
int out_img_idx = tid % out_w;
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size +
in_img_idy * in_w + in_img_idx];
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * value);
platform::CudaAtomicAdd(&in_pos[h_id * in_w],
h1lambda * w2lambda * value);
platform::CudaAtomicAdd(&in_pos[h_id * in_w + w_id],
h1lambda * w1lambda * value);
}
} else {
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / out_chw;
int out_id_w = tid % out_chw;
const int in_img_size = in_h * in_w;
const int out_img_size = out_h * out_w;
T value = out[out_id_h * out_chw + out_id_w];
int out_img_idy = out_id_w / (out_w * num_channels);
int out_img_idx = out_id_w % (out_w * num_channels) / num_channels;
int channel_id = tid % num_channels;
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id];
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * value);
platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
h1lambda * w2lambda * value);
platform::CudaAtomicAdd(
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * value);
}
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
&w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
&h2lambda, src_h, in_h);
T value = out[tid];
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id];
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
platform::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * value);
platform::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
h1lambda * w2lambda * value);
platform::CudaAtomicAdd(
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * value);
}
}
......@@ -1907,11 +1926,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
ratio_h, ratio_w, align_type_value, is_nchw);
} else if (!optimize_flag & is_nchw) {
//
const int num_kernels = n * c * out_h * out_w;
const int num_threads =
std::min(ctx.cuda_device_context().GetMaxThreadsPerBlock(), 1024);
KeBilinearInterpNCHWBw<
T><<<platform::DivUp(num_kernels, num_threads), num_threads, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, out_h, out_w, n, c, ratio_h, ratio_w,
output_grad_data, align_type_value);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeBilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
ratio_h, ratio_w, align_type_value, is_nchw);
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods);
}
} else if ("bicubic" == interp_method) {
#ifdef __HIPCC__
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册