diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 0d9839f27bead75cd286154c44e04d0dccd8aa1d..ab5e96be2f773ba953f82a5546b23dbf8073f4d5 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -23,10 +23,10 @@ #include #include +#include "base/core_ops.h" #include "ir/func_graph.h" #include "ir/primitive.h" #include "utils/context/ms_context.h" -#include "base/core_ops.h" namespace mindspore { // namespace to support intermediate representation definition @@ -191,6 +191,41 @@ std::string get_id(const AnfNodePtr &node) { void reset_id() { node_ids.clear(); } } // namespace id_generator +namespace { +std::string GetMaketupleNodeTarget(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto users = manager->node_users()[cnode]; + std::string first_user_target = GetCNodeTarget(users.back().first); + bool is_used_by_different_target = + std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair &u) -> bool { + return GetCNodeTarget(u.first) != first_user_target; + }); + if (!is_used_by_different_target) { + return first_user_target; + } + + auto inputs = cnode->inputs(); + std::vector real_inputs; + std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); + std::string first_input_target = GetCNodeTarget(real_inputs[0]); + bool is_from_different_target = + std::any_of(std::begin(real_inputs), std::end(real_inputs), + [&first_input_target](const AnfNodePtr &n) -> bool { return GetCNodeTarget(n) != first_input_target; }); + if (!is_from_different_target) { + return first_input_target; + } + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->device_target(); + return default_target; +} +} // namespace + std::string GetCNodeTarget(const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -220,10 +255,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { if (att_target != nullptr) { if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || - IsPrimitive(attr_input, prim::kPrimMakeTuple) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || - IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimTupleGetItem) || - IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || - IsPrimitive(attr_input, prim::kPrimPartial)) { + IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || + IsPrimitive(attr_input, prim::kPrimTupleGetItem) || IsPrimitive(attr_input, prim::kPrimControlDepend) || + IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { primitive->EraseAttr("primitive_target"); return default_target; } @@ -236,6 +270,9 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { } return target; } + if (IsPrimitive(node, prim::kPrimMakeTuple)) { + return GetMaketupleNodeTarget(cnode); + } return default_target; } } // namespace mindspore