diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 25d906fd3deffa3d18807df907302194f53d4266..03e39fd8fe695bacd73641cb842e811441de3e95 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -508,15 +508,6 @@ AlgoChooser::AlgoChooserHelper::AlgoChooserHelper( m_fastrun_layouts, m_dnn_opr->param(), fastrun_batch_size); } - if (owner_graph()->options().no_profiling_on_shape_change) { - for (size_t i = 0; i < m_incache_layouts.size(); i++) { - for (size_t j = 0; j < m_incache_layouts.at(i).ndim; j++) { - m_incache_layouts.at(i)[j] = 0; - m_incache_layouts.at(i).stride[j] = 0; - } - } - } - mgb_assert(m_fastrun_layouts.size() == layouts.size()); static_assert( @@ -571,6 +562,12 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::AlgoChooserHelp if (policy.algo.valid()) { return policy; } + if (is_matmul()) { + mgb_log_warn( + "choose algo by heuristic, which may cause performance " + "regression."); + return choose_by_heuristic(selected_strategy); + } } typename AlgoChooser::ImplExecutionPolicy tmp_policy; @@ -1016,8 +1013,6 @@ std::pair AlgoChooser::AlgoChooserHelper:: } //! from graph option - // FIXME: no_profiling_on_shape_change extract USABLE_DEPEND_ON_SHAPE attribute when - // fixed usable if (owner_graph()->options().fast_run_config.shared_batch_size) { ret.second |= AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } diff --git a/src/opr/include/megbrain/opr/search_policy/profiler.h b/src/opr/include/megbrain/opr/search_policy/profiler.h index 775f33077c35682e79a48f9c646acd38f00472f3..59c465515a057af7700c7ad997e9f5392b321fb5 100644 --- a/src/opr/include/megbrain/opr/search_policy/profiler.h +++ b/src/opr/include/megbrain/opr/search_policy/profiler.h @@ -58,6 +58,13 @@ constexpr bool opr_contain_bias() { return std::is_same::value; } +//! matmul and batchedMatrixMul +template +constexpr bool is_matmul() { + return std::is_same::value || + std::is_same::value; +} + template struct PreprocessFilterImpl { using T = union {}; diff --git a/src/opr/test/algo_chooser.cpp b/src/opr/test/algo_chooser.cpp index 54d1f7789c404336f371c58f5c51eef196ccfda2..2a319b94a5581f88c8b412d6c1475c12aed7e038 100644 --- a/src/opr/test/algo_chooser.cpp +++ b/src/opr/test/algo_chooser.cpp @@ -292,56 +292,6 @@ TEST(TestOprDNN, FastrunIgnoreBatchSizeBatchedMatrixMul) { {TensorShape{4, 6, 8}, TensorShape{4, 8, 4}}); } -template -void test_no_profiling_on_shape_change( - const TensorShapeArray& inps0, const TensorShapeArray& inps1) { - using Policy = typename MgbOpr::ExecutionPolicy; - - int nr_set = 0; - auto on_get = [](const std::string&, const void*, size_t, const void*, size_t) {}; - auto on_set = [&nr_set]( - const std::string&, const void*, size_t, const void*, - size_t) { nr_set++; }; - PersistentCacheHook cache_hook{on_get, on_set}; - - auto cn = CompNode::load("xpu0"); - auto run = [&cn](const TensorShapeArray& shapes) { - auto graph = ComputingGraph::make(); - graph->options().no_profiling_on_shape_change = true; - - HostTensorGenerator<> gen; - auto host_a = gen(shapes[0], cn); - auto host_b = gen(shapes[1], cn); - HostTensorND host_out; - auto a = opr::Host2DeviceCopy::make(*graph, host_a), - b = opr::Host2DeviceCopy::make(*graph, host_b); - - Policy policy; - policy.strategy = Policy::Strategy::PROFILE; - auto out = MgbOpr::make(a, b, {}, policy, {}); - - std::unique_ptr func = graph->compile({{out, {}}}); - func->execute(); - }; - - run(inps0); - int nr = nr_set; - ASSERT_GT(nr, 0); - run(inps1); - ASSERT_EQ(nr, nr_set); -} - -TEST(TestOprDNN, FastrunNoProfilingOnShapeChange) { - REQUIRE_GPU(1); - megdnn::HeuristicCache::instance().clear(); - - test_no_profiling_on_shape_change( - {{12, 3, 36, 36}, {4, 3, 3, 3}}, {{32, 3, 28, 28}, {4, 3, 3, 3}}); - - test_no_profiling_on_shape_change( - {{20, 30}, {30, 40}}, {{30, 40}, {40, 60}}); -} - #endif // MGB_ENABLE_FASTRUN #endif // MGB_CUDA diff --git a/src/opr/test/blas.cpp b/src/opr/test/blas.cpp index c3828c2b20bdeef03d8b58bed2d1bb92ad1bd20e..7be6a410eea765ea2d98e7a051ba7e9a1e4bdd29 100644 --- a/src/opr/test/blas.cpp +++ b/src/opr/test/blas.cpp @@ -899,12 +899,11 @@ TEST(TestOprBlas, MatrixMulExePolicy) { graph->options().no_profiling_on_shape_change = true; auto func = graph->compile({make_callback_copy(matmul, host_y)}); func->execute(); - ASSERT_GT(nr_get, 0); - int nr = nr_get; + ASSERT_EQ(nr_get, 0); graph->options().no_profiling_on_shape_change = false; func = graph->compile({make_callback_copy(matmul, host_y)}); func->execute(); - ASSERT_GT(nr_get, nr); + ASSERT_GT(nr_get, 0); } #endif