提交 1f9fc652 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1154 [control sink refactoring]new compile graph

Merge pull request !1154 from chenfei_mindspore/add-child-graph-to-kernel-graph
......@@ -138,6 +138,43 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
return graph_id;
}
GraphId AscendSession::CompileGraph(const FuncGraphPtr &func_graph) {
MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph);
// split switch
SplitSwitch(graph.get());
// insert goto labels and label_sets
LinkChildGraphs(graph.get());
// resource initialize
InitRuntimeResource();
// ir fusion
IRFusion(graph);
// kernel select
SelectKernelGraphKernel(*graph);
// convert model of predict module
ConvertPredictModel(graph);
// hardware optimize
HardwareOptimizeGraphs(graph);
// adjust kernel
AdjustKernel(graph);
// root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(graph.get());
// assign stream
AssignStream(graph);
// build kernel if node is cnode
BuildKernel(graph);
// alloc mem
MemoryAlloc(graph.get());
// task generate
GenerateTaskInfo(graph);
// load task into device
LoadTask(graph);
// return the graph id to backend
auto graph_id = graph->graph_id();
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id;
}
void AscendSession::BuildGraph(GraphId graph_id) {
MS_LOG(INFO) << "start";
auto graph = GetGraph(graph_id);
......
......@@ -42,6 +42,7 @@ class AscendSession : public SessionBasic {
context_ = std::make_shared<Context>(kAscendDevice, device_id);
}
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraph(const FuncGraphPtr &func_graph) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildGraph(GraphId) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
......@@ -92,6 +93,14 @@ class AscendSession : public SessionBasic {
void SetFinalGraphOutput(const ValuePtr &value);
void SetFinalGraphOutput(const VectorRef &vec_output);
void SplitSwitch(KernelGraph *graph) {}
void LinkChildGraphs(KernelGraph *graph) {}
void IRFusion(const KernelGraphPtr &graph) {}
void SelectKernelGraphKernel(const KernelGraph &graph) {}
void ConvertPredictModel(const KernelGraphPtr graph) {}
void HardwareOptimizeGraphs(const KernelGraphPtr graph) {}
void RootGraphExecutorValidate(KernelGraph *graph) {}
// merge execution order list of child graphs
void MergeGraphExecOrder();
// insert assion op to sync data bettween different graphs
......
......@@ -580,5 +580,23 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
}
}
void KernelGraph::UpdateChildGraphOrder() {}
std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
if (IsLeafGraph()) {
leaf_graph_order.push_back(shared_from_this()->cast<KernelGraphPtr>());
} else {
for (const auto &child_graph : child_graph_order_) {
MS_EXCEPTION_IF_NULL(child_graph);
auto child_leaf_graph_order = child_graph->GetLeafGraphOrder();
std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order));
}
}
return leaf_graph_order;
}
bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
} // namespace session
} // namespace mindspore
......@@ -99,6 +99,14 @@ class KernelGraph : public FuncGraph {
uint32_t stream_distinction_label() { return stream_distinction_label_; }
// refresh execute kernel stream label
void UpdateExecuteKernelStreamLabel();
// calculate the leaf graph order of root graph
std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder();
// update the child graph order of graph
void UpdateChildGraphOrder();
// get the child graph of current graph
std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; }
// checkout whether current graph is leaf graph
bool IsLeafGraph() const;
private:
// remove value node form graph
......@@ -136,6 +144,12 @@ class KernelGraph : public FuncGraph {
bool executable_;
// valid inputs
std::vector<bool> valid_inputs_;
// new members for control sink process
// all child grahs refers to partial node
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> node_to_child_graphs_;
// child graph execute order in root graph
std::vector<std::shared_ptr<KernelGraph>> child_graph_order_;
};
} // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
......
......@@ -494,6 +494,8 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
return graph;
}
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &) { return nullptr; }
// run graph steps
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
......
......@@ -57,6 +57,7 @@ class SessionBasic {
virtual ~SessionBasic() { summary_callback_ = nullptr; }
virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual GraphId CompileGraph(const FuncGraphPtr &) { return kInvalidGraphId; }
// build graph, used to handle multiple child graphs
virtual void BuildGraph(GraphId) {}
......@@ -72,6 +73,7 @@ class SessionBasic {
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册