diff --git a/imperative/src/impl/transformations/dtype_promote.cpp b/imperative/src/impl/transformations/dtype_promote.cpp index 274ddf0e764166f38a38905b1317f0e3dda4e653..c4f4547a7d3f697b8df239266fa36ea648a75130 100644 --- a/imperative/src/impl/transformations/dtype_promote.cpp +++ b/imperative/src/impl/transformations/dtype_promote.cpp @@ -186,6 +186,15 @@ ValueRefList convolution_rule(const OpDef& op, Span inputs) { ValueRefList matmul_rule(const OpDef& op, Span inputs) { auto&& conv_op = const_cast(op.cast_final_safe()); SmallVector dtypes = get_value_dtypes(inputs); + + // skip dtype promotion when inputs are quantized + if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) { + mgb_assert( + dtypes[0].category() == dtypes[1].category(), + "inputs of matmul should have same quantized dtype."); + return imperative::apply(op, inputs); + } + mgb::DType target_dtype; if (DTypePromoteCfg::amp_dtype_autocast_enabled) { @@ -212,6 +221,15 @@ ValueRefList batch_matmul_rule(const OpDef& op, Span inputs) { auto&& conv_op = const_cast(op.cast_final_safe()); SmallVector dtypes = get_value_dtypes(inputs); + + // skip dtype promotion when inputs are quantized + if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) { + mgb_assert( + dtypes[0].category() == dtypes[1].category(), + "inputs of batched matmul should have same quantized dtype."); + return imperative::apply(op, inputs); + } + mgb::DType target_dtype; if (DTypePromoteCfg::amp_dtype_autocast_enabled) { diff --git a/src/rdnn/impl/algo_chooser.cpp b/src/rdnn/impl/algo_chooser.cpp index 3a496f8ba7c5183eaa216f8a218ab19b494bdd28..9630b421afcea4b40db080dbbe018243ea36153d 100644 --- a/src/rdnn/impl/algo_chooser.cpp +++ b/src/rdnn/impl/algo_chooser.cpp @@ -600,7 +600,9 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::AlgoChooserHelp auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn); // skip different sub opr, for example: // skip matmul algo when profiling convolution - if (m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type()) + if ((m_cn.device_type() == mgb::CompNode::DeviceType::CUDA || + m_cn.device_type() == mgb::CompNode::DeviceType::ROCM) && + m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type()) continue; megdnn_opr->param() = Algorithm::deserialize_read_pod(_item.param);