From 1e6019436cac40935ddc18fd595eca130a1f9c2c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 25 May 2021 19:08:46 +0800 Subject: [PATCH] feat(dnn/cuda): add nhwc int4 pooling GitOrigin-RevId: 9cf14cde4e78d71bd8220bb6612e6acc5734d111 --- dnn/src/cuda/pooling/opr_impl.cpp | 29 ++++++ dnn/src/cuda/pooling/pooling2d_qint.cu | 117 ++++++++++++++++++++++++ dnn/src/cuda/pooling/pooling2d_qint.cuh | 5 + dnn/test/cuda/pooling.cpp | 34 +++++++ 4 files changed, 185 insertions(+) diff --git a/dnn/src/cuda/pooling/opr_impl.cpp b/dnn/src/cuda/pooling/opr_impl.cpp index f3755bf36..37ff5ee9c 100644 --- a/dnn/src/cuda/pooling/opr_impl.cpp +++ b/dnn/src/cuda/pooling/opr_impl.cpp @@ -203,6 +203,35 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, relayout_opr->exec(dst, sdst, {}); } return; + } else if (param().format == Format::NHWC && + (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm || + src.layout.dtype.enumv() == DTypeEnum::QuantizedS4)) { + megdnn_assert(src.layout.dtype.enumv() == dst.layout.dtype.enumv(), + "src and dst dtype must equal"); + pooling2d::Param kern_param; + size_t n = src.layout[0], hi = src.layout[1], wi = src.layout[2], + c = src.layout[3], ho = dst.layout[1], wo = dst.layout[2]; + size_t ph = param().pad_h, pw = param().pad_w; + size_t window_h = param().window_h, window_w = param().window_w; + size_t sh = param().stride_h, sw = param().stride_w; + kern_param.n = n, kern_param.c = c, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, + kern_param.ph = ph, kern_param.pw = pw, + kern_param.window_h = window_h, kern_param.window_w = window_w, + kern_param.sh = sh, kern_param.sw = sw; + bool uint_case = false; + int zero_point = 0; + if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { + uint_case = true; + zero_point = src.layout.dtype.param() + .zero_point; + } + auto&& stream = cuda_stream(handle()); + pooling2d::do_pooling2d_int4_nhwc( + (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param, + stream, static_cast(param().mode), uint_case, + zero_point); + return; } auto handle = cudnn_handle(this->handle()); setup_descs(src.layout, dst.layout); diff --git a/dnn/src/cuda/pooling/pooling2d_qint.cu b/dnn/src/cuda/pooling/pooling2d_qint.cu index 10a623a98..ddd970ea7 100644 --- a/dnn/src/cuda/pooling/pooling2d_qint.cu +++ b/dnn/src/cuda/pooling/pooling2d_qint.cu @@ -399,6 +399,59 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, *(reinterpret_cast(g_dst_ptr)) = res; } +template +__global__ void pooling2d_device_template_nhwc(const int8_t* __restrict__ src, + int8_t* __restrict__ dst, + Param param, int zero_point) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + using ldg_type = typename Pooler::feed_type; + static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); + static int constexpr ldg_width_bytes = sizeof(ldg_type); + MEGDNN_STATIC_ASSERT( + ldg_width == ldg_width_assert, + "pooling2d (NHWC) kernel must ldg_width == ldg_width_assert"); + const int c_packed = param.c / pack_size; + const int batch = tid / (param.ho * param.wo * c_packed); + const int batch_residual = tid - batch * param.ho * param.wo * c_packed; + const int oh = batch_residual / (param.wo * c_packed); + const int oh_residual = batch_residual - oh * param.wo * c_packed; + const int ow = oh_residual / c_packed; + const int ow_residual = oh_residual - ow * c_packed; + const int sec = ow_residual; + if (batch >= param.n || oh >= param.ho || ow >= param.wo) + return; + + const int in_batch_stride = + param.hi * param.wi * param.c * pack_byte / pack_size; + const int out_batch_stride = + param.ho * param.wo * param.c * pack_byte / pack_size; + const int w_stride = param.c * pack_byte / pack_size; + const int8_t* __restrict__ g_src_ptr = + src + (batch * in_batch_stride + sec * ldg_width_bytes); + int8_t* __restrict__ g_dst_ptr = + dst + (batch * out_batch_stride + (oh * param.wo + ow) * w_stride + + sec * ldg_width_bytes); + + Pooler pooler(param.window_h * param.window_w, zero_point); + pooler.init(); + for (int fh = 0; fh < param.window_h; fh++) { + uint32_t ih = oh * param.sh + fh - param.ph; + for (int fw = 0; fw < param.window_w; fw++) { + uint32_t iw = ow * param.sw + fw - param.pw; + if (ih < param.hi && iw < param.wi) { + const int8_t* __restrict__ cur_src_ptr = + g_src_ptr + (ih * param.wi + iw) * w_stride; + ldg_type sval = + __ldg(reinterpret_cast(cur_src_ptr)); + pooler.feed(sval); + } + } + } + ldg_type res = pooler.get_ans(); + *(reinterpret_cast(g_dst_ptr)) = res; +} + }; // namespace void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, @@ -588,4 +641,68 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64( after_kernel_launch(); } +void megdnn::cuda::pooling2d::do_pooling2d_int4_nhwc( + const int8_t* d_src, int8_t* d_dst, const Param& param, + cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { + using Mode = megdnn::param_enumv::Pooling::Mode; + void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, + int zero_point); + + megdnn_assert(param.c % 8 == 0); + constexpr int ldg_byte = 4; + constexpr int elem_per_byte = 2; + constexpr int ldg_width_assert = 1; + constexpr int pack_size = ldg_byte * elem_per_byte; + constexpr int pack_byte = pack_size / elem_per_byte; + constexpr int elem_per_thread = ldg_byte * elem_per_byte; + uint32_t vthreads = + param.n * param.c * param.ho * param.wo / elem_per_thread; + if (uint_case) { + switch (mode) { + case Mode::MAX: + kern = pooling2d_device_template_nhwc< + MaxPooler, pack_size, pack_byte, + ldg_width_assert>; + break; + case Mode::AVERAGE: + kern = pooling2d_device_template_nhwc< + MeanIncludeRoundedPooler, + pack_size, pack_byte, ldg_width_assert>; + break; + case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + kern = pooling2d_device_template_nhwc< + MeanExcludeRoundedPooler, + pack_size, pack_byte, ldg_width_assert>; + break; + default: + megdnn_assert(false, "invalid pooling mode"); + } + + } else { + switch (mode) { + case Mode::MAX: + kern = pooling2d_device_template_nhwc< + MaxPooler, pack_size, pack_byte, + ldg_width_assert>; + break; + case Mode::AVERAGE: + kern = pooling2d_device_template_nhwc< + MeanIncludeRoundedPooler, + pack_size, pack_byte, ldg_width_assert>; + break; + case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + kern = pooling2d_device_template_nhwc< + MeanExcludeRoundedPooler, + pack_size, pack_byte, ldg_width_assert>; + break; + default: + megdnn_assert(false, "invalid pooling mode"); + } + } + uint32_t nr_threads = query_blocksize_for_kernel(kern); + nr_threads = std::min(nr_threads, vthreads); + uint32_t nr_blocks = DIVUP(vthreads, nr_threads); + kern<<>>(d_src, d_dst, param, zero_point); + after_kernel_launch(); +} // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/pooling/pooling2d_qint.cuh b/dnn/src/cuda/pooling/pooling2d_qint.cuh index 1518b4a84..5ad2ef6e8 100644 --- a/dnn/src/cuda/pooling/pooling2d_qint.cuh +++ b/dnn/src/cuda/pooling/pooling2d_qint.cuh @@ -40,6 +40,11 @@ void do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, int8_t* d_dst, uint32_t mode, bool uint_case = false, int zero_point = 0); +void do_pooling2d_int4_nhwc(const int8_t* d_src, int8_t* d_dst, + const Param& param, cudaStream_t stream, + uint32_t mode, bool uint_case = false, + int zero_point = 0); + } // namespace pooling2d } // namespace cuda } // namespace megdnn diff --git a/dnn/test/cuda/pooling.cpp b/dnn/test/cuda/pooling.cpp index 8d01e53fe..d7dc1faed 100644 --- a/dnn/test/cuda/pooling.cpp +++ b/dnn/test/cuda/pooling.cpp @@ -330,6 +330,40 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW64_U4) { checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}}); } +TEST_F(CUDA, POOLING_FORWARD_NHWC_Q4) { + require_compute_capability(7, 5); + using Param = param::Pooling; + Checker checker(handle_cuda()); + Param param{Param::Mode::MAX, 1, 1, 2, 2, 2, 2}; + UniformIntRNG int_rng{-8, 7}; + checker.set_dtype(0, dtype::QuantizedS4(1.f)); + param.format = Param::Format::NHWC; + checker.set_epsilon(1e-3).set_rng(0, &int_rng); + checker.set_param(param).exec({{2, 28, 28, 16}, {}}); + checker.set_param(param).exec({{2, 177, 233, 16}, {}}); + param.mode = Param::Mode::AVERAGE; + checker.set_param(param).exec({{3, 13, 28, 32}, {}}); + param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; + checker.set_param(param).exec({{4, 29, 28, 64}, {}}); +} + +TEST_F(CUDA, POOLING_FORWARD_NHWC_U4) { + require_compute_capability(7, 5); + using Param = param::Pooling; + Checker checker(handle_cuda()); + Param param{Param::Mode::MAX, 1, 1, 2, 2, 2, 2}; + UniformIntRNG int_rng{0, 15}; + checker.set_dtype(0, dtype::Quantized4Asymm(1.f, 3)); + param.format = Param::Format::NHWC; + checker.set_epsilon(1e-3).set_rng(0, &int_rng); + checker.set_param(param).exec({{2, 28, 28, 16}, {}}); + checker.set_param(param).exec({{2, 177, 233, 16}, {}}); + param.mode = Param::Mode::AVERAGE; + checker.set_param(param).exec({{3, 13, 28, 32}, {}}); + param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; + checker.set_param(param).exec({{4, 29, 28, 64}, {}}); +} + TEST_F(CUDA, POOLING_FORWARD_CHWN4) { require_compute_capability(6, 1); using Param = param::Pooling; -- GitLab