From 594fa722bd12a7b3827241f305aad82b57ffaee5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 11 Mar 2021 15:23:18 +0800 Subject: [PATCH] feat(mge/graph): add `modify_opr_algo_strategy_inplace` for fast-run GitOrigin-RevId: 034cf58b2a599b8541dad2a1b6d488f14d4295c4 --- .../megengine/core/tensor/megbrain_graph.py | 13 +++++++++++++ imperative/python/src/graph_rt.cpp | 16 ++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 9bd3aae5d..3f9a89730 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -288,6 +288,19 @@ def optimize_for_inference(dest_vars, **kwargs): return _wrap(res_vars) +def modify_opr_algo_strategy_inplace(dest_vars, strategy: str): + """ + C++ graph version of :func:`~.set_execution_strategy`. Used to inplacely modify + dumped graph's fast-run strategy. + + :param dest_vars: list of output vars in the computing graph. + :param strategy: fast-run algorithms strategy. + + """ + dest_vars = _unwrap(dest_vars) + _imperative_rt.modify_opr_algo_strategy_inplace(dest_vars, strategy) + + CompGraphDumpResult = collections.namedtuple( "CompGraphDumpResult", [ diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 1483ed3f8..db6f10a2e 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -32,6 +32,7 @@ namespace ser = mgb::serialization; using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions; using _LayoutTransform = _OptimizeForInferenceOptions::LayoutTransform; +using _AlgoStrategy = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; namespace { class _CompGraphProfilerImpl { @@ -257,6 +258,21 @@ void init_graph_rt(py::module m) { return vars; }); + m.def("modify_opr_algo_strategy_inplace", [](const VarNodeArray& dest_vars, const std::string& strategy) { + _AlgoStrategy stg; + const std::unordered_map> m{ + {"HEURISTIC", [&](){ stg = _AlgoStrategy::HEURISTIC; }}, + {"HEURISTIC_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::HEURISTIC_REPRODUCIBLE; }}, + {"PROFILE", [&](){ stg = _AlgoStrategy::PROFILE; }}, + {"PROFILE_REPRODUCIBLE", [&](){ stg = _AlgoStrategy::PROFILE_REPRODUCIBLE; }}, + {"PROFILE_HEURISTIC", [&](){ stg = _AlgoStrategy::PROFILE_HEURISTIC; }}, + }; + auto it = m.find(strategy); + mgb_assert(it != m.end(), "Invalid strategy string!"); + it->second(); + mgb::gopt::modify_opr_algo_strategy_inplace(dest_vars, stg); + }); + m.def("get_info_for_strip", [](const std::vector& dest_vars) { std::unordered_set opr_types, dtype_names, elemwise_modes; auto on_opr = [&](cg::OperatorNodeBase *opr) { -- GitLab