提交 5a355138 编写于 作者: M Megvii Engine Team

fix(mgb): fix profile skip condition

GitOrigin-RevId: f196eabc9810b5ac32744a10a1473a6b445f7f4a
上级 5bdc430e
...@@ -186,6 +186,15 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) { ...@@ -186,6 +186,15 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) {
ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) { ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>()); auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>());
SmallVector<DType> dtypes = get_value_dtypes(inputs); SmallVector<DType> dtypes = get_value_dtypes(inputs);
// skip dtype promotion when inputs are quantized
if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) {
mgb_assert(
dtypes[0].category() == dtypes[1].category(),
"inputs of matmul should have same quantized dtype.");
return imperative::apply(op, inputs);
}
mgb::DType target_dtype; mgb::DType target_dtype;
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
...@@ -212,6 +221,15 @@ ValueRefList batch_matmul_rule(const OpDef& op, Span<ValueRef> inputs) { ...@@ -212,6 +221,15 @@ ValueRefList batch_matmul_rule(const OpDef& op, Span<ValueRef> inputs) {
auto&& conv_op = auto&& conv_op =
const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>()); const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>());
SmallVector<DType> dtypes = get_value_dtypes(inputs); SmallVector<DType> dtypes = get_value_dtypes(inputs);
// skip dtype promotion when inputs are quantized
if (dtypes[0].category() == megdnn::DTypeCategory::QUANTIZED) {
mgb_assert(
dtypes[0].category() == dtypes[1].category(),
"inputs of batched matmul should have same quantized dtype.");
return imperative::apply(op, inputs);
}
mgb::DType target_dtype; mgb::DType target_dtype;
if (DTypePromoteCfg::amp_dtype_autocast_enabled) { if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
......
...@@ -600,7 +600,9 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp ...@@ -600,7 +600,9 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::AlgoChooserHelp
auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn); auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn);
// skip different sub opr, for example: // skip different sub opr, for example:
// skip matmul algo when profiling convolution // skip matmul algo when profiling convolution
if (m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type()) if ((m_cn.device_type() == mgb::CompNode::DeviceType::CUDA ||
m_cn.device_type() == mgb::CompNode::DeviceType::ROCM) &&
m_dnn_opr->get_opr_type() != megdnn_opr->get_opr_type())
continue; continue;
megdnn_opr->param() = megdnn_opr->param() =
Algorithm::deserialize_read_pod<typename _Opr::Param>(_item.param); Algorithm::deserialize_read_pod<typename _Opr::Param>(_item.param);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册