From 5a355138568b84d04a1c5e282e8c8df05e55d39c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Jun 2022 13:19:07 +0800 Subject: [PATCH] fix(mgb): fix profile skip condition GitOrigin-RevId: f196eabc9810b5ac32744a10a1473a6b445f7f4a --- .../src/impl/transformations/dtype_promote.cpp | 18 ++++++++++++++++++ src/rdnn/impl/algo_chooser.cpp | 4 +++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 274ddf0e7..c4f4547a7 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -186,6 +186,15 @@ ValueRefList convolution_rule(const OpDef& op, Span inputs) { ValueRefList matmul_rule(const OpDef& op, Span inputs) { auto&& conv_op = const_cast(op.cast_final_safe()); SmallVector dtypes = get_value_dtypes(inputs); + + // skip dtype promotion when inputs are quantized + if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) { + mgb_assert( + dtypes[0].category() == dtypes[1].category(), + "inputs of matmul should have same quantized dtype."); + return imperative::apply(op, inputs); + } + mgb::DType target_dtype; if (DTypePromoteCfg::amp_dtype_autocast_enabled) { @@ -212,6 +221,15 @@ ValueRefList batch_matmul_rule(const OpDef& op, Span inputs) { auto&& conv_op = const_cast(op.cast_final_safe()); SmallVector dtypes = get_value_dtypes(inputs); + + // skip dtype promotion when inputs are quantized + if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) { + mgb_assert( + dtypes[0].category() == dtypes[1].category(), + "inputs of batched matmul should have same quantized dtype."); + return imperative::apply(op, inputs); + } + mgb::DType target_dtype; if (DTypePromoteCfg::amp_dtype_autocast_enabled) { diff --git a/src/rdnn/impl/algo_chooser.cpp b/src/rdnn/impl/algo_chooser.cpp index 3a496f8ba..9630b421a 100644 --- a/src/rdnn/impl/algo_chooser.cpp +++ b/src/rdnn/impl/algo_chooser.cpp @@ -600,7 +600,9 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::AlgoChooserHelp auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn); // skip different sub opr, for example: // skip matmul algo when profiling convolution - if (m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type()) + if ((m_cn.device_type() == mgb::CompNode::DeviceType::CUDA || + m_cn.device_type() == mgb::CompNode::DeviceType::ROCM) && + m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type()) continue; megdnn_opr->param() = Algorithm::deserialize_read_pod(_item.param); -- GitLab