提交 7f024072 编写于 作者: M Megvii Engine Team

perf(dnn): speed up pad kernel

GitOrigin-RevId: 33db700687c04ee4d813e3b3fa24b6e35de4036c
上级 2886245b
......@@ -7,6 +7,10 @@
namespace megdnn {
namespace cuda {
bool is_conv_pad(size_t offsets[MEGDNN_MAX_NDIM * 2]) {
return (offsets[0] == offsets[2] && offsets[0] == 0);
}
void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
forward_check_exec(src.layout, dst.layout);
SmallVector<size_t> offsets(get_offsets());
......@@ -16,6 +20,18 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
offsets[5], offsets[6], offsets[7], offsets[8], offsets[9],
offsets[10], offsets[11], offsets[12], offsets[13]};
auto stream = cuda_stream(this->handle());
if (src.layout.ndim == 4 && is_conv_pad(param_offsets)) {
#define cb(DType) \
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
padding::pad4d_forward_proxy<ctype>( \
src, dst, param_offsets, uint32_t(param().padding_mode), \
param().padding_val, stream); \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
} else {
#define cb(DType) \
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
using ctype = typename DTypeTrait<DType>::ctype; \
......@@ -26,6 +42,7 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
}
}
void PaddingBackwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
......
#include <thrust/pair.h>
#include <algorithm>
#include <cstring>
#include <iostream>
......@@ -94,6 +95,55 @@ __global__ void paddingReflect_kernel(
}
}
__device__ inline thrust::pair<int64_t, int64_t> get_index_mapping2d(
int64_t input_dim_x, int64_t input_dim_y, int64_t output_dim_x,
int64_t output_dim_y, int64_t pad_l, int64_t pad_t, int64_t output_xy,
int y_shift, int z_shift, int nplane) {
// 3D grid of 1D blocks
auto input_offset = ((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) *
input_dim_x * input_dim_y;
auto output_offset = ((blockIdx.y + y_shift) + (blockIdx.z + z_shift) * nplane) *
output_dim_x * output_dim_y;
auto output_x = output_xy % output_dim_x;
auto output_y = output_xy / output_dim_x;
auto i_start_x = ::max(int64_t(0), -pad_l);
auto i_start_y = ::max(int64_t(0), -pad_t);
auto o_start_x = ::max(int64_t(0), pad_l);
auto o_start_y = ::max(int64_t(0), pad_t);
auto input_x = ::abs(output_x - pad_l) -
::abs(output_x - (input_dim_x + pad_l - 1)) - output_x + 2 * pad_l +
input_dim_x - 1 - o_start_x + i_start_x;
auto input_y = ::abs(output_y - pad_t) -
::abs(output_y - (input_dim_y + pad_t - 1)) - output_y + 2 * pad_t +
input_dim_y - 1 - o_start_y + i_start_y;
return thrust::make_pair<int64_t, int64_t>(
input_offset + input_y * input_dim_x + input_x,
output_offset + output_y * output_dim_x + output_x);
}
template <typename T>
__global__ void reflection_pad4d_kernel(
const T* const input, T* const output, int64_t input_dim_x, int64_t input_dim_y,
int pad_t, int pad_b, int pad_l, int pad_r, int y_shift, int z_shift,
int nplane) {
auto output_xy = threadIdx.x + blockIdx.x * blockDim.x;
auto output_dim_x = input_dim_x + pad_l + pad_r;
auto output_dim_y = input_dim_y + pad_t + pad_b;
if (output_xy < output_dim_x * output_dim_y) {
auto index_pair = get_index_mapping2d(
input_dim_x, input_dim_y, output_dim_x, output_dim_y, pad_l, pad_t,
output_xy, y_shift, z_shift, nplane);
output[index_pair.second] = input[index_pair.first];
}
}
template <typename T>
__global__ void paddingConstBackward_kernel(
const size_t ndim, const size_t total_in_nr, const T* const src, T* const dst,
......@@ -198,6 +248,44 @@ void padding_forward_proxy(
after_kernel_launch();
}
template <typename T>
void pad4d_forward_proxy(
const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2],
uint32_t mode, const float_t padding_val, cudaStream_t stream) {
if (mode == param_enumv::Padding::PaddingMode::REFLECT) {
size_t pad_t = offsets[4];
size_t pad_b = offsets[5];
size_t pad_l = offsets[6];
size_t pad_r = offsets[7];
size_t nbatch = src.layout.shape[0];
size_t nplane = src.layout.shape[1];
size_t input_h = src.layout.shape[2];
size_t input_w = src.layout.shape[3];
size_t output_plane_size = dst.layout.shape[2] * dst.layout.shape[3];
dim3 block_size(output_plane_size > 256 ? 256 : output_plane_size);
for (size_t block_y = 0; block_y < nplane; block_y += 65535) {
size_t block_y_size =
std::min(nplane - block_y, static_cast<size_t>(65535));
for (size_t block_z = 0; block_z < nbatch; block_z += 65535) {
size_t block_z_size =
std::min(nbatch - block_z, static_cast<size_t>(65535));
dim3 grid_size(
DIVUP(output_plane_size, static_cast<size_t>(256)),
block_y_size, block_z_size);
reflection_pad4d_kernel<<<grid_size, block_size, 0, stream>>>(
src.ptr<T>(), dst.ptr<T>(), input_w, input_h, pad_t, pad_b,
pad_l, pad_r, block_y, block_z, nplane);
}
}
after_kernel_launch();
} else {
padding_forward_proxy<T>(src, dst, offsets, mode, padding_val, stream);
}
}
template <typename T>
void padding_backward_proxy(
const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2],
......@@ -250,6 +338,17 @@ MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
#undef INST
#define INST(T) \
template void pad4d_forward_proxy<T>( \
const TensorND& src, const TensorND& dst, \
size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, \
const float_t padding_val, cudaStream_t stream);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
#undef cb
#undef INST
#define INST(T) \
template void padding_backward_proxy<T>( \
const TensorND& src, const TensorND& dst, \
......
......@@ -13,6 +13,11 @@ void padding_forward_proxy(
const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2],
uint32_t mode, const float_t padding_val, cudaStream_t stream);
template <typename T>
void pad4d_forward_proxy(
const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2],
uint32_t mode, const float_t padding_val, cudaStream_t stream);
template <typename T>
void padding_backward_proxy(
const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册