diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 326590584809f18c982232e31d29f129d78eaf8f..7d56551ff05872f0593463412dfb1e9c99f9680f 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -282,7 +282,7 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } -static bool IsCtrlSink(const FuncGraphPtr &graph) { +static bool IsCtrlSink() { auto ms_ctx = MsContext::GetInstance(); if (ms_ctx->execution_mode() != kGraphMode) { return false; @@ -297,10 +297,9 @@ static bool IsCtrlSink(const FuncGraphPtr &graph) { return false; } - if (graph != nullptr && CompileGraphs::ContainMixedTarget(graph)) { + if (!ms_ctx->is_multi_graph_sink()) { return false; } - return true; } @@ -310,27 +309,29 @@ bool TaskEmitAction(const ResourcePtr &res) { } FuncGraphPtr func_graph = res->func_graph(); auto bc_ptr = res->results()[kBackend].cast(); - if (IsCtrlSink(func_graph)) { - res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); - return true; - } - std::vector cut_list = compile::nonlinear_ops; - if (bc_ptr->name() == kMsConvert) { - cut_list = compile::GetMsNonlinearOps(); - } - - std::shared_ptr compile = std::make_shared(bc_ptr, cut_list); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); if (CompileGraphs::ContainMixedTarget(func_graph)) { bc_ptr->set_is_multi_graph_sink(false); + context_ptr->set_is_multi_graph_sink(false); context_ptr->set_loop_sink_flag(false); } else if (context_ptr->execution_mode() != kPynativeMode) { std::string device_target = context_ptr->device_target(); if (device_target == kAscendDevice) { bc_ptr->set_is_multi_graph_sink(true); + context_ptr->set_is_multi_graph_sink(true); } } + + if (IsCtrlSink()) { + res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); + return true; + } + std::vector cut_list = compile::nonlinear_ops; + if (bc_ptr->name() == kMsConvert) { + cut_list = compile::GetMsNonlinearOps(); + } + std::shared_ptr compile = std::make_shared(bc_ptr, cut_list); res->results()[kOutput] = compile->CompileAndLink(func_graph); return true; } @@ -340,11 +341,10 @@ bool ExecuteAction(const ResourcePtr &res) { MS_LOG(EXCEPTION) << "Execute args error"; } - if (IsCtrlSink(nullptr)) { + if (IsCtrlSink()) { if (!res->results()[kOutput].is()) { MS_LOG(EXCEPTION) << "Execute args error"; } - auto graph_id = res->results()[kOutput].cast(); std::shared_ptr bc_ptr = res->results()[kBackend].cast>(); std::shared_ptr msbc_ptr = std::dynamic_pointer_cast(bc_ptr);