diff --git a/dnn/test/common/conv_bias.cpp b/dnn/test/common/conv_bias.cpp index 0deba21ecec147a37afa12d3fb80c2552b84a7ed..5cb3248446ac0d2b9cd4b61ec3f2dd7217651553 100644 --- a/dnn/test/common/conv_bias.cpp +++ b/dnn/test/common/conv_bias.cpp @@ -524,6 +524,42 @@ std::vector get_int8_nchw4_args_check_bounds(size_t kernel_size) { return args; } +std::vector get_int4_nchw64_args_ptx(size_t kernel_size, bool is_uint4) { + std::vector args; + param::ConvBias cur_param; + + using NLMode = param::ConvBias::NonlineMode; + + // clang-format off + for (auto nlmode : {NLMode::RELU, NLMode::IDENTITY}) {//{NLMode::H_SWISH} are not currently supported + for (auto mode : {param::ConvBias::Mode::CROSS_CORRELATION}) { + for (size_t b : {3, 7}) { + for (size_t ic : {64, 128}) { + for (size_t oc : {64, 320}) { + for (size_t h : {13}) { + for (size_t w : {28}) { + for (int p : {0, static_cast(kernel_size / 2)}) { + for (size_t s : {1, 2}) { + if (is_uint4 && nlmode == NLMode::H_SWISH) continue; + size_t f = kernel_size; + cur_param.mode = mode; + cur_param.nonlineMode = nlmode; + cur_param.format = param::ConvBias::Format::NCHW64; + cur_param.sparse = param::ConvBias::Sparse::DENSE; + cur_param.pad_h = cur_param.pad_w = p; + cur_param.stride_h = cur_param.stride_w = s; + + //! bias channel + args.emplace_back(cur_param, TensorShape{b, ic / 64, h, w, 64}, + TensorShape{oc, ic / 64, f, f, 64}, + TensorShape{1, oc / 64, 1, 1, 64}); + + } } } } } } } } } + // clang-format on + + return args; +} + std::vector get_int8_nchw4_args_small_batch(size_t kernel_size) { std::vector args; param::ConvBias cur_param; diff --git a/dnn/test/common/conv_bias.h b/dnn/test/common/conv_bias.h index b0c093fca2fe61d4bc801da666d650e24ed26240..4a3c5fc348bb63d02eaabcb4ae03a1e9985bb17d 100644 --- a/dnn/test/common/conv_bias.h +++ b/dnn/test/common/conv_bias.h @@ -30,6 +30,7 @@ std::vector get_quantized_winograd_mk_packed_args( std::vector get_quantized_args_with_nlmode( param::ConvBias::NonlineMode nlmode); std::vector get_quantized_args(); +std::vector get_int4_nchw64_args_ptx(size_t kernel_size, bool is_uint4); std::vector get_int8_nchw4_args(size_t kernel_size); std::vector get_int8_nchw4_args_check_bounds(size_t kernel_size); std::vector get_int8_nchw4_small_channel_args(size_t kernel_size);