提交 169fa53d 编写于 作者: M Megvii Engine Team

fix(mgb): fix execution_policy set of matmul

GitOrigin-RevId: 90f539b0bed2e1bb5da0103eeb2b46e4e9690c6f
上级 abe3c165
......@@ -98,6 +98,7 @@ size_t MatrixMul::get_workspace_size_bytes(
param ^= 1;
};
MGB_TRY {
megdnn_opr()->execution_policy() = {};
a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
//! Here we just want to save the execution policy got from setup_algo,
......@@ -106,24 +107,28 @@ size_t MatrixMul::get_workspace_size_bytes(
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i1, tparam.transposeB);
c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
}
MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d));
......@@ -252,29 +257,34 @@ size_t BatchedMatrixMul::get_workspace_size_bytes(
param ^= 1;
};
MGB_TRY {
megdnn_opr()->execution_policy() = {};
a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i1, tparam.transposeB);
c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
}
MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d));
......
......@@ -266,7 +266,6 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(bool reproducible) const {
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
opr->owner_graph(), opr->comp_node(),
opr->execution_policy().workspace_limit);
m_megdnn_opr->execution_policy() = {};
return APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, reproducible),
m_layouts);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册