From 2a6d346d2f68a9b5b04143ed5f7d22dabc3e189e Mon Sep 17 00:00:00 2001 From: huangdongrun Date: Thu, 9 Jul 2020 08:29:07 +0800 Subject: [PATCH] support if by if grad parameter add join for ref adjust env eliminate to eliminate all env ops add partial app cache resolve while endless fix env eliminate support for "for while" cases fix join shape error --- mindspore/ccsrc/debug/anf_ir_utils.cc | 3 +- .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 7 +- mindspore/ccsrc/frontend/optimizer/irpass.cc | 7 +- mindspore/ccsrc/frontend/optimizer/irpass.h | 12 + .../optimizer/irpass/env_item_eliminate.h | 64 +- .../ccsrc/frontend/optimizer/irpass/inline.h | 54 +- .../frontend/optimizer/irpass/ref_eliminate.h | 36 +- mindspore/ccsrc/pipeline/jit/action.cc | 9 +- mindspore/ccsrc/pipeline/jit/pass.cc | 60 +- .../pipeline/jit/remove_value_node_dup.cc | 104 +++ .../pipeline/jit/remove_value_node_dup.h | 4 + .../pipeline/jit/static_analysis/evaluator.cc | 44 +- .../jit/static_analysis/static_analysis.cc | 35 +- .../jit/static_analysis/static_analysis.h | 15 +- mindspore/ccsrc/pipeline/jit/validator.cc | 6 +- mindspore/ccsrc/utils/convert_utils.cc | 9 +- mindspore/ccsrc/vm/segment_runner.cc | 11 +- mindspore/core/abstract/abstract_value.cc | 4 + mindspore/core/ir/graph_utils.cc | 2 +- mindspore/core/ir/scalar.h | 31 +- mindspore/core/utils/trace_info.h | 10 + tests/st/control/test_cont_grad.py | 816 ++++++++++++++++++ tests/ut/cpp/CMakeLists.txt | 1 + tests/ut/python/ops/test_ops.py | 1 - tests/ut/python/runtest.sh | 1 - tests/vm_impl/array_ops_vm_impl.py | 13 +- tests/vm_impl/math_ops_vm_impl.py | 12 + 27 files changed, 1259 insertions(+), 112 deletions(-) create mode 100644 tests/st/control/test_cont_grad.py diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 4d6edd18c..4f8493ca7 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -330,7 +330,8 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_gra } oss << "SymInst(%para" << idx << ")"; } else { - MS_LOG(EXCEPTION) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString(); + MS_LOG(WARNING) << "SymbolicKeyInstance does not embed a parameter: " << sym_node->ToString(); + oss << "SymInst(cnode_" << sym_node->ToString() << ")"; } return oss.str(); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index d4fe20171..452f1800f 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -189,6 +189,11 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { if (!morph->isa()) { return nullptr; } + // for free variable, which may be handled in MapValueObject, just return it + auto node_adjoint_found = anfnode_to_adjoin_.find(morph); + if (node_adjoint_found != anfnode_to_adjoin_.end()) { + return node_adjoint_found->second; + } ScopeGuard scope_guard(morph->scope()); auto cnode_morph = morph->cast(); @@ -502,7 +507,7 @@ void DFunctor::MapFvObject() { if (parent_adjoint != nullptr) { adjoint = std::make_shared(node, parent_adjoint->k(), tape_); } else { - if (is_top_ || node->isa() || !IsInScope(node)) { + if (is_top_ || node->isa()) { // Out of ad scope, add adjoint for free variables. adjoint = std::make_shared(node, node, tape_); UpdateAdjoint(adjoint); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index b41c3081b..dfef764b8 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -87,10 +87,12 @@ OptimizeIRPassLib::OptimizeIRPassLib() { env_get_item_eliminate_ = MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_ = - MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_bypass_recursive_ = + MakeSubstitution(std::make_shared(true), "incorporate_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); + incorporate_env_getitem_ = + MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); // Ref eliminate make_ref_eliminate_ = @@ -122,6 +124,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // inline inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); + inline_without_move_ = MakeSubstitution(std::make_shared(false), "inline", IsCNodeGraph); replace_applicator_ = MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); specialize_transform_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 5a0f2ed5b..9a9a1e7a7 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -55,6 +55,7 @@ class OptimizeIRPassLib { SubstitutionPtr env_get_item_eliminate_; SubstitutionPtr new_env_get_item_; SubstitutionPtr incorporate_env_getitem_; + SubstitutionPtr incorporate_env_getitem_bypass_recursive_; SubstitutionPtr incorporate_env_getitem_switch_; // Ref eliminate @@ -80,6 +81,7 @@ class OptimizeIRPassLib { // inline SubstitutionPtr inline_; + SubstitutionPtr inline_without_move_; SubstitutionPtr replace_applicator_; SubstitutionPtr specialize_transform_; @@ -193,6 +195,16 @@ inline bool IsCNodeDup(const AnfNodePtr &node) { auto inp0 = node->cast()->input(0); return (inp0 != nullptr) && inp0->isa(); } + +// check if the cnode is a switch cnode +inline bool IsCNodeSwitch(const AnfNodePtr &node) { + if (node != nullptr) { + if (node->isa()) { + return IsPrimitiveCNode(node, prim::kPrimSwitch); + } + } + return false; +} } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h index 1fee007a8..6fa1304dd 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -29,6 +29,7 @@ #include "frontend/optimizer/anf_visitor.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/inline.h" #include "frontend/optimizer/optimizer.h" #include "utils/symbolic.h" @@ -59,8 +60,13 @@ class EnvGetitemTransform { while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { // {prim::kPrimEnvSetItem, env, symbolickey, value} auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4 || !IsValueNode(inputs[2])) { - MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; + if (inputs.size() != 4) { + MS_LOG(WARNING) << "Input size should be 4"; + return nullptr; + } + if (!IsValueNode(inputs[2])) { + MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; + return nullptr; } env = inputs[1]; @@ -91,33 +97,12 @@ class EnvGetitemTransform { class NewEnvGetItem : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - auto gety = [this](const AnfNodePtr &node) -> bool { - this->y_ = node; - return true; - }; - - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode, IsVNode, gety})(node); - if (env_ != nullptr && env_->Len() == 0) { - return y_; - } + PatternNode c1, c2, y; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimEnvGetItem, c1, c2, y), y, + (IsValueNode(c1.GetNode(node)) && IsVNode(c2.GetNode(node)) && + (GetValueNode(c1.GetNode(node)))->Len() == 0)); return nullptr; } - - void Visit(const ValueNodePtr &vnode) override { - if (env_ == nullptr) { - env_ = GetValueNode(vnode); - } - } - - void Reset() { - y_ = nullptr; - env_ = nullptr; - } - - private: - AnfNodePtr y_{nullptr}; - EnvInstancePtr env_{nullptr}; }; // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> @@ -205,8 +190,13 @@ class EnvGetSetItem : public AnfVisitor { while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { // {prim::kPrimEnvSetItem, env, symbolickey, value} auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4 || !IsValueNode(inputs[2])) { - MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; + if (inputs.size() != 4) { + MS_LOG(WARNING) << "Input size should be 4"; + return nullptr; + } + if (!IsValueNode(inputs[2])) { + MS_LOG(DEBUG) << "Input 2 is not a SymbolicKeyInstance?"; + return nullptr; } env = inputs[1]; @@ -257,7 +247,8 @@ class EnvGetItemEliminater : public OptimizerCaller { // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} class IncorporateEnvGetitem : public AnfVisitor { public: - IncorporateEnvGetitem() : env_get_item_transform_() {} + explicit IncorporateEnvGetitem(bool bypass_recursive = false) + : env_get_item_transform_(), bypass_recursive_(bypass_recursive) {} ~IncorporateEnvGetitem() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { @@ -285,7 +276,13 @@ class IncorporateEnvGetitem : public AnfVisitor { auto inputs = inp1->inputs(); auto fg = GetValueNode(inputs[0]); auto new_fg = env_get_item_transform_(fg, key, default_v); - + if (fg->recursive() && bypass_recursive_) { + MS_LOG(DEBUG) << "Bypass env_get_item transform for recursive fg=" << fg->ToString(); + return nullptr; + } + if (new_fg == nullptr) { + return nullptr; + } std::vector args; args.push_back(NewValueNode(new_fg)); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); @@ -298,6 +295,7 @@ class IncorporateEnvGetitem : public AnfVisitor { private: bool is_match_{false}; internal::EnvGetitemTransform env_get_item_transform_; + bool bypass_recursive_; }; // {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} @@ -342,7 +340,9 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { auto g2 = GetValueNode(sw->input(3)); auto new_g1 = env_get_item_transform_(g1, key, default_v); auto new_g2 = env_get_item_transform_(g2, key, default_v); - + if (new_g1 == nullptr || new_g2 == nullptr) { + return nullptr; + } auto fg = node->func_graph(); auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index 0be228f44..ebe4cbe5e 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -93,10 +93,22 @@ bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } +bool IsDirectParentCall(FuncGraphPtr fg, AnfNodePtr node) { + bool unique_use = IsUniqueUse(fg, nullptr); + bool is_recursive = fg->recursive(); + if (fg->parent() != nullptr && is_recursive) { + if (fg->parent() == node->func_graph() && unique_use) { + return true; + } + } + return false; +} + // {G, Xs} class InlinerBase : public AnfVisitor { public: - explicit InlinerBase(std::vector> criterions) : criterions_(criterions) {} + explicit InlinerBase(std::vector> criterions, bool use_move = true) + : use_move_(use_move), criterions_(criterions) {} ~InlinerBase() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa()) { @@ -113,6 +125,7 @@ class InlinerBase : public AnfVisitor { if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { return nullptr; } + // Do not inline GraphKernel to Cell. if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { // If the GraphKernel only contains a return node, we make it inlined. @@ -142,8 +155,12 @@ class InlinerBase : public AnfVisitor { std::vector params; (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); - - if (IsUniqueUse(fg, nullptr)) { + // compare size to avoid the case that the function has default value after grad. + // for which after renormalize, the function default value will be an input + if (fg->parameters().size() != params.size()) { + return nullptr; + } + if (use_move_ && IsUniqueUse(fg, nullptr)) { auto mng = fg->manager(); MS_EXCEPTION_IF_NULL(mng); ReplaceParams(mng, params, fg); @@ -183,21 +200,36 @@ class InlinerBase : public AnfVisitor { private: bool is_checked_{false}, is_recursive_{false}; + bool use_move_; std::vector> criterions_; }; class Inliner : public InlinerBase { public: - Inliner() - : InlinerBase({ - {IsUniqueUse, true}, - {IsTrivial, false}, - {IsInside, false}, - {IsCore, false}, - {NoCriterion, true}, - }) {} + explicit Inliner(bool use_move = true) + : InlinerBase( + { + {IsUniqueUse, true}, + {IsTrivial, false}, + {IsInside, false}, + {IsCore, false}, + {IsDirectParentCall, false}, + {NoCriterion, true}, + }, + use_move) {} ~Inliner() override = default; }; + +class DirectInliner : public InlinerBase { + public: + explicit DirectInliner(bool use_move = true) + : InlinerBase( + { + {IsDirectParentCall, false}, + }, + use_move) {} + ~DirectInliner() override = default; +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h index fc859b213..ab72c1947 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h @@ -26,6 +26,30 @@ namespace mindspore { namespace opt { namespace irpass { +namespace internal { +class GetRefValueTransform { + public: + GetRefValueTransform() {} + ~GetRefValueTransform() = default; + + AnfNodePtr operator()(const AnfNodePtr &node) { + CNodePtr cnode = node->cast(); + auto inputs = cnode->inputs(); + auto fg = GetValueNode(inputs[0])->cast(); + if (fg->recursive()) { + MS_LOG(DEBUG) << "Get refvalue by pass recursive:" << fg->ToString(); + return node; + } + auto new_fg = TransformableClone(fg, std::make_shared("GetRefValue")); + auto output = new_fg->output(); + new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimGetRefValue), output})); + inputs[0] = NewValueNode(new_fg); + auto ret_node = cnode->func_graph()->NewCNode(inputs); + return ret_node; + } +}; +} // namespace internal + // {prim::kPrimMakeRef, X, Y, Z} -> Y class MakeRefEliminater : public OptimizerCaller { public: @@ -48,13 +72,23 @@ class GetRefParamEliminater : public OptimizerCaller { // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y +// {prim::kPrimGetRefValue, {prim::switch, cond, t, f}} -> {prim::switch, cond, t, f} class GetMakeRefEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x, y, z; MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); - + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsCNodeSwitch, node)); + internal::GetRefValueTransform trans; + auto GetRefLambda = [&trans, &x, &node]() -> AnfNodePtr { + auto rep = trans(x.GetNode(node)); + if (rep != nullptr) { + return rep; + } + return nullptr; + }; + MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetRefValue, x), GetRefLambda, x.CheckFunc(IsCNodeGraph, node)); return nullptr; } }; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index c5b38fe82..c556b3399 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -314,6 +314,7 @@ bool TaskEmitAction(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); auto bc_ptr = res->results()[kBackend].cast(); auto context_ptr = MsContext::GetInstance(); + std::string backend = MsContext::GetInstance()->backend_policy(); MS_EXCEPTION_IF_NULL(context_ptr); if (CompileGraphs::ContainMixedTarget(func_graph)) { bc_ptr->set_is_multi_graph_sink(false); @@ -321,13 +322,13 @@ bool TaskEmitAction(const ResourcePtr &res) { context_ptr->set_loop_sink_flag(false); } else if (context_ptr->execution_mode() != kPynativeMode) { std::string device_target = context_ptr->device_target(); - if (device_target == kAscendDevice) { + if (device_target == kAscendDevice && backend != kMsVm) { bc_ptr->set_is_multi_graph_sink(true); context_ptr->set_is_multi_graph_sink(true); } } - if (IsCtrlSink()) { + if (IsCtrlSink() && backend == kMsConvert) { res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); return true; } @@ -344,8 +345,8 @@ bool ExecuteAction(const ResourcePtr &res) { if (res->results().count(kOutput) == 0) { MS_LOG(EXCEPTION) << "Execute args error"; } - - if (IsCtrlSink()) { + std::string backend = MsContext::GetInstance()->backend_policy(); + if (IsCtrlSink() && backend == kMsConvert) { if (!res->results()[kOutput].is()) { MS_LOG(EXCEPTION) << "Execute args error"; } diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 0c27ba7c4..c047e1133 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -30,6 +30,7 @@ #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/resource.h" #include "pipeline/jit/validator.h" +#include "pipeline/jit/remove_value_node_dup.h" #include "frontend/optimizer/optimizer.h" #include "frontend/optimizer/cse.h" #include "frontend/optimizer/graph_kernel_reuse.h" @@ -127,11 +128,14 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.incorporate_getitem_set_, irpass.incorporate_call_, irpass.incorporate_call_switch_, - irpass.incorporate_env_getitem_, + irpass.incorporate_env_getitem_bypass_recursive_, irpass.incorporate_env_getitem_switch_, irpass.new_env_get_item_, irpass.depend_value_elim_, }); + opt::OptPassConfig a_after_grad = opt::OptPassConfig({ + irpass.inline_without_move_, + }); opt::OptPassConfig a_3 = opt::OptPassConfig({ irpass.arithmetic_simplify2_, irpass.same_eliminate_, @@ -154,6 +158,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { {"virtual_dataset", virtual_dataset}, {"grad", grad}, {"resolve", resolve_pass}, + {"a_after_grad", a_after_grad}, {"renormalize", opt::OptPassConfig::Renormalize()}, {"cse", opt::OptPassConfig(opt::CSE(false))}, {"a_3", a_3}}); @@ -161,11 +166,24 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { return map_a; } +OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig c_1 = opt::OptPassConfig({ + // Safe inlining + irpass.inline_, + irpass.partial_eliminate_, + }); + + OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); + + return map_a; +} + OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig b_1 = - opt::OptPassConfig({irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, - irpass.get_make_ref_eliminate_, irpass.value_based_eliminate_}); + opt::OptPassConfig b_1 = opt::OptPassConfig( + {irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_, + irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, + irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, + irpass.value_based_eliminate_}); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, @@ -244,6 +262,8 @@ void InitOpt(const ResourcePtr &res) { opt::irpass::OptimizeIRPassLib irpass; g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); + g_pass_opts["opt_after_cconv"] = + Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true); g_pass_opts["opt_graph_kernel_a"] = Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); g_pass_opts["opt_graph_kernel_b"] = @@ -288,6 +308,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } +bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); } bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } @@ -311,6 +332,33 @@ bool AddControlDependPass(const ResourcePtr &res) { return true; } +bool MergeDupGraphPass(const ResourcePtr &res) { + FuncGraphPtr func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(res->manager()); + if (res->manager()->func_graphs().size() <= 1) { + return true; + } + return MergeDuplicateGraphs(res->manager()); +} + +bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "Remove value node duplications error."; + } + auto manager = res->manager(); + HashCache hash_cache; + HashValue hashes; + // Remove duplicated value nodes across all graphs in manager + for (auto &fg : manager->func_graphs()) { + auto value_nodes = fg->value_nodes(); + for (const auto &value_pair : value_nodes) { + TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); + } + } + return true; +} + bool CconvPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); @@ -340,6 +388,8 @@ std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStru {"clean_after_opta", CleanAfterOptAPass}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}, + {"opt_after_cconv", OptPassAfterCconvGroup}, + {"remove_dup_value", RemoveValueNodeDuplicationsPass}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"add_control_depend", AddControlDependPass}}; diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc index e9467e4ae..2d390c46a 100644 --- a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "pipeline/jit/remove_value_node_dup.h" #include "ir/anf.h" @@ -70,5 +71,108 @@ void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, Has // Meet for the first time, append node to bucket. bucket.emplace_back(node); } + +size_t HashOfGraph(const FuncGraphPtr &fg) { + std::vector toposet = TopoSort(fg->get_return()); + MS_LOG(DEBUG) << "TopSort for:" << fg->ToString(); + std::unordered_map hashes; + auto ¶ms = fg->parameters(); + for (size_t i = 0; i < params.size(); i++) { + hashes[params[i]] = std::hash{}("param" + std::to_string(i)); + } + for (auto node : toposet) { + MS_EXCEPTION_IF_NULL(node); + if (hashes.find(node) != hashes.end()) { + continue; + } + + std::size_t h = 0; + if (node->isa()) { + ValueNodePtr value_node = node->cast(); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (IsValueNode(value_node)) { + auto v_fg = value->cast(); + h = value->hash(); + } else if (IsValueNode(value_node)) { + // the tensor has same value has been replaced in duplicate value pass, + // so we use the value pointer here as an identifier + h = hash_combine(value->hash(), std::hash{}(value.get())); + } else { + h = hash_combine(value->hash(), (opt::AbsOf(value_node)->hash())); + } + } else if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + size_t init = 0; + h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) { + return hash_combine(hash, hashes[node_in]); + }); + } else if (node->isa()) { + h = node->hash(); + } else { + MS_LOG(ERROR) << "Unknow node type"; + } + hashes[node] = h; + } + return hashes[fg->get_return()]; +} + +bool IsCNodeGraph(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + + auto inp0 = node->cast()->input(0); + return IsValueNode(inp0); +} + +bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager) { + std::unordered_map> hash_graphs; + std::unordered_map graph_hash; + for (auto fg : manager->func_graphs()) { + size_t h = HashOfGraph(fg); + graph_hash[fg] = h; + if (hash_graphs.find(h) == hash_graphs.end()) { + hash_graphs[h] = {fg}; + } else { + hash_graphs[h].push_back(fg); + } + } + FuncGraphPairMapEquiv equiv_graph; + NodeMapEquiv equiv_node; + for (auto &fg : manager->func_graphs()) { + MS_LOG(DEBUG) << "Try Merge Graph:" << fg->ToString(); + for (auto &item : fg->nodes()) { + if (!item->isa()) { + continue; + } + auto &inputs = item->cast()->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + if (!inputs[i]->isa()) { + continue; + } + auto value_ptr = GetValueNode(inputs[i]); + auto v_fg = value_ptr->cast(); + if (v_fg == nullptr) { + continue; + } + auto &fg_vec = hash_graphs[graph_hash[v_fg]]; + if (fg_vec.size() > 1) { + if (v_fg != fg_vec[0]) { + bool is_morphic = Isomorphic(v_fg, fg_vec[0], &equiv_graph, &equiv_node); + if (is_morphic) { + auto new_node = NewValueNode(fg_vec[0]); + MS_LOG(DEBUG) << "Replace graph node :" << inputs[i]->ToString() << " with:" << new_node->ToString(); + manager->Replace(inputs[i], new_node); + } + } + } + } + } + } + return true; +} + } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h index fd52924d5..39fcd4472 100644 --- a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h @@ -28,6 +28,10 @@ using HashCache = std::unordered_map>; using HashValue = std::unordered_map; void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); +size_t HashOfGraph(const FuncGraphPtr &fg); +bool IsCNodeGraph(const AnfNodePtr &node); +bool MergeDuplicateGraphs(const FuncGraphManagerPtr manager); + } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 424a057bc..f6ffda863 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -113,17 +113,18 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr } const AnfNodePtr &func_node = fg->get_return(); - MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() + MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); AbstractBasePtr ret_base = nullptr; std::vector nodes = FastShadowSort(func_node); for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { const auto &node = *it; AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); - MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); + MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() + << ", node_conf: " << node_conf->ToString(); ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); - MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() - << ", abstract: " << ret_base->ToString(); + MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() + << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); } MS_EXCEPTION_IF_NULL(ret_base); @@ -142,16 +143,17 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), [](const AbstractBasePtr &arg) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(arg); - return arg->Broaden(); + if (arg->GetValueTrack() != kAnyValue) { + return arg->Broaden(); + } + return arg; }); - if (func_graph_->joined_shapes_.size() != broaded_list.size()) { - MS_EXCEPTION(ValueError) << "Number of input arguments " << broaded_list.size() - << " does not equal to number of original buffer arguments " - << func_graph_->joined_shapes_.size(); - } - for (size_t i = 0; i < broaded_list.size(); ++i) { - broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); + if (func_graph_->joined_shapes_.size() == broaded_list.size()) { + for (size_t i = 0; i < broaded_list.size(); ++i) { + broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); + } } + MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) << ", broaded: " << mindspore::ToString(broaded_list); return broaded_list; @@ -181,8 +183,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->joined_shapes_.clear(); std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), - std::back_inserter(func_graph_->joined_shapes_), - [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); + std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { + if (arg_spec->isa()) { + return arg_spec->cast()->ref()->GetShapeTrack(); + } + return arg_spec->GetShapeTrack(); + }); + joined_args_spec_list = NormalizeArgs(joined_args_spec_list); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } return joined_args_spec_list; @@ -199,8 +206,13 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->joined_shapes_.clear(); std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), - std::back_inserter(func_graph_->joined_shapes_), - [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); + std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { + if (arg_spec->isa()) { + return arg_spec->cast()->ref()->GetShapeTrack(); + } + return arg_spec->GetShapeTrack(); + }); + joined_args_spec_list = NormalizeArgs(joined_args_spec_list); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 338743b1d..e5f9cdb6b 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -188,6 +188,7 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { trace::TraceEvalCNodeLeave(); } else { MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() + << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph") << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } @@ -301,6 +302,8 @@ void AnalysisEngine::Clear() { anfnode_config_map_.clear(); eval_trace_.clear(); constructors_.clear(); + constructors_app_.clear(); + continued_evals_.clear(); } namespace { @@ -426,8 +429,14 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptrfn(); EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); + auto part_pair = std::make_pair(func_orig, func->args()); + auto itr = constructors_app_.find(part_pair); + if (itr != constructors_app_.end()) { + return itr->second; + } std::shared_ptr partial_evaluator = std::make_shared(evaluator_orig, func->args()); + constructors_app_[part_pair] = partial_evaluator; return partial_evaluator; } @@ -504,9 +513,10 @@ void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { if (fg_eval == nullptr) { return; } + auto fg = fg_eval->func_graph(); MS_EXCEPTION_IF_NULL(fg); - auto undetermined_fgs = fg->recursive_graphs(); + auto undetermined_fgs = fg->recursive(); if (undetermined_fgs) { auto fg_parent = fg->parent(); MS_EXCEPTION_IF_NULL(fg_parent); @@ -546,15 +556,19 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vectorToString() << " check undetermined."; - if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { - MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; + MS_LOG(DEBUG) << u_eval.first->ToString() << "check undetermined."; + auto &alternate_evaluator = multi_poss_[u_eval.first]; + auto &eval_cache = alternate_evaluator->cache(); + if ((!undetermined_evals.count(std::make_pair(alternate_evaluator, args_spec_list))) && + (((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) || + (eval_cache->find(args_spec_list) == eval_cache->end()))) { + MS_LOG(DEBUG) << u_eval.first->ToString() << "has undetermined."; has_undetermined = true; break; } } if (has_undetermined == false) { - MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; + MS_LOG(DEBUG) << eval->ToString() << "has no undetermined."; *continue_flag = true; return latest_entry; } @@ -597,34 +611,33 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorToString(); - // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); if (it == eval_trace_.rend()) { eval_trace_.push_back(current_inf); - MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); MS_EXCEPTION_IF_NULL(eval); auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); MS_EXCEPTION_IF_NULL(eval_result->abstract()); - MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); out_specs.push_back(eval_result->abstract()); eval_trace_.pop_back(); if (eval_trace_.empty()) { multi_poss_.clear(); } - } else if (it != eval_trace_.rbegin()) { + } else { bool continue_flag = false; auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); if (continue_flag) { + MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.first.get() << current_inf.first->ToString(); + continued_evals_.insert(current_inf); continue; } // Try to travel the latest undetermined. if (latest_entry != eval_trace_.rbegin()->first) { - MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); + MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString(); auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); MS_EXCEPTION_IF_NULL(eval_result->abstract()); - MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() + MS_LOG(DEBUG) << "end Direct Evaluator " << latest_entry->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); return eval_result; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 0ebd9a0af..701893289 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -26,6 +26,7 @@ #include #include #include +#include #ifdef DEBUG #include @@ -113,7 +114,8 @@ class AnfNodeConfig : public Config { std::string ToString() const override { std::ostringstream buffer; - buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString(); + buffer << "Node: " << node_->DebugString() << "-uid(" << node_->UniqueId() + << "), Context: " << context_->ToString(); return buffer.str(); } @@ -173,7 +175,13 @@ struct AnalysisResult { }; using EvalTraceRevIter = std::list>::reverse_iterator; - +struct PartialAppHasher { + std::size_t operator()(const std::pair &p) const { + auto h1 = std::hash{}(p.first); + auto h2 = AbstractBasePtrListHash(p.second); + return h1 ^ h2; + } +}; class AnalysisEngine : public std::enable_shared_from_this { public: AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) @@ -233,10 +241,13 @@ class AnalysisEngine : public std::enable_shared_from_this { const PrimEvaluatorMap &prim_constructors_; FuncGraphManagerPtr func_graph_manager_; std::unordered_map constructors_; + std::unordered_map, EvaluatorPtr, PartialAppHasher> + constructors_app_; AnfNodeConfigMap anfnode_config_map_; // Use a list to trace multiple evaluators. std::list> eval_trace_; std::map multi_poss_; + std::set> continued_evals_; AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, const ConfigPtrList &args_conf_list); diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 9655f7a65..95a54eebb 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -34,6 +34,7 @@ using mindspore::abstract::AbstractError; using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractRef; using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractSparseTensor; @@ -83,7 +84,8 @@ void ValidateAbstract(const AnfNodePtr &node) { // only send string in external if (!IsValueNode(node)) { // Validate a type. - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() + << " for node=" << node->DebugString(); } } return; @@ -96,7 +98,7 @@ void ValidateAbstract(const AnfNodePtr &node) { if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa() || ptrBase->isa()) { + ptrBase->isa() || ptrBase->isa() || ptrBase->isa()) { return; } diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 70590da75..0d3ab3b65 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -481,8 +481,10 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple } // Isomorphism -static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, - NodeMapEquiv *const equiv_node) { +static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node); +bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { if (equiv_node == nullptr) { MS_LOG(ERROR) << "Invalid equiv_node"; return false; @@ -514,6 +516,9 @@ static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, Fu MS_LOG(DEBUG) << "two parameters are not equal."; return false; } + if (node1->isa() && node2->isa()) { + return SameNode(node1, node2, equiv_func_graph, equiv_node); + } MS_LOG(ERROR) << "type error"; return false; } diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index 141adc1bf..33197876a 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -116,12 +116,15 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo } // namespace std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { - auto fg = std::make_shared(); - AnfNodePtrList inputs; - AnfNodePtrToAnfNodePtrMap eqv; if (lst.empty()) { MS_LOG(EXCEPTION) << "Input anf node list is empty"; } + TraceManager::DebugTrace( + std::make_shared(lst[0]->cast()->func_graph()->debug_info())); + auto fg = std::make_shared(); + TraceManager::EndTrace(); + AnfNodePtrList inputs; + AnfNodePtrToAnfNodePtrMap eqv; // Merge CNodes into a AnfGraph that represents a linear instruction segment for (auto n : lst) { if (!n->isa()) { @@ -154,7 +157,9 @@ std::tuple TransformSegmentToAnfGr (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); } + TraceManager::DebugTrace(std::make_shared(n->debug_info())); eqv[n] = fg->NewCNode(args); + TraceManager::EndTrace(); eqv[n]->set_abstract(n->abstract()); eqv[n]->set_kernel_info(n->kernel_info_ptr()); } diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index dab262bc8..0fb6759d9 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -452,6 +452,10 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { } auto other_tensor = dyn_cast(other); if (other_tensor == nullptr) { + auto ref_tensor = dyn_cast(other); + if (ref_tensor != nullptr) { + return this->Join(ref_tensor->ref()); + } MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } if (*this == *other) { diff --git a/mindspore/core/ir/graph_utils.cc b/mindspore/core/ir/graph_utils.cc index cde5eaafb..ccdf8ee1d 100644 --- a/mindspore/core/ir/graph_utils.cc +++ b/mindspore/core/ir/graph_utils.cc @@ -48,7 +48,7 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c continue; } if (rank.find(node) != rank.end() && rank[node] != todo.size()) { - MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(); + MS_LOG(EXCEPTION) << "Graph exists cycle, node " << node->DebugString(2); } rank[node] = todo.size(); bool cont = false; diff --git a/mindspore/core/ir/scalar.h b/mindspore/core/ir/scalar.h index b814a4781..62c5f35ba 100644 --- a/mindspore/core/ir/scalar.h +++ b/mindspore/core/ir/scalar.h @@ -30,6 +30,7 @@ #include "base/base.h" #include "ir/dtype.h" #include "ir/dtype/number.h" +#include "utils/hashing.h" using std::fabs; @@ -51,7 +52,7 @@ using ScalarPtr = std::shared_ptr; class BoolImm : public Scalar { public: - explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash{}(v_); } + explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~BoolImm() override = default; MS_DECLARE_PARENT(BoolImm, Scalar) std::size_t hash() const override { return hash_; } @@ -91,7 +92,7 @@ class IntergerImm : public Scalar { class Int8Imm : public IntergerImm { public: Int8Imm() : IntergerImm(kInt8), v_(0) {} - explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash{}(v_); } + explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~Int8Imm() override = default; MS_DECLARE_PARENT(Int8Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -117,7 +118,7 @@ IMM_TRAITS(Int8ImmPtr, int8_t) class Int16Imm : public IntergerImm { public: Int16Imm() : IntergerImm(kInt16), v_(0) {} - explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash{}(v_); } + explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~Int16Imm() override = default; MS_DECLARE_PARENT(Int16Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -143,7 +144,7 @@ IMM_TRAITS(Int16ImmPtr, int16_t) class Int32Imm : public IntergerImm { public: Int32Imm() : IntergerImm(kInt32), v_(0) {} - explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash{}(v_); } + explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~Int32Imm() override = default; MS_DECLARE_PARENT(Int32Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -169,7 +170,7 @@ IMM_TRAITS(Int32ImmPtr, int32_t) class Int64Imm : public IntergerImm { public: Int64Imm() : IntergerImm(kInt64), v_(0) {} - explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash{}(v_); } + explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~Int64Imm() override = default; MS_DECLARE_PARENT(Int64Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -195,7 +196,9 @@ IMM_TRAITS(Int64ImmPtr, int64_t) class UInt8Imm : public IntergerImm { public: UInt8Imm() : IntergerImm(kUInt8), v_(0) {} - explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash{}(v_); } + explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { + hash_ = hash_combine({tid(), std::hash{}(v_)}); + } ~UInt8Imm() override = default; MS_DECLARE_PARENT(UInt8Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -221,7 +224,9 @@ IMM_TRAITS(UInt8ImmPtr, uint8_t); class UInt16Imm : public IntergerImm { public: UInt16Imm() : IntergerImm(kUInt16), v_(0) {} - explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash{}(v_); } + explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { + hash_ = hash_combine({tid(), std::hash{}(v_)}); + } ~UInt16Imm() override = default; MS_DECLARE_PARENT(UInt16Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -247,7 +252,9 @@ IMM_TRAITS(UInt16ImmPtr, uint16_t); class UInt32Imm : public IntergerImm { public: UInt32Imm() : IntergerImm(kUInt32), v_(0) {} - explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash{}(v_); } + explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { + hash_ = hash_combine({tid(), std::hash{}(v_)}); + } ~UInt32Imm() override = default; MS_DECLARE_PARENT(UInt32Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -273,7 +280,9 @@ IMM_TRAITS(UInt32ImmPtr, uint32_t); class UInt64Imm : public IntergerImm { public: UInt64Imm() : IntergerImm(kUInt64), v_(0) {} - explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash{}(v); } + explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { + hash_ = hash_combine({tid(), std::hash{}(v)}); + } ~UInt64Imm() override = default; MS_DECLARE_PARENT(UInt64Imm, IntergerImm) std::size_t hash() const override { return hash_; } @@ -308,7 +317,7 @@ using FloatImmPtr = std::shared_ptr; class FP32Imm : public FloatImm { public: FP32Imm() : FloatImm(kFloat32), v_(0.0) {} - explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash{}(v_); } + explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~FP32Imm() override = default; MS_DECLARE_PARENT(FP32Imm, FloatImm) std::size_t hash() const override { return hash_; } @@ -334,7 +343,7 @@ IMM_TRAITS(FP32ImmPtr, float) class FP64Imm : public FloatImm { public: FP64Imm() : FloatImm(kFloat64), v_(0.0) {} - explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash{}(v_); } + explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = hash_combine({tid(), std::hash{}(v_)}); } ~FP64Imm() override = default; MS_DECLARE_PARENT(FP64Imm, FloatImm) std::size_t hash() const override { return hash_; } diff --git a/mindspore/core/utils/trace_info.h b/mindspore/core/utils/trace_info.h index fea2cb3ea..5c9160d7c 100644 --- a/mindspore/core/utils/trace_info.h +++ b/mindspore/core/utils/trace_info.h @@ -412,6 +412,16 @@ class TraceCombileLikeGraphs : public TraceInfo { return std::make_shared(*shared_from_base()); } }; + +class TraceSegmentTransform : public TraceInfo { + public: + explicit TraceSegmentTransform(const DebugInfoPtr &info) : TraceInfo(info, "segment_transform", "") {} + MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); + ~TraceSegmentTransform() override = default; + TraceInfoPtr clone() override { + return std::make_shared(*shared_from_base()); + } +}; } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_TRACE_INFO_H_ diff --git a/tests/st/control/test_cont_grad.py b/tests/st/control/test_cont_grad.py new file mode 100644 index 000000000..c3baae1fb --- /dev/null +++ b/tests/st/control/test_cont_grad.py @@ -0,0 +1,816 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test control ops """ +import numpy as np + +from mindspore import dtype as ms +from mindspore import Tensor +from mindspore import context +from mindspore import nn +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.ops import composite as C +from mindspore.ops import operations as P +# from tests.vm_impl.math_ops_vm_impl import * +# from tests.vm_impl.vm_interface import * +# from tests.vm_impl import * +# context.set_context(save_graphs=True) + + +def test_while_forward(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + + def construct(self, idx, end, x): + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + idx = idx + 1 + return x + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + net = MyWhileNet() + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + + def construct(self, idx, end, x): + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + idx = idx + 1 + return x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_forward(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + out = out + x + self.param + idx = idx + 1 + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + net = MyWhileNet() + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_endless_case(): + """endless case when optmization""" + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + part = x[idx, :, :] + out = out + part + idx = idx + 1 + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + net = MyWhileNet() + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + part = x[idx, :, :] + max_num = self.max(part) + x[idx, :, 0:2] = max_num + out = out + x + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_forward_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = while_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_opt_endless(): + """endless during optimization case""" + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.addn = P.AddN() + + def construct(self, idx, end, x): + addn1 = self.addn((x, x, x)) + out = addn1 + while idx < end: + out = self.addn((out, addn1)) + idx = idx + 1 + out = self.addn((out, x)) + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32) + net(idx, end, x) + + +def test_no_while_call(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + def construct(self, idx, end, x): + out = self.zero + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = while_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_grad_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_with_const_branch(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = self.zero + for _ in range(0, 2): + idx = self.start + while idx < end: + if 2 > 1: + out = out + self.param + else: + out = out + idx + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_basic(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = self.zero + for _ in range(0, 2): + idx = self.start + while idx < end: + out = out + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_for_while_with_param_grad_normal(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.reduce = P.ReduceSum() + self.start = Tensor(np.array(0), dtype=ms.int32) + + def construct(self, idx, end, x): + out = x + for _ in range(0, 2): + idx = self.start + while idx < end: + out = out + self.param + idx = idx + 1 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_mul(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out * self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_two(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + self.weight + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_basic_grad_three(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss") + self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param + self.weight + self.key + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_if_with_param_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + self.t2 = Tensor(np.array(2), dtype=ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if self.max(out) < self.max(x): + out = out + self.param * 2 + else: + out = out + self.param + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(3), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_while_with_param_grad_not_enter_while(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + out = out + self.param * 3 + idx = idx + 1 + return out + self.param + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, a, b, c): + return C.grad_by_list(self.net, self.weights)(a, b, c) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor(np.array(3), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_forward(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param + else: + out = out + x + if a == b: + out = out + x*3 + self.param + else: + out = out + x*2 + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(4), dtype=ms.int32) + x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_inputs(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 4 + if a == b: + out = out + x*3 + self.param * 3 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_parameter(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 2 + if a == b: + out = out + x*3 + self.param + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(2), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_with_param_if_by_if_grad_param_excute_null(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + out = out + x + self.param * 2 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(4), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_return_inside_grad(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.max = P.ReduceMax() + self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight") + self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32) + + def construct(self, a, b, x): + out = self.zero + if a < b: + return out + x + self.param + if a == b: + return out + self.param * 2 + return out + self.param * 3 + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = GradNet(if_net) + idx = Tensor(np.array(1), dtype=ms.int32) + end = Tensor(np.array(0), dtype=ms.int32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) + net(idx, end, x) + + +def test_if_by_if_forward(): + class MyIfByIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.TensorAdd() + self.sub = P.Sub() + self.mul = P.Mul() + self.div = P.RealDiv() + + def construct(self, a, b, x): + if a < b: + a = self.add(a, b) + else: + a = self.sub(a, b) + if a == x: + a = self.mul(a, b) + else: + a = self.div(a, b) + if b == x: + b = self.add(a, b) + else: + b = self.add(a, x) + a = a * b + out = a + b + x + return out + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + if_net = MyIfByIfNet() + net = if_net + idx = Tensor(np.array(2), dtype=ms.float32) + end = Tensor(np.array(3), dtype=ms.float32) + x = Tensor(np.array(4), dtype=ms.float32) + net(idx, end, x) diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index e43a8272c..71c0f39d3 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -58,6 +58,7 @@ add_subdirectory(serving) file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/core/base/*.cc" + "../../../mindspore/core/gvar/*.cc" "../../../mindspore/core/abstract/*.cc" "../../../mindspore/core/ir/*.cc" "../../../mindspore/core/utils/*.cc" diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 2eb3584c3..4f3c3302b 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -34,7 +34,6 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ import pipeline_for_compile_grad_ge_graph_for_case_by_case_config - class InputBackward(nn.Cell): def __init__(self, network): super(InputBackward, self).__init__() diff --git a/tests/ut/python/runtest.sh b/tests/ut/python/runtest.sh index 6108e0f47..1a687b9b3 100755 --- a/tests/ut/python/runtest.sh +++ b/tests/ut/python/runtest.sh @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - CURRPATH=$(cd $(dirname $0); pwd) IGNORE_EXEC="--ignore=$CURRPATH/exec" PROJECT_PATH=$(cd ${CURRPATH}/../../..; pwd) diff --git a/tests/vm_impl/array_ops_vm_impl.py b/tests/vm_impl/array_ops_vm_impl.py index 921d5c518..9f5453321 100644 --- a/tests/vm_impl/array_ops_vm_impl.py +++ b/tests/vm_impl/array_ops_vm_impl.py @@ -14,7 +14,6 @@ # ============================================================================ """Generate vm_impl function for array ops""" import numpy as np - import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.ops import operations as P @@ -22,7 +21,6 @@ from mindspore.ops.operations import _grad_ops as G from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters from .vm_interface import vm - # pylint: disable=unused-argument @@ -181,8 +179,7 @@ def vm_impl_tile(self): def vm_impl(x, multiples): x = x.asnumpy() - multiples = multiples.asnumpy() - out = vm.Tile(x, multiples) + out = np.tile(x, multiples) return Tensor(out) return vm_impl @@ -255,7 +252,10 @@ def vm_impl_sum(self): def vm_impl(x, axis): x = x.asnumpy() - out = vm.sum(x, axis) + if axis == (): + out = np.sum(x) + else: + out = np.sum(x, axis=axis) return Tensor(np.array(out)) return vm_impl @@ -291,12 +291,14 @@ def vm_impl_square(self): return vm_impl + @vm_impl_getters.register(P.ZerosLike) def vm_impl_zeros_like(self): """Generate vm_impl function for ZerosLike""" def vm_impl(x): return Tensor(np.zeros_like(x.asnumpy())) + @vm_impl_getters.register(P.Partial) def vm_impl_partial(self): """Generate vm_impl function for Partial""" @@ -307,6 +309,7 @@ def vm_impl_partial(self): return vm_impl + @vm_impl_getters.register(P.Depend) def vm_impl_depend(self): """Generate vm_impl function for Depend""" diff --git a/tests/vm_impl/math_ops_vm_impl.py b/tests/vm_impl/math_ops_vm_impl.py index d40961643..9a614c9c9 100644 --- a/tests/vm_impl/math_ops_vm_impl.py +++ b/tests/vm_impl/math_ops_vm_impl.py @@ -196,6 +196,18 @@ def vm_impl_reduce_mean(self): return vm_impl +@vm_impl_getters.register(P.ReduceMax) +def vm_impl_reduce_max(self): + """Generate vm_impl function for ReduceMean.""" + + def vm_impl(x, axis): + x = x.asnumpy() + if axis == (): + axis = None + out = np.amax(x, axis) + return Tensor(out) + + return vm_impl @vm_impl_getters.register(P.Equal) def vm_impl_equal(self): -- GitLab