From d4de0c5af1955417e42f8cf796aba256eea19d1d Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Mon, 22 Jun 2020 11:37:11 +0800 Subject: [PATCH] fix BackendCommonOptimization order Signed-off-by: zhoufeng --- mindspore/ccsrc/session/ascend_session.cc | 45 ++++++++++++++--------- mindspore/ccsrc/session/ascend_session.h | 1 + mindspore/ccsrc/session/session_basic.cc | 8 ++-- mindspore/ccsrc/session/session_basic.h | 3 +- 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 1cc10b243..bae10ed94 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -29,6 +29,7 @@ #include "device/ascend/ascend_kernel_runtime.h" #include "device/ascend/ascend_device_address.h" #include "pre_activate/ascend/ascend_backend_optimization.h" +#include "pre_activate/common/common_backend_optimization.h" #include "device/kernel_adjust.h" #include "device/ascend/ascend_stream_assign.h" #include "device/ascend/ascend_label_assign.h" @@ -283,36 +284,38 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL GraphId AscendSession::CompileGraph(NotNull func_graph) { MS_LOG(INFO) << "start"; - auto graph = ConstructKernelGraph(func_graph); + std::vector all_graphs; + auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); + BackendOptimization(all_graphs); // split switch - SplitGraphs(NOT_NULL(graph)); + SplitGraphs(NOT_NULL(root_graph)); // insert goto labels and label_sets - LinkChildGraphs(NOT_NULL(graph)); + LinkChildGraphs(NOT_NULL(root_graph)); // resource initialize InitRuntimeResource(); // assign label - AssignLabel(NOT_NULL(graph)); - // recurse compile child graph + AssignLabel(NOT_NULL(root_graph)); + // recurse compile child root_graph std::set memo; - RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo)); - // root graph valiate,include genearte execute order and so on - RootGraphExecutorValidate(NOT_NULL(graph)); + RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); + // root root_graph valiate,include genearte execute order and so on + RootGraphExecutorValidate(NOT_NULL(root_graph)); // adjust kernel - AdjustKernel(graph); + AdjustKernel(root_graph); // assign stream - AssignStream(graph); + AssignStream(root_graph); // insert profiling point - device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); + device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); // build kernel - BuildKernel(graph); + BuildKernel(root_graph); // alloc mem - MemoryAlloc(graph.get()); + MemoryAlloc(root_graph.get()); // task generate - GenerateTaskInfo(graph); + GenerateTaskInfo(root_graph); // load task into device - LoadTask(graph); - // return the graph id to backend - auto graph_id = graph->graph_id(); + LoadTask(root_graph); + // return the root_graph id to backend + auto graph_id = root_graph->graph_id(); return graph_id; } @@ -1569,6 +1572,14 @@ std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPt return call_node_inputs; } +void AscendSession::BackendOptimization(const std::vector &all_graphs) { + MS_LOG(INFO) << "Start BackendCommonOptimization"; + for (auto &graph : all_graphs) { + opt::BackendCommonOptimization(graph); + } + MS_LOG(INFO) << "End."; +} + void AscendSession::SplitGraphs(NotNull root_graph) { std::set memo; // if root graph output is a call node ,the root graph is condition graph of 'if' sentence diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index e035f84c9..785733011 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -102,6 +102,7 @@ class AscendSession : public SessionBasic { void SplitGraph(NotNull graph, const std::set &cut_prims); // split graphs with recurse from root graph void SplitGraphs(NotNull root_graph); + void BackendOptimization(const std::vector &all_graphs); void LinkChildGraphs(NotNull graph); void RootGraphExecutorValidate(NotNull graph); std::vector ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 46c395dd8..47d726451 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -579,8 +579,10 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con return graph; } -std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) { +std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph, + std::vector *all_out_graph) { MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(all_out_graph); auto node_list = TopoSort(func_graph->get_return()); auto graph = NewKernelGraph(); front_backend_graph_map_[func_graph] = graph; @@ -607,7 +609,7 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { is_trace_back = true; } else { - (void)ConstructKernelGraph(child_graph); + (void)ConstructKernelGraph(child_graph, all_out_graph); } (void)CreateValueNodeKernelGraph(node, graph.get()); } @@ -634,7 +636,7 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP if (ExistSummaryNode(graph.get())) { graph->set_summary_node_exist(true); } - opt::BackendCommonOptimization(graph); + all_out_graph->push_back(graph); return graph; } diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index f4f391d0f..ea156d3c7 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -75,7 +75,8 @@ class SessionBasic { virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); - std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph); + std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph, + std::vector *all_out_graph); CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, std::unordered_map *other_graph_cnode); -- GitLab