From 15d3b3b941743aad9e02c87cd27f41cf0a997df6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 22 Mar 2021 18:02:39 +0800 Subject: [PATCH] fix(mgb): fix mgb still profiling matmul even when no-profiling-on-shape-change GitOrigin-RevId: d24f73193eaf3a428e466c10e79b29e3fc1e71d6 --- src/opr/impl/search_policy/algo_chooser.cpp | 9 ++++- .../megbrain/opr/search_policy/profiler.h | 7 ++++ src/opr/test/blas.cpp | 39 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 9f712c40..a77a2af7 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -384,8 +384,15 @@ AlgoChooser::choose_by_profile(ExeContext& ctx, MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) if (ctx.owner_graph()->options().no_profiling_on_shape_change) { auto policy = ctx.megdnn_opr()->execution_policy(); - if (policy.algo.valid()) + if (policy.algo.valid()){ return policy; + } + if (!algo_usable_on_shape_change()) { + mgb_log_warn( + "choose algo by heuristic, which may cause performance " + "regression."); + return ctx.choose_by_heuristic(selected_strategy); + } } if (enable_update) { diff --git a/src/opr/include/megbrain/opr/search_policy/profiler.h b/src/opr/include/megbrain/opr/search_policy/profiler.h index da91abab..486563cb 100644 --- a/src/opr/include/megbrain/opr/search_policy/profiler.h +++ b/src/opr/include/megbrain/opr/search_policy/profiler.h @@ -89,6 +89,13 @@ constexpr bool opr_contain_bias() { return std::is_same::value; } +//! matmul and batchedMatrixMul may not be usable once shape changed +template +constexpr bool algo_usable_on_shape_change() { + return !(std::is_same::value || + std::is_same::value); +} + template struct PreprocessFilterImpl { using T = union {}; diff --git a/src/opr/test/blas.cpp b/src/opr/test/blas.cpp index e02f2b97..c0198042 100644 --- a/src/opr/test/blas.cpp +++ b/src/opr/test/blas.cpp @@ -885,5 +885,44 @@ TEST(TestOprBlas, SingularValueDecompositionZeroGrad) { run_svd_empty_grad_test<1, 1, 1>(); } +#if MGB_ENABLE_FASTRUN +TEST(TestOprBlas, MatrixMulExePolicy) { + using Param = opr::MatrixMul::Param; + Param param; + using Policy = opr::MatrixMul::ExecutionPolicy; + using S = Policy::Strategy; + Policy policy; + policy.strategy = S::PROFILE; + + auto cn = CompNode::load("cpux"); + + int nr_get = 0; + auto on_get = [&nr_get](const std::string&, const void*, size_t, + const void*, size_t) { ++nr_get; }; + PersistentCacheHook cache_hook{on_get}; + + auto graph = ComputingGraph::make(); + HostTensorGenerator<> gen; + + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp), cn).rename(name); + }; + + auto a = mkvar("a", {20, 50}); + auto b = mkvar("b", {50, 40}); + auto matmul = opr::MatrixMul::make(a, b, param, policy, {}); + + HostTensorND host_y; + graph->options().no_profiling_on_shape_change = true; + auto func = graph->compile({make_callback_copy(matmul, host_y)}); + func->execute(); + ASSERT_EQ(nr_get, 0); + graph->options().no_profiling_on_shape_change = false; + func = graph->compile({make_callback_copy(matmul, host_y)}); + func->execute(); + ASSERT_GT(nr_get, 0); +} +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // -- GitLab