提交 d4de0c5a 编写于 作者: Z zhoufeng

fix BackendCommonOptimization order

Signed-off-by: Nzhoufeng <zhoufeng54@huawei.com>
上级 beb436f4
......@@ -29,6 +29,7 @@
#include "device/ascend/ascend_kernel_runtime.h"
#include "device/ascend/ascend_device_address.h"
#include "pre_activate/ascend/ascend_backend_optimization.h"
#include "pre_activate/common/common_backend_optimization.h"
#include "device/kernel_adjust.h"
#include "device/ascend/ascend_stream_assign.h"
#include "device/ascend/ascend_label_assign.h"
......@@ -283,36 +284,38 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph);
std::vector<KernelGraphPtr> all_graphs;
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
BackendOptimization(all_graphs);
// split switch
SplitGraphs(NOT_NULL(graph));
SplitGraphs(NOT_NULL(root_graph));
// insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(graph));
LinkChildGraphs(NOT_NULL(root_graph));
// resource initialize
InitRuntimeResource();
// assign label
AssignLabel(NOT_NULL(graph));
// recurse compile child graph
AssignLabel(NOT_NULL(root_graph));
// recurse compile child root_graph
std::set<KernelGraphPtr> memo;
RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo));
// root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(NOT_NULL(graph));
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
// root root_graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(NOT_NULL(root_graph));
// adjust kernel
AdjustKernel(graph);
AdjustKernel(root_graph);
// assign stream
AssignStream(graph);
AssignStream(root_graph);
// insert profiling point
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
// build kernel
BuildKernel(graph);
BuildKernel(root_graph);
// alloc mem
MemoryAlloc(graph.get());
MemoryAlloc(root_graph.get());
// task generate
GenerateTaskInfo(graph);
GenerateTaskInfo(root_graph);
// load task into device
LoadTask(graph);
// return the graph id to backend
auto graph_id = graph->graph_id();
LoadTask(root_graph);
// return the root_graph id to backend
auto graph_id = root_graph->graph_id();
return graph_id;
}
......@@ -1569,6 +1572,14 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
return call_node_inputs;
}
void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs) {
MS_LOG(INFO) << "Start BackendCommonOptimization";
for (auto &graph : all_graphs) {
opt::BackendCommonOptimization(graph);
}
MS_LOG(INFO) << "End.";
}
void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
std::set<KernelGraphPtr> memo;
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence
......
......@@ -102,6 +102,7 @@ class AscendSession : public SessionBasic {
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims);
// split graphs with recurse from root graph
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
......
......@@ -579,8 +579,10 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
return graph;
}
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(all_out_graph);
auto node_list = TopoSort(func_graph->get_return());
auto graph = NewKernelGraph();
front_backend_graph_map_[func_graph] = graph;
......@@ -607,7 +609,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
is_trace_back = true;
} else {
(void)ConstructKernelGraph(child_graph);
(void)ConstructKernelGraph(child_graph, all_out_graph);
}
(void)CreateValueNodeKernelGraph(node, graph.get());
}
......@@ -634,7 +636,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);
}
opt::BackendCommonOptimization(graph);
all_out_graph->push_back(graph);
return graph;
}
......
......@@ -75,7 +75,8 @@ class SessionBasic {
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
std::vector<KernelGraphPtr> *all_out_graph);
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册