From 1bdb26f9e831e03026c405fa648c770142f343ed Mon Sep 17 00:00:00 2001 From: BowenK Date: Thu, 3 Sep 2020 09:38:43 +0800 Subject: [PATCH] Warming up python pass by adding inline passes before it --- .../ccsrc/frontend/optimizer/ad/kprim.cc | 10 ++- mindspore/ccsrc/frontend/optimizer/py_pass.cc | 66 +++++++++++-------- mindspore/ccsrc/frontend/optimizer/py_pass.h | 3 +- .../frontend/optimizer/py_pass_manager.cc | 28 ++++---- .../frontend/optimizer/py_pass_manager.h | 6 +- mindspore/ccsrc/pipeline/jit/action.cc | 15 ++++- mindspore/ccsrc/pipeline/jit/pass.cc | 10 +++ mindspore/ccsrc/pipeline/jit/pass.h | 1 + mindspore/graph_utils/python_pass/__init__.py | 8 +-- .../python_pass/python_pass_register.py | 21 +++--- tests/ut/python/optimizer/test_python_pass.py | 40 +++++------ 11 files changed, 129 insertions(+), 79 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 4b3c1fd74..028a97243 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -49,7 +49,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + grad_op_child_scope_prefix + prim->name()); ScopeGuard scope_guard(scope); - py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast()->GetBpropFunction(); + py::function fn; + if (prim->is_base()) { + fn = GetBpropFunction(prim->name()); + } else { + fn = prim->cast()->GetBpropFunction(); + if (py::isinstance(fn)) { + fn = GetBpropFunction(prim->name()); + } + } if (!fn || py::isinstance(fn)) { MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; return nullptr; diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc index 877da5f5a..a0477223e 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -35,8 +35,10 @@ namespace internal { const char PARAMETER_MODULE[] = "mindspore.common.parameter"; const char PARAMETER_CLASS[] = "Parameter"; const char SET_PARAM[] = "__setattr__"; -AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph); -AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res); +AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph, + const FuncGraphPtr &top_graph); +AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, + const MatchResultPtr &res); void ReflectParamBackToPython(const AnfNodePtr ¶m, string param_name, tensor::TensorPtr default_input, bool requires_grad, bool layerwise_parallel); @@ -72,7 +74,8 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res) return std::make_shared(input_tensor); } -AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) { +AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg, + const FuncGraphPtr &top_graph) { auto call_pattern = pattern->cast(); MS_EXCEPTION_IF_NULL(call_pattern); auto prim = call_pattern->prim_value(); @@ -81,20 +84,20 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP } auto prim_pattern = call_pattern->prim_pattern(); MS_EXCEPTION_IF_NULL(prim_pattern); - return ProcessSinglePattern(prim_pattern, res, fg); + return ProcessSinglePattern(prim_pattern, res, fg, top_graph); } -AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { +AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) { auto new_para_pattern = pattern->cast(); MS_EXCEPTION_IF_NULL(new_para_pattern); if (!new_para_pattern->built()) { static int parameter_id = 0; auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++); - auto para_node = std::make_shared(func_graph); + auto para_node = std::make_shared(top_graph); MS_EXCEPTION_IF_NULL(para_node); para_node->set_name(para_name); // Set function graph - para_node->set_func_graph(func_graph); + para_node->set_func_graph(top_graph); // Set Debug Info auto debug_info = std::make_shared(para_name); para_node->set_debug_info(debug_info); @@ -103,7 +106,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re MS_EXCEPTION_IF_NULL(default_value); para_node->set_abstract(default_value->ToAbstract()->Broaden()); res->add_entry(pattern, para_node); - func_graph->add_parameter(para_node); + top_graph->add_parameter(para_node); // Reflect back to Cell._params internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(), new_para_pattern->layerwise_parallel()); @@ -126,7 +129,8 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) { return std::make_shared(scalar_value_ptr); } -AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) { +AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph, + const FuncGraphPtr &top_graph) { auto target_node = res->get_node(pattern); if (target_node != nullptr) { // If pattern is NewParameter, check whether it shouldn't last and is not built @@ -141,9 +145,10 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr } else if (pattern->isa()) { return BuildNewTensor(pattern, res); } else if (pattern->isa()) { - return BuildPrimitiveValueNode(pattern, res, func_graph); + return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph); } else if (pattern->isa()) { - return BuildNewParameter(pattern, res, func_graph); + // Add new parameter to top graph instead of current graph + return BuildNewParameter(pattern, res, top_graph); } else if (pattern->isa()) { return BuildImmNode(pattern, res); } else { @@ -154,17 +159,18 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr } AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res, - const FuncGraphPtr &func_graph) { + const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph) { if (pattern->isa()) { - return BuildPrimitiveValueNode(pattern, res, func_graph); + return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph); } return nullptr; } -AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) { +AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, + const MatchResultPtr &res) { auto target_inputs = pattern->inputs(); if (target_inputs.size() == 0) { - auto new_node = ProcessSinglePattern(pattern, res, func_graph); + auto new_node = ProcessSinglePattern(pattern, res, func_graph, top_graph); if (new_node != nullptr) { res->add_entry(pattern, new_node); } @@ -172,14 +178,14 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph } // Build up the AnfNode in a recursive manner std::vector new_inputs; - auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph); + auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph, top_graph); MS_EXCEPTION_IF_NULL(prim_value_node); new_inputs.push_back(prim_value_node); for (auto &iter : target_inputs) { if (iter == pattern) { MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n"; } - auto input_node = BuildTarget(iter, func_graph, res); + auto input_node = BuildTarget(iter, func_graph, top_graph, res); if (input_node == nullptr) { MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n"; } @@ -240,11 +246,12 @@ void Reset(PatternPtr pattern) { } // namespace internal -AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) { +AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node, + const MatchResultPtr &res) { auto match_res = src_pattern_->match(node); if (match_res != nullptr) { res->merge(match_res); - auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res); + auto new_node = internal::BuildTarget(dst_pattern_, func_graph, top_graph, res); internal::Reset(dst_pattern()); return new_node; } @@ -284,16 +291,19 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) } FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - auto graph_nodes_sorted = TopoSort(func_graph->output()); + auto func_graphs = manager->func_graphs(); bool changes = false; - - // Traverse once - for (auto &node : graph_nodes_sorted) { - AnfNodePtr new_node = Run(func_graph, node, res); - if (new_node != nullptr && new_node != node) { - (void)manager->Replace(node, new_node); - changes = true; + for (auto &fg : func_graphs) { + manager->AddFuncGraph(fg); + auto graph_nodes_sorted = TopoSort(fg->output()); + // Traverse once + for (auto &node : graph_nodes_sorted) { + AnfNodePtr new_node = Run(fg, func_graph, node, res); + if (new_node != nullptr && new_node != node) { + MS_LOG(WARNING) << "Matched"; + (void)manager->Replace(node, new_node); + changes = true; + } } } return changes; diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.h b/mindspore/ccsrc/frontend/optimizer/py_pass.h index 6e693c0e4..145f86c47 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.h @@ -39,7 +39,8 @@ class PythonPass { ~PythonPass() = default; bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res); std::string name() const { return name_; } - AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res); + AnfNodePtr Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node, + const MatchResultPtr &res); PatternPtr src_pattern() { return src_pattern_; } PatternPtr dst_pattern() { return dst_pattern_; } diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc index 4540d5bbc..89e0015a3 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc @@ -43,15 +43,19 @@ PyPassManagerPtr PyPassManager::GetInstance() { } PyPassManager::PyPassManager() { - phase_to_group_[Phase::RESOLVE] = std::make_shared(); - phase_to_group_[Phase::OPT] = std::make_shared(); + phase_to_group_[Phase::PREAD] = std::make_shared("Pre_AD_PassGroup"); + phase_to_group_[Phase::OPT] = std::make_shared("After_OPT_PassGroup"); res_ = std::make_shared(); } void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, - bool run_only_once) { - // NOTE: remove phase option to avoid unnecessary confusion. - auto cur_pg = GetPassGroup(Phase::OPT); + bool requires_grad, bool run_only_once) { + PassGroupPtr cur_pg; + if (requires_grad) { + cur_pg = GetPassGroup(Phase::PREAD); + } else { + cur_pg = GetPassGroup(Phase::OPT); + } MS_EXCEPTION_IF_NULL(cur_pg); cur_pg->SetRunOnlyOnce(run_only_once); MS_EXCEPTION_IF_NULL(pattern); @@ -62,11 +66,13 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt } void PyPassManager::Unregiste(const std::string &pass_name) { - // NOTE: remove phase option to avoid unnecessary confusion. - auto cur_pm = GetPassGroup(Phase::OPT); - MS_EXCEPTION_IF_NULL(cur_pm); - if (!cur_pm->DeletePass(pass_name)) { - MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; + auto opt_pm = GetPassGroup(Phase::OPT); + if (!opt_pm->DeletePass(pass_name)) { + MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n"; + } + auto pre_ad_pm = GetPassGroup(Phase::PREAD); + if (!pre_ad_pm->DeletePass(pass_name)) { + MS_LOG(WARNING) << "Pre_AD has no such pass : " + pass_name + "\n"; } } @@ -92,7 +98,7 @@ void PyPassManager::ClearRes() { REGISTER_PYBIND_DEFINE( PyPassManager_, ([](const py::module *m) { - (void)py::enum_(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); + (void)py::enum_(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT); (void)py::class_>(*m, "PyPassManager_") .def(py::init([]() { return PyPassManager::GetInstance(); })) .def("registe", &PyPassManager::Registe, "Registe python pass") diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h index c892d4685..590c63af0 100644 --- a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h @@ -38,7 +38,7 @@ namespace python_pass { class PyPassManager; using PyPassManagerPtr = std::shared_ptr; -enum Phase { RESOLVE, OPT }; +enum Phase { PREAD, OPT }; class PyPassManager { protected: @@ -52,8 +52,8 @@ class PyPassManager { // Access the only global instance static PyPassManagerPtr GetInstance(); virtual ~PyPassManager() = default; - void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, - bool run_only_once = false); + void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad, + bool run_only_once); void Unregiste(const std::string &pass_name); void GenNewParameter(const PatternPtr ¶meter); PassGroupPtr GetPassGroup(Phase phase); diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index b2434d5a1..029cf2104 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -288,6 +288,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) return true; } +bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); } + bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } @@ -460,7 +462,12 @@ bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { return ppm->GetPassGroup(phase)->Run(res->func_graph()); } -bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); } +bool PreAdActionPyStub(const ResourcePtr &res) { + if (!ActionPyStub(res, opt::python_pass::Phase::PREAD)) { + MS_LOG(DEBUG) << "No Match."; + } + return true; +} bool OptActionVmPyStub(const ResourcePtr &res) { if (ActionPyStub(res, opt::python_pass::Phase::OPT)) { @@ -516,12 +523,14 @@ static std::vector CommonPipeline() { if (!multi_graphs) { actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); } - // Add resolve-stage python pass stub - actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); // Evaluate type and shape, and specialize actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); + // Do data structure simplifications and inline + actions.emplace_back(std::make_pair("inline", OptInlineAction)); + // Add pre-ad, post-inline python pass stub + actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub)); return actions; } diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 6465f0f89..e50f18b07 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -165,6 +165,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { return map_a; } +OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) { + auto opt_a = GetOptPassesA(irpass); + OptPassGroupMap a1_a2({opt_a[0], opt_a[1]}); + return a1_a2; +} + OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig c_1 = opt::OptPassConfig({ // Safe inlining, @@ -270,6 +276,7 @@ static std::unordered_map> g_pass_opts = void InitOpt(const ResourcePtr &res) { if (g_pass_opts.size() == 0) { opt::irpass::OptimizeIRPassLib irpass; + g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", res, GetA1A2(irpass)); g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); g_pass_opts["opt_after_cconv"] = @@ -318,6 +325,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { return true; } +bool OptPassA1A2(const ResourcePtr &res) { return OptPassGroup(res, "a1a2"); } bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } @@ -440,5 +448,7 @@ std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"cconv", CconvPass}, {"transform_top", TransformTopGraphPass}, {"transform_graph", OptPassTransformGraphGroup}}; + +std::vector kInlinePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"a1a2", OptPassA1A2}}; } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.h b/mindspore/ccsrc/pipeline/jit/pass.h index 6176113a1..e187a4df6 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.h +++ b/mindspore/ccsrc/pipeline/jit/pass.h @@ -29,6 +29,7 @@ using PassItem = std::pair>; extern std::vector kGePasses; extern std::vector kVmPasses; +extern std::vector kInlinePasses; extern std::vector kPynativePasses; bool CconvPass(const ResourcePtr &res); diff --git a/mindspore/graph_utils/python_pass/__init__.py b/mindspore/graph_utils/python_pass/__init__.py index d9fe61c87..5fa33e350 100644 --- a/mindspore/graph_utils/python_pass/__init__.py +++ b/mindspore/graph_utils/python_pass/__init__.py @@ -13,14 +13,14 @@ # limitations under the License. # ============================================================================ """Reference for python pass registration.""" -from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\ - set_reopt +from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\ + _set_reopt __all__ = [ "registe_pass", "unregiste_pass", "gen_new_parameter", "cancel_new_parameter", - "set_renorm", - "set_reopt" + "_set_renorm", + "_set_reopt" ] diff --git a/mindspore/graph_utils/python_pass/python_pass_register.py b/mindspore/graph_utils/python_pass/python_pass_register.py index 8e37c44c7..37427f177 100644 --- a/mindspore/graph_utils/python_pass/python_pass_register.py +++ b/mindspore/graph_utils/python_pass/python_pass_register.py @@ -23,22 +23,26 @@ __all__ = [ "unregiste_pass", "gen_new_parameter", "cancel_new_parameter", - "set_renorm", - "set_reopt" + "_set_renorm", + "_set_reopt" ] class PyPassManager(PyPassManager_): r""" Used to registe and unregiste python passes which can be used to alter graphs. Args: + requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True run_only_once (bool): Specify whether or not to run pass only once. Default: False. Raises: TypeError: If argument has invalid type. """ - def __init__(self, run_only_once=False): + def __init__(self, requires_grad=True, run_only_once=False): + if not isinstance(requires_grad, bool): + raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}") if not isinstance(run_only_once, bool): raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}") + self.requires_grad = requires_grad self.run_only_once_ = run_only_once PyPassManager_.__init__(self) @@ -51,7 +55,7 @@ class PyPassManager(PyPassManager_): raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") if not isinstance(target, Pattern): raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") - super().registe(pass_name, pattern, target, self.run_only_once_) + super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_) def unregiste(self, py_pass): if isinstance(py_pass, str): @@ -81,11 +85,12 @@ class PyPassManager(PyPassManager_): raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") super().set_reopt(do_reopt) -def registe_pass(run_only_once=False): +def registe_pass(requires_grad=True, run_only_once=False): """ Registe python pass to specified pipeline phase which would be used in compilation. Args: + requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False. Returns: @@ -99,7 +104,7 @@ def registe_pass(run_only_once=False): >>> target = IsPrimTypeOf("ReLU6") >>> return pattern, target """ - return PyPassManager(run_only_once) + return PyPassManager(requires_grad, run_only_once) def unregiste_pass(py_pass): """ @@ -157,7 +162,7 @@ def cancel_new_parameter(pattern): ppm = PyPassManager() ppm.unregiste(pattern.para_name) -def set_renorm(should_renorm): +def _set_renorm(should_renorm): """ Set whether or not to do renormalization after modified graph in python pass(es). @@ -171,7 +176,7 @@ def set_renorm(should_renorm): ppm = PyPassManager() ppm.set_renorm(should_renorm) -def set_reopt(do_reopt): +def _set_reopt(do_reopt): """ Set whether or not to do optimization after modified graph in python pass(es). diff --git a/tests/ut/python/optimizer/test_python_pass.py b/tests/ut/python/optimizer/test_python_pass.py index 9038229fc..379f83f5f 100644 --- a/tests/ut/python/optimizer/test_python_pass.py +++ b/tests/ut/python/optimizer/test_python_pass.py @@ -19,8 +19,8 @@ import mindspore.nn as nn from mindspore import context from mindspore.common.tensor import Tensor from mindspore.ops import operations as P -from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\ - cancel_new_parameter, set_reopt +from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\ + cancel_new_parameter, _set_reopt from mindspore.common.api import _generate_pip_args from mindspore._c_expression import generate_key, Executor_ from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm @@ -157,8 +157,8 @@ def test_isnot_pattern_0(): Test IsNot pattern which expresses the IsNot semantics. Case: IsNot pass failed to match """ - set_renorm(False) - set_reopt(False) + _set_renorm(False) + _set_reopt(False) class ConvBN(nn.Cell): def __init__(self): super(ConvBN, self).__init__() @@ -176,7 +176,7 @@ def test_isnot_pattern_0(): inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32) conv_bn_model = ConvBN() - @registe_pass(run_only_once=True) + @registe_pass(requires_grad=False, run_only_once=True) def single_bn_pass(): """ Sub a BN which does NOT take Conv as inputs to ReLU6. @@ -188,7 +188,7 @@ def test_isnot_pattern_0(): target = Call(P.ReLU6(), [pattern_0]) return pattern, target - @registe_pass(run_only_once=True) + @registe_pass(requires_grad=False, run_only_once=True) def bn_pass(): """ Sub a BN to Softmax. @@ -202,7 +202,7 @@ def test_isnot_pattern_0(): unregiste_pass(bn_pass) assert "ReLU6" not in transformed_repr assert "Softmax" in transformed_repr - set_renorm(True) + _set_renorm(True) def test_isnot_pattern_1(): """ @@ -234,12 +234,12 @@ def test_newtensor_pattern(): """ Test NewTensor pattern in the target """ - set_renorm(False) - set_reopt(False) + _set_renorm(False) + _set_reopt(False) inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() - @registe_pass(run_only_once=True) + @registe_pass(requires_grad=False, run_only_once=True) def softmax_addn_pass(): x = Any() pattern = Call(P.Softmax(), [x]) @@ -252,7 +252,7 @@ def test_newtensor_pattern(): unregiste_pass(softmax_addn_pass) assert "AddN" in transformed_repr assert "Softmax" not in transformed_repr - set_renorm(True) + _set_renorm(True) def test_newparameter_pattern(): """ @@ -261,9 +261,9 @@ def test_newparameter_pattern(): inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() - set_renorm(False) - set_reopt(False) - @registe_pass(run_only_once=True) + _set_renorm(False) + _set_reopt(False) + @registe_pass(requires_grad=False, run_only_once=True) def softmax_addn_pass(): x = Any() pattern = Call(P.Softmax(), [x]) @@ -288,9 +288,9 @@ def test_imm_target(): inputs = Tensor(np.ones([42]), mindspore.float16) softmax_model = nn.Softmax() - set_renorm(False) - set_reopt(False) - @registe_pass(run_only_once=True) + _set_renorm(False) + _set_reopt(False) + @registe_pass(requires_grad=False, run_only_once=True) def softmax_pass(): x = Any() pattern = Call(P.Softmax(), [x]) @@ -313,10 +313,10 @@ def test_gen_new_parameter(): default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) new_para = NewParameter("Merlin", default_tensor) - set_renorm(False) - set_reopt(False) + _set_renorm(False) + _set_reopt(False) gen_new_parameter(new_para) - @registe_pass(run_only_once=True) + @registe_pass(requires_grad=False, run_only_once=True) def softmax_make_tuple_pass(): x = Any() softmax = P.Softmax() -- GitLab