bind summary nodes to KernelGraph in order to memory reuse

无相关合并请求
......@@ -302,6 +302,18 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
return graph_id;
}
void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto graph_order = GetGraphOrder(kernel_graph->graph_id());
for (auto graph_id : graph_order) {
auto child_graph = GetGraph(graph_id);
if (child_graph->summary_node_exist()) {
kernel_graph->set_summary_node_exist(true);
return;
}
}
kernel_graph->set_summary_node_exist(false);
}
void AscendSession::BuildGraph(GraphId graph_id) {
MS_LOG(INFO) << "start";
auto graph = GetGraph(graph_id);
......@@ -317,6 +329,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
InsertAllAssigns();
// insert switch and active to child graph
MergeSwitchCompile();
SetFinalGraphSummaryFlag(graph);
// OptChildGraphs
auto graph_order = GetGraphOrder(final_graph_id_);
auto &graph_type = GetGraphOrderType(final_graph_id_);
......@@ -328,6 +341,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
auto child_graph = GetGraph(graph_order[i]);
CompileChildGraph(child_graph);
}
GetSummaryNodes(graph.get());
// merge child graph
MergeGraphExecOrder();
} else {
......@@ -725,6 +739,28 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
return final_graph_id_;
}
void AscendSession::GetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
// if final graph have no child graph
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
if (graph_order_iter == graph_execute_orders_.end()) {
SessionBasic::GetSummaryNodes(graph);
return;
}
// for every child graph, find summary nodes
auto summary = graph->summary_nodes();
auto graph_order = GetGraphOrder(graph->graph_id());
for (size_t i = 0; i < graph_order.size(); i++) {
auto child_graph = GetGraph(graph_order[i]);
SessionBasic::GetSummaryNodes(child_graph.get());
auto child_graph_summary = child_graph->summary_nodes();
summary.insert(child_graph_summary.begin(), child_graph_summary.end());
}
graph->set_summary_nodes(summary);
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
}
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
auto fake_graph = GetGraph(fake_graph_id);
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
......
......@@ -67,6 +67,7 @@ class AscendSession : public SessionBasic {
void SetActive(GraphId, GraphId) override;
// compile child graph when session have multiple child graphs
void CompileChildGraph(const KernelGraphPtr &child_graph);
void GetSummaryNodes(KernelGraph *graph);
private:
void InitRuntimeResource();
......@@ -149,6 +150,7 @@ class AscendSession : public SessionBasic {
AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output);
// sync intial tensors' data to device
void SyncInitialTenosrToDevice();
void SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph);
// member variables
// key is final_graph_id,value is child graph execute order of final graph
......
......@@ -40,6 +40,7 @@ class KernelGraph : public FuncGraph {
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
execution_order_ = {};
executable_ = true;
summary_node_exist_ = false;
stream_distinction_label_ = kInvalidDistincLabel;
}
~KernelGraph() override;
......@@ -90,6 +91,10 @@ class KernelGraph : public FuncGraph {
bool executable() const { return executable_; }
// set executable of graph
void set_executable(bool executable) { executable_ = executable; }
// set summary_node of graph
void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; }
// check whether exist summary node in graph
bool summary_node_exist() const { return summary_node_exist_; }
// set invalid inputs for control sink
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
std::vector<bool> valid_inputs() const { return valid_inputs_; }
......@@ -132,6 +137,8 @@ class KernelGraph : public FuncGraph {
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
CNodePtr get_start_label() { return start_label_; }
const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
private:
// remove value node form graph
......@@ -165,6 +172,9 @@ class KernelGraph : public FuncGraph {
// record map between ref final output anf with index and ref origin input with index
std::map<AnfWithOutIndex, AnfWithOutIndex> ref_out_in_map_;
std::unordered_map<AnfNodePtr, std::vector<std::pair<AnfNodePtr, size_t>>> node_output_edges_;
std::map<std::string, std::pair<AnfNodePtr, int>> summary_nodes_;
// exist summary node in graph
bool summary_node_exist_;
// graph needn't execute
bool executable_;
// valid inputs
......
......@@ -54,46 +54,6 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
return py_param.ptr();
}
void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
summary->clear();
auto apply_list = TopoSort(graph->get_return());
for (auto &n : apply_list) {
MS_EXCEPTION_IF_NULL(n);
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
auto cnode = n->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() <= kSummaryGetItem) {
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
}
auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
}
(*summary)[n->fullname_with_scope()] = item_with_index;
}
}
MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
}
bool ExistSummaryNode(const KernelGraph *graph) {
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto all_nodes = DeepLinkedGraphSearch(ret);
for (auto &n : all_nodes) {
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
return true;
}
}
return false;
}
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
const std::vector<tensor::TensorPtr> &input_tensors) {
MS_EXCEPTION_IF_NULL(node);
......@@ -330,6 +290,19 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
(void)tab_str.append(any.ToString());
MS_LOG(INFO) << tab_str;
}
bool ExistSummaryNode(const KernelGraph *graph) {
auto ret = graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto all_nodes = DeepLinkedGraphSearch(ret);
for (auto &n : all_nodes) {
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
return true;
}
}
return false;
}
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
......@@ -595,6 +568,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);
}
opt::BackendCommonOptimization(graph);
return graph;
}
......@@ -658,6 +634,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);
}
return graph;
}
......@@ -751,6 +730,36 @@ void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
}
void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph);
if (!graph->summary_node_exist()) {
return;
}
auto summary = graph->summary_nodes();
auto apply_list = TopoSort(graph->get_return());
for (auto &n : apply_list) {
MS_EXCEPTION_IF_NULL(n);
if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
auto cnode = n->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() <= kSummaryGetItem) {
MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
}
auto node = cnode->input(kSummaryGetItem);
MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
}
summary[n->fullname_with_scope()] = item_with_index;
}
}
graph->set_summary_nodes(summary);
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
}
void SessionBasic::Summary(KernelGraph *graph) {
if (summary_callback_ == nullptr) {
return;
......@@ -760,8 +769,12 @@ void SessionBasic::Summary(KernelGraph *graph) {
if (!exist_summary) {
return;
}
std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
GetSummaryNodes(graph, &summary_outputs);
GetSummaryNodes(graph);
auto summary_outputs = graph->summary_nodes();
// do not exist summary node
if (summary_outputs.empty()) {
return;
}
std::map<std::string, tensor::TensorPtr> params_list;
// fetch outputs apply kernel in session & run callback functions
for (auto &output_item : summary_outputs) {
......
......@@ -92,6 +92,7 @@ class SessionBasic {
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
virtual void SetActive(GraphId, GraphId) {}
virtual void GetSummaryNodes(KernelGraph *graph);
protected:
virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册