提交 1a2ed8c4 编写于 作者: M Megvii Engine Team

feat(cuda): add convbias ptx algo testcase

GitOrigin-RevId: 9ad6d4561fb9da95708d35dc114c58bea77debf7
上级 64551105
......@@ -524,6 +524,42 @@ std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size) {
return args;
}
std::vector<TestArg> get_int4_nchw64_args_ptx(size_t kernel_size, bool is_uint4) {
std::vector<TestArg> args;
param::ConvBias cur_param;
using NLMode = param::ConvBias::NonlineMode;
// clang-format off
for (auto nlmode : {NLMode::RELU, NLMode::IDENTITY}) {//{NLMode::H_SWISH} are not currently supported
for (auto mode : {param::ConvBias::Mode::CROSS_CORRELATION}) {
for (size_t b : {3, 7}) {
for (size_t ic : {64, 128}) {
for (size_t oc : {64, 320}) {
for (size_t h : {13}) {
for (size_t w : {28}) {
for (int p : {0, static_cast<int>(kernel_size / 2)}) {
for (size_t s : {1, 2}) {
if (is_uint4 && nlmode == NLMode::H_SWISH) continue;
size_t f = kernel_size;
cur_param.mode = mode;
cur_param.nonlineMode = nlmode;
cur_param.format = param::ConvBias::Format::NCHW64;
cur_param.sparse = param::ConvBias::Sparse::DENSE;
cur_param.pad_h = cur_param.pad_w = p;
cur_param.stride_h = cur_param.stride_w = s;
//! bias channel
args.emplace_back(cur_param, TensorShape{b, ic / 64, h, w, 64},
TensorShape{oc, ic / 64, f, f, 64},
TensorShape{1, oc / 64, 1, 1, 64});
} } } } } } } } }
// clang-format on
return args;
}
std::vector<TestArg> get_int8_nchw4_args_small_batch(size_t kernel_size) {
std::vector<TestArg> args;
param::ConvBias cur_param;
......
......@@ -30,6 +30,7 @@ std::vector<TestArg> get_quantized_winograd_mk_packed_args(
std::vector<TestArg> get_quantized_args_with_nlmode(
param::ConvBias::NonlineMode nlmode);
std::vector<TestArg> get_quantized_args();
std::vector<TestArg> get_int4_nchw64_args_ptx(size_t kernel_size, bool is_uint4);
std::vector<TestArg> get_int8_nchw4_args(size_t kernel_size);
std::vector<TestArg> get_int8_nchw4_args_check_bounds(size_t kernel_size);
std::vector<TestArg> get_int8_nchw4_small_channel_args(size_t kernel_size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册