提交 1e601943 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add nhwc int4 pooling

GitOrigin-RevId: 9cf14cde4e78d71bd8220bb6612e6acc5734d111
上级 0fb9cc41
......@@ -203,6 +203,35 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
relayout_opr->exec(dst, sdst, {});
}
return;
} else if (param().format == Format::NHWC &&
(src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm ||
src.layout.dtype.enumv() == DTypeEnum::QuantizedS4)) {
megdnn_assert(src.layout.dtype.enumv() == dst.layout.dtype.enumv(),
"src and dst dtype must equal");
pooling2d::Param kern_param;
size_t n = src.layout[0], hi = src.layout[1], wi = src.layout[2],
c = src.layout[3], ho = dst.layout[1], wo = dst.layout[2];
size_t ph = param().pad_h, pw = param().pad_w;
size_t window_h = param().window_h, window_w = param().window_w;
size_t sh = param().stride_h, sw = param().stride_w;
kern_param.n = n, kern_param.c = c, kern_param.hi = hi,
kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo,
kern_param.ph = ph, kern_param.pw = pw,
kern_param.window_h = window_h, kern_param.window_w = window_w,
kern_param.sh = sh, kern_param.sw = sw;
bool uint_case = false;
int zero_point = 0;
if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
uint_case = true;
zero_point = src.layout.dtype.param<dtype::Quantized4Asymm>()
.zero_point;
}
auto&& stream = cuda_stream(handle());
pooling2d::do_pooling2d_int4_nhwc(
(int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param,
stream, static_cast<uint32_t>(param().mode), uint_case,
zero_point);
return;
}
auto handle = cudnn_handle(this->handle());
setup_descs(src.layout, dst.layout);
......
......@@ -399,6 +399,59 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src,
*(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res;
}
template <typename Pooler, int pack_size, int pack_byte,
int ldg_width_assert = 4>
__global__ void pooling2d_device_template_nhwc(const int8_t* __restrict__ src,
int8_t* __restrict__ dst,
Param param, int zero_point) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
using ldg_type = typename Pooler::feed_type;
static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t);
static int constexpr ldg_width_bytes = sizeof(ldg_type);
MEGDNN_STATIC_ASSERT(
ldg_width == ldg_width_assert,
"pooling2d (NHWC) kernel must ldg_width == ldg_width_assert");
const int c_packed = param.c / pack_size;
const int batch = tid / (param.ho * param.wo * c_packed);
const int batch_residual = tid - batch * param.ho * param.wo * c_packed;
const int oh = batch_residual / (param.wo * c_packed);
const int oh_residual = batch_residual - oh * param.wo * c_packed;
const int ow = oh_residual / c_packed;
const int ow_residual = oh_residual - ow * c_packed;
const int sec = ow_residual;
if (batch >= param.n || oh >= param.ho || ow >= param.wo)
return;
const int in_batch_stride =
param.hi * param.wi * param.c * pack_byte / pack_size;
const int out_batch_stride =
param.ho * param.wo * param.c * pack_byte / pack_size;
const int w_stride = param.c * pack_byte / pack_size;
const int8_t* __restrict__ g_src_ptr =
src + (batch * in_batch_stride + sec * ldg_width_bytes);
int8_t* __restrict__ g_dst_ptr =
dst + (batch * out_batch_stride + (oh * param.wo + ow) * w_stride +
sec * ldg_width_bytes);
Pooler pooler(param.window_h * param.window_w, zero_point);
pooler.init();
for (int fh = 0; fh < param.window_h; fh++) {
uint32_t ih = oh * param.sh + fh - param.ph;
for (int fw = 0; fw < param.window_w; fw++) {
uint32_t iw = ow * param.sw + fw - param.pw;
if (ih < param.hi && iw < param.wi) {
const int8_t* __restrict__ cur_src_ptr =
g_src_ptr + (ih * param.wi + iw) * w_stride;
ldg_type sval =
__ldg(reinterpret_cast<const ldg_type*>(cur_src_ptr));
pooler.feed(sval);
}
}
}
ldg_type res = pooler.get_ans();
*(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res;
}
}; // namespace
void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src,
......@@ -588,4 +641,68 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(
after_kernel_launch();
}
void megdnn::cuda::pooling2d::do_pooling2d_int4_nhwc(
const int8_t* d_src, int8_t* d_dst, const Param& param,
cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) {
using Mode = megdnn::param_enumv::Pooling::Mode;
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param,
int zero_point);
megdnn_assert(param.c % 8 == 0);
constexpr int ldg_byte = 4;
constexpr int elem_per_byte = 2;
constexpr int ldg_width_assert = 1;
constexpr int pack_size = ldg_byte * elem_per_byte;
constexpr int pack_byte = pack_size / elem_per_byte;
constexpr int elem_per_thread = ldg_byte * elem_per_byte;
uint32_t vthreads =
param.n * param.c * param.ho * param.wo / elem_per_thread;
if (uint_case) {
switch (mode) {
case Mode::MAX:
kern = pooling2d_device_template_nhwc<
MaxPooler<dt_quint4, int32_t>, pack_size, pack_byte,
ldg_width_assert>;
break;
case Mode::AVERAGE:
kern = pooling2d_device_template_nhwc<
MeanIncludeRoundedPooler<dt_quint4, int32_t, int32_t>,
pack_size, pack_byte, ldg_width_assert>;
break;
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
kern = pooling2d_device_template_nhwc<
MeanExcludeRoundedPooler<dt_quint4, int32_t, int32_t>,
pack_size, pack_byte, ldg_width_assert>;
break;
default:
megdnn_assert(false, "invalid pooling mode");
}
} else {
switch (mode) {
case Mode::MAX:
kern = pooling2d_device_template_nhwc<
MaxPooler<dt_qint4, int32_t>, pack_size, pack_byte,
ldg_width_assert>;
break;
case Mode::AVERAGE:
kern = pooling2d_device_template_nhwc<
MeanIncludeRoundedPooler<dt_qint4, int32_t, int32_t>,
pack_size, pack_byte, ldg_width_assert>;
break;
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING:
kern = pooling2d_device_template_nhwc<
MeanExcludeRoundedPooler<dt_qint4, int32_t, int32_t>,
pack_size, pack_byte, ldg_width_assert>;
break;
default:
megdnn_assert(false, "invalid pooling mode");
}
}
uint32_t nr_threads = query_blocksize_for_kernel(kern);
nr_threads = std::min(nr_threads, vthreads);
uint32_t nr_blocks = DIVUP(vthreads, nr_threads);
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param, zero_point);
after_kernel_launch();
}
// vim: syntax=cuda.doxygen
......@@ -40,6 +40,11 @@ void do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, int8_t* d_dst,
uint32_t mode, bool uint_case = false,
int zero_point = 0);
void do_pooling2d_int4_nhwc(const int8_t* d_src, int8_t* d_dst,
const Param& param, cudaStream_t stream,
uint32_t mode, bool uint_case = false,
int zero_point = 0);
} // namespace pooling2d
} // namespace cuda
} // namespace megdnn
......
......@@ -330,6 +330,40 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW64_U4) {
checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}});
}
TEST_F(CUDA, POOLING_FORWARD_NHWC_Q4) {
require_compute_capability(7, 5);
using Param = param::Pooling;
Checker<Pooling> checker(handle_cuda());
Param param{Param::Mode::MAX, 1, 1, 2, 2, 2, 2};
UniformIntRNG int_rng{-8, 7};
checker.set_dtype(0, dtype::QuantizedS4(1.f));
param.format = Param::Format::NHWC;
checker.set_epsilon(1e-3).set_rng(0, &int_rng);
checker.set_param(param).exec({{2, 28, 28, 16}, {}});
checker.set_param(param).exec({{2, 177, 233, 16}, {}});
param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{3, 13, 28, 32}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{4, 29, 28, 64}, {}});
}
TEST_F(CUDA, POOLING_FORWARD_NHWC_U4) {
require_compute_capability(7, 5);
using Param = param::Pooling;
Checker<Pooling> checker(handle_cuda());
Param param{Param::Mode::MAX, 1, 1, 2, 2, 2, 2};
UniformIntRNG int_rng{0, 15};
checker.set_dtype(0, dtype::Quantized4Asymm(1.f, 3));
param.format = Param::Format::NHWC;
checker.set_epsilon(1e-3).set_rng(0, &int_rng);
checker.set_param(param).exec({{2, 28, 28, 16}, {}});
checker.set_param(param).exec({{2, 177, 233, 16}, {}});
param.mode = Param::Mode::AVERAGE;
checker.set_param(param).exec({{3, 13, 28, 32}, {}});
param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
checker.set_param(param).exec({{4, 29, 28, 64}, {}});
}
TEST_F(CUDA, POOLING_FORWARD_CHWN4) {
require_compute_capability(6, 1);
using Param = param::Pooling;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册