提交 8d248a6a 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(dnn/cuda): fix testcase for fallback nchw qs8 conv

GitOrigin-RevId: 646440db59f0157a3fdbd8061167f9ac04dbd422
上级 894a2407
...@@ -353,7 +353,8 @@ bool megdnn::check_bias_share_in_channel(const TensorLayout& bias, ...@@ -353,7 +353,8 @@ bool megdnn::check_bias_share_in_channel(const TensorLayout& bias,
format == param::ConvBias::Format::NCHW4_NCHW) { format == param::ConvBias::Format::NCHW4_NCHW) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 &&
bias[3] == 1); bias[3] == 1);
} else if (format == param::ConvBias::Format::NHWC) { } else if (format == param::ConvBias::Format::NHWC ||
format == param::ConvBias::Format::NCHW4_NHWC) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 &&
bias[2] == 1); bias[2] == 1);
} else if (format == param::ConvBias::Format::NCHW4 || } else if (format == param::ConvBias::Format::NCHW4 ||
......
...@@ -84,8 +84,12 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list( ...@@ -84,8 +84,12 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list(
inner_dst_layout, inner_bias_layout, inner_z_layout); inner_dst_layout, inner_bias_layout, inner_z_layout);
Param inner_conv_param = o->param(); Param inner_conv_param = o->param();
inner_conv_param.format = Param::Format::NCHW4; if (layouts[4].dtype.enumv() == DTypeEnum::Float32) {
inner_conv_param.format = Param::Format::NCHW4_NCHW;
} else {
inner_conv_param.format = Param::Format::NCHW4;
}
std::string param_str; std::string param_str;
Algorithm::serialize_write_pod(inner_conv_param, param_str); Algorithm::serialize_write_pod(inner_conv_param, param_str);
...@@ -192,9 +196,9 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( ...@@ -192,9 +196,9 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec(
inner_conv_param.format = inner_conv_param.format =
dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4; dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4;
auto inner_opr = args.handle->create_operator<ConvBiasForward>(); auto inner_opr = args.handle->create_operator<ConvBiasForward>();
inner_opr->param() = inner_conv_param;
set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr, set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr,
inner_opr.get()); inner_opr.get());
inner_opr->param() = inner_conv_param;
relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {}); relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {});
relayout_weight->exec(*args.filter_tensor, inner_weight, {}); relayout_weight->exec(*args.filter_tensor, inner_weight, {});
......
...@@ -701,9 +701,11 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { ...@@ -701,9 +701,11 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) {
TEST_F(CUDA, FALLBACK_CONV_QS8) { TEST_F(CUDA, FALLBACK_CONV_QS8) {
require_compute_capability_eq(7, 5); require_compute_capability_eq(7, 5);
Checker<ConvBiasForward> checker(handle_cuda()); Checker<ConvBiasForward> checker(handle_cuda());
auto check = [&checker](const std::string&& algo) { auto check = [&checker](const std::string&& algo,
const std::string&& sub_algo) {
checker.set_before_exec_callback( checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo.c_str())); conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
{algo.c_str(), {sub_algo.c_str()}}));
UniformIntRNG rng{-3, 3}; UniformIntRNG rng{-3, 3};
UniformIntRNG bias_rng{-50, 50}; UniformIntRNG bias_rng{-50, 50};
checker.set_rng(0, &rng) checker.set_rng(0, &rng)
...@@ -733,15 +735,17 @@ TEST_F(CUDA, FALLBACK_CONV_QS8) { ...@@ -733,15 +735,17 @@ TEST_F(CUDA, FALLBACK_CONV_QS8) {
{}, {},
{}}); {}});
}; };
check("FALLBACK_CONV_NCHW_QS8"); check("FALLBACK_CONV_NCHW_QS8", "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM");
} }
TEST_F(CUDA, FALLBACK_CONV_QS8_F32) { TEST_F(CUDA, FALLBACK_CONV_QS8_F32) {
require_compute_capability_eq(7, 5); require_compute_capability_eq(7, 5);
Checker<ConvBiasForward> checker(handle_cuda()); Checker<ConvBiasForward> checker(handle_cuda());
auto check = [&checker](const std::string&& algo) { auto check = [&checker](const std::string&& algo,
const std::string&& sub_algo) {
checker.set_before_exec_callback( checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo.c_str())); conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
{algo.c_str(), {sub_algo.c_str()}}));
UniformIntRNG rng{-3, 3}; UniformIntRNG rng{-3, 3};
UniformFloatRNG bias_rng{-50.f, 50.f}; UniformFloatRNG bias_rng{-50.f, 50.f};
checker.set_rng(0, &rng) checker.set_rng(0, &rng)
...@@ -771,7 +775,7 @@ TEST_F(CUDA, FALLBACK_CONV_QS8_F32) { ...@@ -771,7 +775,7 @@ TEST_F(CUDA, FALLBACK_CONV_QS8_F32) {
{}, {},
{}}); {}});
}; };
check("FALLBACK_CONV_NCHW_QS8"); check("FALLBACK_CONV_NCHW_QS8", "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM");
} }
TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_WEIGHT_PREPROCESS) { TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_WEIGHT_PREPROCESS) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册