提交 15d3b3b9 编写于 作者: M Megvii Engine Team

fix(mgb): fix mgb still profiling matmul even when no-profiling-on-shape-change

GitOrigin-RevId: d24f73193eaf3a428e466c10e79b29e3fc1e71d6
上级 31e4bf2c
......@@ -384,8 +384,15 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx,
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile")))
if (ctx.owner_graph()->options().no_profiling_on_shape_change) {
auto policy = ctx.megdnn_opr()->execution_policy();
if (policy.algo.valid())
if (policy.algo.valid()){
return policy;
}
if (!algo_usable_on_shape_change<Opr>()) {
mgb_log_warn(
"choose algo by heuristic, which may cause performance "
"regression.");
return ctx.choose_by_heuristic(selected_strategy);
}
}
if (enable_update) {
......
......@@ -89,6 +89,13 @@ constexpr bool opr_contain_bias() {
return std::is_same<Opr, megdnn::ConvBias>::value;
}
//! matmul and batchedMatrixMul may not be usable once shape changed
template <typename Opr>
constexpr bool algo_usable_on_shape_change() {
return !(std::is_same<Opr, megdnn::MatrixMul>::value ||
std::is_same<Opr, megdnn::BatchedMatrixMul>::value);
}
template <typename Opr, bool has_prep>
struct PreprocessFilterImpl {
using T = union {};
......
......@@ -885,5 +885,44 @@ TEST(TestOprBlas, SingularValueDecompositionZeroGrad) {
run_svd_empty_grad_test<1, 1, 1>();
}
#if MGB_ENABLE_FASTRUN
TEST(TestOprBlas, MatrixMulExePolicy) {
using Param = opr::MatrixMul::Param;
Param param;
using Policy = opr::MatrixMul::ExecutionPolicy;
using S = Policy::Strategy;
Policy policy;
policy.strategy = S::PROFILE;
auto cn = CompNode::load("cpux");
int nr_get = 0;
auto on_get = [&nr_get](const std::string&, const void*, size_t,
const void*, size_t) { ++nr_get; };
PersistentCacheHook cache_hook{on_get};
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name);
};
auto a = mkvar("a", {20, 50});
auto b = mkvar("b", {50, 40});
auto matmul = opr::MatrixMul::make(a, b, param, policy, {});
HostTensorND host_y;
graph->options().no_profiling_on_shape_change = true;
auto func = graph->compile({make_callback_copy(matmul, host_y)});
func->execute();
ASSERT_EQ(nr_get, 0);
graph->options().no_profiling_on_shape_change = false;
func = graph->compile({make_callback_copy(matmul, host_y)});
func->execute();
ASSERT_GT(nr_get, 0);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
//
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册