未验证 提交 df0b4434 编写于 作者: L Lijunhui 提交者: GitHub

Optimize nearest_interp backward (#39067)

* nearest_interp_bw init

* optimize kernel config

* optimize kernel config

* fix struct init

* optimize code

* rm duplicated struct
上级 539fb0d7
...@@ -210,32 +210,66 @@ __global__ void KeNearestNeighbor3DInterpFw( ...@@ -210,32 +210,66 @@ __global__ void KeNearestNeighbor3DInterpFw(
} }
} }
template <typename T>
__global__ void KeNearestNeighborInterpNCHWBw(
T* in, const size_t in_img_h, const size_t in_img_w, const T* out,
const size_t out_img_h, const size_t out_img_w, const size_t nc,
const float ratio_h, const float ratio_w, const bool align_corners) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
// nearest_sampling by multiple read in_addr and write to out_addr
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;
// prevent from multiple threads writing
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
T* in_pos = &in[in_index];
const T out_pos = out[out_index];
platform::CudaAtomicAdd(in_pos, out_pos);
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}
}
template <typename T> template <typename T>
__global__ void KeNearestNeighborInterpBw( __global__ void KeNearestNeighborInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h, const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w, const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w, const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) { const bool align_corners, FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w; int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x; int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;
for (; tid < nthreads; tid += stride) { for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w; auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_w = tid % output_w; int out_id_h = out_id_divmod.val[0];
int in_img_size = input_w / num_channels; int out_id_w = out_id_divmod.val[1];
int out_img_size = output_w / num_channels;
int channel_id, out_img_idy, out_img_idx; int channel_id = divmods.channels_div.Divmod(tid).val[1];
if (data_layout == DataLayout::kNCHW) { auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
channel_id = out_id_w / out_img_size; int out_img_idy = outimg_id_divmod.val[0];
out_img_idy = (out_id_w % out_img_size) / out_img_w; int out_img_idx =
out_img_idx = tid % out_img_w; divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int in_img_idy = (align_corners) int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5) ? static_cast<int>(ratio_h * out_img_idy + 0.5)
...@@ -244,15 +278,10 @@ __global__ void KeNearestNeighborInterpBw( ...@@ -244,15 +278,10 @@ __global__ void KeNearestNeighborInterpBw(
? static_cast<int>(ratio_w * out_img_idx + 0.5) ? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx); : static_cast<int>(ratio_w * out_img_idx);
T* in_pos; T* in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
if (data_layout == DataLayout::kNCHW) { in_img_idx * num_channels + channel_id];
in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx]; const T out_pos = out[tid];
} else {
in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
const T out_pos = out[out_id_h * output_w + out_id_w];
platform::CudaAtomicAdd(in_pos, out_pos); platform::CudaAtomicAdd(in_pos, out_pos);
} }
} }
...@@ -1842,11 +1871,26 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, ...@@ -1842,11 +1871,26 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);
if ("nearest" == interp_method) { if ("nearest" == interp_method) {
KeNearestNeighborInterpBw< if (data_layout == DataLayout::kNCHW) {
T><<<config.block_per_grid, config.thread_per_block, 0, // get launch 3D config
ctx.cuda_device_context().stream()>>>( int nc = n * c;
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, platform::GpuLaunchConfig config_3d =
n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
KeNearestNeighborInterpNCHWBw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, nc,
ratio_h, ratio_w, align_corners);
} else {
int64_t cw = c * out_w;
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
KeNearestNeighborInterpBw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w, align_corners,
interp_divmods);
}
} else if ("bilinear" == interp_method) { } else if ("bilinear" == interp_method) {
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0; const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false; bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册