From 169fa53d547f0a975be3399ea1c284c0bb1da78f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 19 Jan 2021 11:32:39 +0800 Subject: [PATCH] fix(mgb): fix execution_policy set of matmul GitOrigin-RevId: 90f539b0bed2e1bb5da0103eeb2b46e4e9690c6f --- src/opr/impl/blas.cpp | 10 ++++++++++ src/opr/impl/search_policy/algo_chooser.cpp | 1 - 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/opr/impl/blas.cpp b/src/opr/impl/blas.cpp index 5a5c1e56..bb89a0a1 100644 --- a/src/opr/impl/blas.cpp +++ b/src/opr/impl/blas.cpp @@ -98,6 +98,7 @@ size_t MatrixMul::get_workspace_size_bytes( param ^= 1; }; MGB_TRY { + megdnn_opr()->execution_policy() = {}; a = AlgoChooser::setup_algo({i0, i1, out}, megdnn_opr(), this); //! Here we just want to save the execution policy got from setup_algo, @@ -106,24 +107,28 @@ size_t MatrixMul::get_workspace_size_bytes( const_cast(this) ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = megdnn_opr()->execution_policy(); + megdnn_opr()->execution_policy() = {}; transpose(i0, tparam.transposeA); b = AlgoChooser::setup_algo({i0, i1, out}, megdnn_opr(), this); const_cast(this) ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = megdnn_opr()->execution_policy(); + megdnn_opr()->execution_policy() = {}; transpose(i1, tparam.transposeB); c = AlgoChooser::setup_algo({i0, i1, out}, megdnn_opr(), this); const_cast(this) ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = megdnn_opr()->execution_policy(); + megdnn_opr()->execution_policy() = {}; transpose(i0, tparam.transposeA); d = AlgoChooser::setup_algo({i0, i1, out}, megdnn_opr(), this); const_cast(this) ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = megdnn_opr()->execution_policy(); + megdnn_opr()->execution_policy() = {}; } MGB_FINALLY({ tparam = this->param(); }); return std::max(std::max(a, b), std::max(c, d)); @@ -252,29 +257,34 @@ size_t BatchedMatrixMul::get_workspace_size_bytes( param ^= 1; }; MGB_TRY { + megdnn_opr()->execution_policy() = {}; a = AlgoChooser::setup_algo( {i0, i1, out}, megdnn_opr(), this); const_cast(this) ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = megdnn_opr()->execution_policy(); + megdnn_opr()->execution_policy() = {}; transpose(i0, tparam.transposeA); b = AlgoChooser::setup_algo( {i0, i1, out}, megdnn_opr(), this); const_cast(this) ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = megdnn_opr()->execution_policy(); + megdnn_opr()->execution_policy() = {}; transpose(i1, tparam.transposeB); c = AlgoChooser::setup_algo( {i0, i1, out}, megdnn_opr(), this); const_cast(this) ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = megdnn_opr()->execution_policy(); + megdnn_opr()->execution_policy() = {}; transpose(i0, tparam.transposeA); d = AlgoChooser::setup_algo( {i0, i1, out}, megdnn_opr(), this); const_cast(this) ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] = megdnn_opr()->execution_policy(); + megdnn_opr()->execution_policy() = {}; } MGB_FINALLY({ tparam = this->param(); }); return std::max(std::max(a, b), std::max(c, d)); diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index b6e07a2c..f9d72994 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -266,7 +266,6 @@ AlgoChooser::ExeContext::choose_by_heuristic(bool reproducible) const { auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( opr->owner_graph(), opr->comp_node(), opr->execution_policy().workspace_limit); - m_megdnn_opr->execution_policy() = {}; return APPLY(m_megdnn_opr->get_algorithm_info_heuristic( args..., workspace_limit, reproducible), m_layouts); -- GitLab