From 285f225ecaa673f183aa8663808644ffa51fc66d Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Mon, 1 Jun 2020 19:08:31 +0800 Subject: [PATCH] Improve performance of finding summary nodes --- mindspore/ccsrc/session/kernel_graph.h | 7 +++++++ mindspore/ccsrc/session/session_basic.cc | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 8c8ba5f8b..497bc8df9 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -40,6 +40,7 @@ class KernelGraph : public FuncGraph { inputs_ = std::make_shared>(); execution_order_ = {}; executable_ = true; + summary_node_exist_ = false; stream_distinction_label_ = kInvalidDistincLabel; } ~KernelGraph() override; @@ -90,6 +91,10 @@ class KernelGraph : public FuncGraph { bool executable() const { return executable_; } // set executable of graph void set_executable(bool executable) { executable_ = executable; } + // set summary_node of graph + void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; } + // check whether exist summary node in graph + bool summary_node_exist() const { return summary_node_exist_; } // set invalid inputs for control sink std::vector *MutableValidInputs() { return &valid_inputs_; } std::vector valid_inputs() const { return valid_inputs_; } @@ -172,6 +177,8 @@ class KernelGraph : public FuncGraph { std::unordered_map>> node_output_edges_; // graph needn't execute bool executable_; + // exist summary node in graph + bool summary_node_exist_; // valid inputs std::vector valid_inputs_; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index b1bfefcac..fc40fd3d5 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -291,6 +291,19 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { (void)tab_str.append(any.ToString()); MS_LOG(INFO) << tab_str; } + +bool ExistSummaryNode(const KernelGraph *graph) { + auto ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto all_nodes = DeepLinkedGraphSearch(ret); + for (auto &n : all_nodes) { + if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || + IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { + return true; + } + } + return false; +} } // namespace GraphId SessionBasic::graph_sum_ = 0; @@ -537,6 +550,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con graph->set_manager(manager); } graph->SetExecOrderByDefault(); + if (ExistSummaryNode(graph.get())) { + graph->set_summary_node_exist(true); + } opt::BackendCommonOptimization(graph); return graph; } @@ -594,6 +610,9 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP graph->set_manager(manager); } graph->SetExecOrderByDefault(); + if (ExistSummaryNode(graph.get())) { + graph->set_summary_node_exist(true); + } return graph; } @@ -716,6 +735,9 @@ void SessionBasic::GetSummaryNodes(const KernelGraph *graph, MS_LOG(DEBUG) << "Update summary Start"; MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(summary); + if (!graph->summary_node_exist()) { + return; + } auto apply_list = TopoSort(graph->get_return()); for (auto &n : apply_list) { MS_EXCEPTION_IF_NULL(n); -- GitLab