diff --git a/dnn/src/cuda/pooling/opr_impl.cpp b/dnn/src/cuda/pooling/opr_impl.cpp index e88dfc61382607a3993885c2661e18e9f4ce1c29..f3755bf367b0d41fe5372c9bd6dca296a1e5c598 100644 --- a/dnn/src/cuda/pooling/opr_impl.cpp +++ b/dnn/src/cuda/pooling/opr_impl.cpp @@ -40,8 +40,7 @@ void get_inner_layout(const TensorLayout& src, const TensorLayout& dst, Handle* handle, PoolingForwardImpl::Param::Format format) { bool is_nchw = format == PoolingForwardImpl::Param::Format::NCHW; - if (src.dtype.enumv() == DTypeEnum::QuantizedS4 && - dst.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) { + if (is_nchw) { auto relayout_opr = handle->create_operator(); deduce_reformat_layout(relayout_opr, src, inner_src, RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1); @@ -66,8 +65,11 @@ WorkspaceBundle PoolingForwardImpl::get_workspace_bundle( TensorLayout fsrc = src; TensorLayout fdst = dst; bool is_nchw = param().format == Param::Format::NCHW; - if (src.dtype.enumv() == DTypeEnum::QuantizedS4 && - dst.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) { + if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 || + src.dtype.enumv() == DTypeEnum::Quantized4Asymm) && + (dst.dtype.enumv() == DTypeEnum::QuantizedS4 || + dst.dtype.enumv() == DTypeEnum::Quantized4Asymm) && + is_nchw) { get_inner_layout(src, dst, fsrc, fdst, handle(), param().format); sizes.push_back(fsrc.span().dist_byte()); sizes.push_back(fdst.span().dist_byte()); @@ -97,8 +99,11 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, bool is_nchw = param().format == Param::Format::NCHW; if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { ctypecvt.src_to_comp_type(ssrc, src).src_to_comp_type(sdst, dst); - } else if (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && - sdst.layout.dtype.enumv() == DTypeEnum::QuantizedS4 && is_nchw) { + } else if ((ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 || + ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) && + (sdst.layout.dtype.enumv() == DTypeEnum::QuantizedS4 || + sdst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) && + is_nchw) { auto handle_ptr = handle(); get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout, handle_ptr, param().format); @@ -166,8 +171,6 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, kern_param, stream, static_cast(param().mode)); } else if (param().format == Format::NCHW64 || inner_format == Format::NCHW64) { - megdnn_assert(src.layout.dtype.enumv() == DTypeEnum::QuantizedS4, - "but %s", src.layout.dtype.name()); pooling2d::Param kern_param; size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3], c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3]; @@ -180,16 +183,24 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, kern_param.ph = ph, kern_param.pw = pw, kern_param.window_h = window_h, kern_param.window_w = window_w, kern_param.sh = sh, kern_param.sw = sw; + bool uint_case = false; + int zero_point = 0; + if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { + uint_case = true; + zero_point = src.layout.dtype.param() + .zero_point; + } auto&& stream = cuda_stream(handle()); pooling2d::do_pooling2d_int4_ncdiv64hw64( (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param, - stream, static_cast(param().mode)); - if (sdst.layout.ndim == 4) { - auto relayout_opr = handle()->create_operator(); - RelayoutFormat::Param trans_param; - trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW; - relayout_opr->param() = trans_param; - relayout_opr->exec(dst, sdst,{}); + stream, static_cast(param().mode), uint_case, + zero_point); + if (sdst.layout.ndim == 4) { + auto relayout_opr = handle()->create_operator(); + RelayoutFormat::Param trans_param; + trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW; + relayout_opr->param() = trans_param; + relayout_opr->exec(dst, sdst, {}); } return; } diff --git a/dnn/src/cuda/pooling/pooling2d_qint.cu b/dnn/src/cuda/pooling/pooling2d_qint.cu index 5a648951fb0857c22ce499797e7ef7d034e432b2..10a623a986ab99157cfd3dde1988a246f7ff7d15 100644 --- a/dnn/src/cuda/pooling/pooling2d_qint.cu +++ b/dnn/src/cuda/pooling/pooling2d_qint.cu @@ -29,53 +29,51 @@ __device__ __forceinline__ int pack_int8_to_int8x4(int8_t x, int8_t y, int8_t z, return ix; } -template +template __device__ __forceinline__ OutDtype pack_int8(int8_t (&x)[regs]); template <> -__device__ __forceinline__ int pack_int8<4, int8_t, int>(int8_t (&x)[4]) { +__device__ __forceinline__ int pack_int8<4, 8, int>(int8_t (&x)[4]) { return pack_int8_to_int8x4(x[0], x[1], x[2], x[3]); } template <> -__device__ __forceinline__ int2 pack_int8<8, int8_t, int2>(int8_t (&x)[8]) { +__device__ __forceinline__ int2 pack_int8<8, 8, int2>(int8_t (&x)[8]) { int8_t x0[4]{x[0], x[1], x[2], x[3]}; int8_t x1[4]{x[4], x[5], x[6], x[7]}; - return ::make_int2(pack_int8<4, int8_t, int>(x0), - pack_int8<4, int8_t, int>(x1)); + return ::make_int2(pack_int8<4, 8, int>(x0), pack_int8<4, 8, int>(x1)); } template <> -__device__ __forceinline__ int4 pack_int8<16, int8_t, int4>(int8_t (&x)[16]) { +__device__ __forceinline__ int4 pack_int8<16, 8, int4>(int8_t (&x)[16]) { int8_t x0[4]{x[0], x[1], x[2], x[3]}; int8_t x1[4]{x[4], x[5], x[6], x[7]}; int8_t x2[4]{x[8], x[9], x[10], x[11]}; int8_t x3[4]{x[12], x[13], x[14], x[15]}; - return ::make_int4( - pack_int8<4, int8_t, int>(x0), pack_int8<4, int8_t, int>(x1), - pack_int8<4, int8_t, int>(x2), pack_int8<4, int8_t, int>(x3)); + return ::make_int4(pack_int8<4, 8, int>(x0), pack_int8<4, 8, int>(x1), + pack_int8<4, 8, int>(x2), pack_int8<4, 8, int>(x3)); } __device__ __forceinline__ int8_t pack_int8_to_int4x2(int8_t x0, int8_t x1) { return (x0 & 0xf) | (x1 << 4); } template <> -__device__ __forceinline__ int pack_int8<8, dt_qint4, int>(int8_t (&x)[8]) { +__device__ __forceinline__ int pack_int8<8, 4, int>(int8_t (&x)[8]) { int8_t x0 = pack_int8_to_int4x2(x[0], x[1]); int8_t x1 = pack_int8_to_int4x2(x[2], x[3]); int8_t x2 = pack_int8_to_int4x2(x[4], x[5]); int8_t x3 = pack_int8_to_int4x2(x[6], x[7]); return pack_int8_to_int8x4(x0, x1, x2, x3); } + template <> -__device__ __forceinline__ int4 pack_int8<32, dt_qint4, int4>(int8_t (&x)[32]) { +__device__ __forceinline__ int4 pack_int8<32, 4, int4>(int8_t (&x)[32]) { int8_t x0[8]{x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]}; int8_t x1[8]{x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15]}; int8_t x2[8]{x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23]}; int8_t x3[8]{x[24], x[25], x[26], x[27], x[28], x[29], x[30], x[31]}; - return ::make_int4( - pack_int8<8, dt_qint4, int>(x0), pack_int8<8, dt_qint4, int>(x1), - pack_int8<8, dt_qint4, int>(x2), pack_int8<8, dt_qint4, int>(x3)); + return ::make_int4(pack_int8<8, 4, int>(x0), pack_int8<8, 4, int>(x1), + pack_int8<8, 4, int>(x2), pack_int8<8, 4, int>(x3)); } template @@ -88,6 +86,7 @@ struct TypeTrait { static constexpr int8_t min = -128; static constexpr int elem_per_32bit = 32 / bit_width; static constexpr int shift_fix_sign = 0; + static constexpr bool need_zero_pad = false; }; template <> @@ -97,6 +96,16 @@ struct TypeTrait { static constexpr int8_t min = -8; static constexpr int elem_per_32bit = 32 / bit_width; static constexpr int shift_fix_sign = 4; + static constexpr bool need_zero_pad = false; +}; +template <> +struct TypeTrait { + static constexpr int bit_width = 4; + static constexpr int mask = 0xf; + static constexpr int8_t min = 0; + static constexpr int elem_per_32bit = 32 / bit_width; + static constexpr int shift_fix_sign = 0; + static constexpr bool need_zero_pad = true; }; template @@ -108,7 +117,7 @@ struct MaxPooler { static constexpr int shift_fix_sign = TypeTrait::shift_fix_sign; int8_t res[nr_results]; - __device__ MaxPooler(int) {} + __device__ MaxPooler(int, int) {} __device__ __forceinline__ void init() { #pragma unroll for (int i = 0; i < nr_results; ++i) { @@ -137,7 +146,7 @@ struct MaxPooler { } __device__ __forceinline__ feed_type get_ans() { feed_type ans; - ans = pack_int8(res); + ans = pack_int8(res); return ans; } }; @@ -149,21 +158,27 @@ struct MeanIncludeRoundedPooler { static constexpr int nr_results = sizeof(feed_type) * 8 / bit_width; static constexpr int elem_per_32bit = TypeTrait::elem_per_32bit; static constexpr int shift_fix_sign = TypeTrait::shift_fix_sign; + static constexpr bool need_zero_pad = TypeTrait::need_zero_pad; int32_t res[nr_results]; const int count; const float fi_count; + int real_fi_count; + const int zero_pad; - __device__ MeanIncludeRoundedPooler(int count) - : count{count}, fi_count{1.f / count} {} + __device__ MeanIncludeRoundedPooler(int count, int zero_point) + : count{count}, fi_count{1.f / count}, zero_pad{zero_point} {} __device__ __forceinline__ void init() { #pragma unroll for (int i = 0; i < nr_results; ++i) { res[i] = 0; } + if (need_zero_pad) { + real_fi_count = 0; + } } - __device__ __forceinline__ void feed(int x, int idx = 0) { + __device__ __forceinline__ void feed(int x, int idx) { constexpr int unroll_n = sizeof(int) * 8 / bit_width; #pragma unroll for (int i = 0; i < unroll_n; i++) { @@ -173,15 +188,27 @@ struct MeanIncludeRoundedPooler { res[idx + i] += static_cast(temp); } } + __device__ __forceinline__ void feed(int x) { + feed(x, 0); + if (need_zero_pad) { + real_fi_count++; + } + } __device__ __forceinline__ void feed(int2 x) { feed(x.x, 0 * elem_per_32bit); feed(x.y, 1 * elem_per_32bit); + if (need_zero_pad) { + real_fi_count++; + } } __device__ __forceinline__ void feed(int4 x) { feed(x.x, 0 * elem_per_32bit); feed(x.y, 1 * elem_per_32bit); feed(x.z, 2 * elem_per_32bit); feed(x.w, 3 * elem_per_32bit); + if (need_zero_pad) { + real_fi_count++; + } } __device__ __forceinline__ feed_type get_ans() { feed_type ans; @@ -189,13 +216,18 @@ struct MeanIncludeRoundedPooler { #pragma unroll for (int i = 0; i < nr_results; i++) { float f32_res = roundf(static_cast(res[i]) * fi_count); + if (need_zero_pad) { + f32_res = roundf((static_cast(res[i]) + + (count - real_fi_count) * zero_pad) * + fi_count); + } int i8_res; asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(i8_res) : "f"(f32_res)); out_res[i] = i8_res; } - ans = pack_int8(out_res); + ans = pack_int8(out_res); return ans; } }; @@ -209,7 +241,7 @@ struct MeanExcludeRoundedPooler { static constexpr int shift_fix_sign = TypeTrait::shift_fix_sign; int32_t res[nr_results]; int count; - __device__ MeanExcludeRoundedPooler(int) {} + __device__ MeanExcludeRoundedPooler(int, int) {} __device__ __forceinline__ void init() { #pragma unroll @@ -257,7 +289,7 @@ struct MeanExcludeRoundedPooler { : "f"(f32_res)); out_res[i] = i8_res; } - ans = pack_int8(out_res); + ans = pack_int8(out_res); return ans; } }; @@ -290,7 +322,7 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( packed_ch * output_pixels * npack + (ho * param.wo + wo) * npack; - Pooler pooler(param.window_h * param.window_w); + Pooler pooler(param.window_h * param.window_w, 0); pooler.init(); for (int fh = 0; fh < param.window_h; fh++) { uint32_t ih = ho * param.sh + fh - param.ph; @@ -313,7 +345,7 @@ template __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, int8_t* __restrict__ dst, - Param param) { + Param param, int zero_point) { const int tid = blockIdx.x * blockDim.x + threadIdx.x; using ldg_type = typename Pooler::feed_type; static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); @@ -348,7 +380,7 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, dst + (batch * out_batch_stride + oc * out_channel_stride + (oh * param.wo + ow) * pack_byte + sec * ldg_width_bytes); - Pooler pooler(param.window_h * param.window_w); + Pooler pooler(param.window_h * param.window_w, zero_point); pooler.init(); for (int fh = 0; fh < param.window_h; fh++) { uint32_t ih = oh * param.sh + fh - param.ph; @@ -418,13 +450,12 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, after_kernel_launch(); } -void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, - int8_t* d_dst, - const Param& param, - cudaStream_t stream, - uint32_t mode) { +void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4( + const int8_t* d_src, int8_t* d_dst, const Param& param, + cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; - void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); + void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, + int zero_point); constexpr int ldg_byte = 4; constexpr int elem_per_byte = 1; constexpr int pack_size = 4; @@ -455,17 +486,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, uint32_t nr_threads = query_blocksize_for_kernel(kern); nr_threads = std::min(nr_threads, vthreads); uint32_t nr_blocks = DIVUP(vthreads, nr_threads); - kern<<>>(d_src, d_dst, param); + kern<<>>(d_src, d_dst, param, zero_point); after_kernel_launch(); } -void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, - int8_t* d_dst, - const Param& param, - cudaStream_t stream, - uint32_t mode) { +void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32( + const int8_t* d_src, int8_t* d_dst, const Param& param, + cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; - void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); + void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, + int zero_point); constexpr int ldg_byte = 16; constexpr int elem_per_byte = 1; constexpr int pack_size = 32; @@ -494,17 +524,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, uint32_t nr_threads = query_blocksize_for_kernel(kern); nr_threads = std::min(nr_threads, vthreads); uint32_t nr_blocks = DIVUP(vthreads, nr_threads); - kern<<>>(d_src, d_dst, param); + kern<<>>(d_src, d_dst, param, zero_point); after_kernel_launch(); } -void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, - int8_t* d_dst, - const Param& param, - cudaStream_t stream, - uint32_t mode) { +void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64( + const int8_t* d_src, int8_t* d_dst, const Param& param, + cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; - void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); + void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, + int zero_point); constexpr int ldg_byte = 16; constexpr int elem_per_byte = 2; constexpr int pack_size = 64; @@ -512,28 +541,50 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, constexpr int elem_per_thread = ldg_byte * elem_per_byte; uint32_t vthreads = param.n * param.c * param.ho * param.wo / elem_per_thread; - switch (mode) { - case Mode::MAX: - kern = pooling2d_device_template_nchwc, - pack_size, pack_byte>; - break; - case Mode::AVERAGE: - kern = pooling2d_device_template_nchwc< - MeanIncludeRoundedPooler, - pack_size, pack_byte>; - break; - case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: - kern = pooling2d_device_template_nchwc< - MeanExcludeRoundedPooler, - pack_size, pack_byte>; - break; - default: - megdnn_assert(false, "invalid pooling mode"); + if (uint_case) { + switch (mode) { + case Mode::MAX: + kern = pooling2d_device_template_nchwc< + MaxPooler, pack_size, pack_byte>; + break; + case Mode::AVERAGE: + kern = pooling2d_device_template_nchwc< + MeanIncludeRoundedPooler, + pack_size, pack_byte>; + break; + case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + kern = pooling2d_device_template_nchwc< + MeanExcludeRoundedPooler, + pack_size, pack_byte>; + break; + default: + megdnn_assert(false, "invalid pooling mode"); + } + + } else { + switch (mode) { + case Mode::MAX: + kern = pooling2d_device_template_nchwc, + pack_size, pack_byte>; + break; + case Mode::AVERAGE: + kern = pooling2d_device_template_nchwc< + MeanIncludeRoundedPooler, + pack_size, pack_byte>; + break; + case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + kern = pooling2d_device_template_nchwc< + MeanExcludeRoundedPooler, + pack_size, pack_byte>; + break; + default: + megdnn_assert(false, "invalid pooling mode"); + } } uint32_t nr_threads = query_blocksize_for_kernel(kern); nr_threads = std::min(nr_threads, vthreads); uint32_t nr_blocks = DIVUP(vthreads, nr_threads); - kern<<>>(d_src, d_dst, param); + kern<<>>(d_src, d_dst, param, zero_point); after_kernel_launch(); } diff --git a/dnn/src/cuda/pooling/pooling2d_qint.cuh b/dnn/src/cuda/pooling/pooling2d_qint.cuh index ed24e526e2c817193cdac320e0d0a1dad10b1f7a..1518b4a847b30afc100d827b88b0a0c58e1a2f66 100644 --- a/dnn/src/cuda/pooling/pooling2d_qint.cuh +++ b/dnn/src/cuda/pooling/pooling2d_qint.cuh @@ -27,15 +27,18 @@ void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, - uint32_t mode); + uint32_t mode, bool uint_case = false, + int zero_point = 0); void do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, - uint32_t mode); + uint32_t mode, bool uint_case = false, + int zero_point = 0); void do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, - uint32_t mode); + uint32_t mode, bool uint_case = false, + int zero_point = 0); } // namespace pooling2d } // namespace cuda diff --git a/dnn/test/cuda/pooling.cpp b/dnn/test/cuda/pooling.cpp index 033be709b1bbed141a38a3f9146b50c216acd7c8..8d01e53fe7d104471bf1e9a05f5f79173fcb6a42 100644 --- a/dnn/test/cuda/pooling.cpp +++ b/dnn/test/cuda/pooling.cpp @@ -254,6 +254,13 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW_Q4) { checker.set_param(param).exec({{20, 96, 22, 33}, {}}); param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; checker.set_param(param).exec({{20, 24, 22, 33}, {}}); + checker.set_dtype(0, dtype::Quantized4Asymm(3.1415926f, 3)); + param.format = Param::Format::NCHW; + checker.set_param(param).exec({{20, 64, 22, 33}, {}}); + param.mode = Param::Mode::AVERAGE; + checker.set_param(param).exec({{20, 96, 22, 33}, {}}); + param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; + checker.set_param(param).exec({{20, 24, 22, 33}, {}}); } TEST_F(CUDA, POOLING_FORWARD_NCHW4) { @@ -291,20 +298,36 @@ TEST_F(CUDA, POOLING_FORWARD_NCHW32) { } #endif -TEST_F(CUDA, POOLING_FORWARD_NCHW64) { +TEST_F(CUDA, POOLING_FORWARD_NCHW64_Q4) { require_compute_capability(7, 5); using Param = param::Pooling; Checker checker(handle_cuda()); - Param param{Param::Mode::MAX, 0, 0, 2, 2, 2, 2}; + Param param{Param::Mode::MAX, 1, 1, 2, 2, 2, 2}; UniformIntRNG int_rng{-8, 7}; checker.set_dtype(0, dtype::QuantizedS4(1.f)); param.format = Param::Format::NCHW64; checker.set_epsilon(1e-3).set_rng(0, &int_rng); - checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}}); + checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}}); param.mode = Param::Mode::AVERAGE; - checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}}); + checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}}); param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; - checker.set_param(param).exec({{64, 8, 28, 28, 64}, {}}); + checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}}); +} + +TEST_F(CUDA, POOLING_FORWARD_NCHW64_U4) { + require_compute_capability(7, 5); + using Param = param::Pooling; + Checker checker(handle_cuda()); + Param param{Param::Mode::MAX, 1, 1, 2, 2, 2, 2}; + UniformIntRNG int_rng{0, 15}; + checker.set_dtype(0, dtype::Quantized4Asymm(1.f, 3)); + param.format = Param::Format::NCHW64; + checker.set_epsilon(1e-3).set_rng(0, &int_rng); + checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}}); + param.mode = Param::Mode::AVERAGE; + checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}}); + param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; + checker.set_param(param).exec({{4, 8, 28, 28, 64}, {}}); } TEST_F(CUDA, POOLING_FORWARD_CHWN4) { diff --git a/dnn/test/naive/pooling.cpp b/dnn/test/naive/pooling.cpp index e8208ab901c66cc19bd54656dfc5d8a22246b32f..59da2d652cd48c506b9bbf8baecc7a57c0d67ed9 100644 --- a/dnn/test/naive/pooling.cpp +++ b/dnn/test/naive/pooling.cpp @@ -84,12 +84,12 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) { } { - auto u4_dt = dtype::Quantized4Asymm(1.f, 0); + auto u4_dt = dtype::Quantized4Asymm(0.1f, 3); std::vector u8_src_vec{1, 2, 3, 4, 5, 6, 7, 8, 9}; std::vector u8_max_dst_vec{1, 3, 7, 9}; - std::vector u8_avg_dst_vec{0, 1, 3, 7}; + std::vector u8_avg_dst_vec{3, 3, 4, 7}; std::vector u8_avg_exclu_dst_vec{1, 3, 6, 7}; Pooling::Param param{Mode::MAX, 1, 1, 2, 2, 2, 2}; Testcase input{TensorValueLowbit4({1, 1, 3, 3}, u4_dt, u8_src_vec), {}};