提交 79e82370 编写于 作者: K kswang

fix mix target entry

上级 beb436f4
...@@ -282,7 +282,7 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa ...@@ -282,7 +282,7 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
static bool IsCtrlSink(const FuncGraphPtr &graph) { static bool IsCtrlSink() {
auto ms_ctx = MsContext::GetInstance(); auto ms_ctx = MsContext::GetInstance();
if (ms_ctx->execution_mode() != kGraphMode) { if (ms_ctx->execution_mode() != kGraphMode) {
return false; return false;
...@@ -297,10 +297,9 @@ static bool IsCtrlSink(const FuncGraphPtr &graph) { ...@@ -297,10 +297,9 @@ static bool IsCtrlSink(const FuncGraphPtr &graph) {
return false; return false;
} }
if (graph != nullptr && CompileGraphs::ContainMixedTarget(graph)) { if (!ms_ctx->is_multi_graph_sink()) {
return false; return false;
} }
return true; return true;
} }
...@@ -310,27 +309,29 @@ bool TaskEmitAction(const ResourcePtr &res) { ...@@ -310,27 +309,29 @@ bool TaskEmitAction(const ResourcePtr &res) {
} }
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
if (IsCtrlSink(func_graph)) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true;
}
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
if (bc_ptr->name() == kMsConvert) {
cut_list = compile::GetMsNonlinearOps();
}
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (CompileGraphs::ContainMixedTarget(func_graph)) { if (CompileGraphs::ContainMixedTarget(func_graph)) {
bc_ptr->set_is_multi_graph_sink(false); bc_ptr->set_is_multi_graph_sink(false);
context_ptr->set_is_multi_graph_sink(false);
context_ptr->set_loop_sink_flag(false); context_ptr->set_loop_sink_flag(false);
} else if (context_ptr->execution_mode() != kPynativeMode) { } else if (context_ptr->execution_mode() != kPynativeMode) {
std::string device_target = context_ptr->device_target(); std::string device_target = context_ptr->device_target();
if (device_target == kAscendDevice) { if (device_target == kAscendDevice) {
bc_ptr->set_is_multi_graph_sink(true); bc_ptr->set_is_multi_graph_sink(true);
context_ptr->set_is_multi_graph_sink(true);
} }
} }
if (IsCtrlSink()) {
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
return true;
}
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
if (bc_ptr->name() == kMsConvert) {
cut_list = compile::GetMsNonlinearOps();
}
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
res->results()[kOutput] = compile->CompileAndLink(func_graph); res->results()[kOutput] = compile->CompileAndLink(func_graph);
return true; return true;
} }
...@@ -340,11 +341,10 @@ bool ExecuteAction(const ResourcePtr &res) { ...@@ -340,11 +341,10 @@ bool ExecuteAction(const ResourcePtr &res) {
MS_LOG(EXCEPTION) << "Execute args error"; MS_LOG(EXCEPTION) << "Execute args error";
} }
if (IsCtrlSink(nullptr)) { if (IsCtrlSink()) {
if (!res->results()[kOutput].is<GraphId>()) { if (!res->results()[kOutput].is<GraphId>()) {
MS_LOG(EXCEPTION) << "Execute args error"; MS_LOG(EXCEPTION) << "Execute args error";
} }
auto graph_id = res->results()[kOutput].cast<GraphId>(); auto graph_id = res->results()[kOutput].cast<GraphId>();
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>(); std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr); std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册