提交 5aae0d91 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1459 Insert assign nodes for linking sub graph

Merge pull request !1459 from zhoufeng/link-assign
...@@ -28,6 +28,9 @@ namespace device { ...@@ -28,6 +28,9 @@ namespace device {
namespace ascend { namespace ascend {
static void UpdateLabelGoto(NotNull<CNodePtr> node) { static void UpdateLabelGoto(NotNull<CNodePtr> node) {
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) {
return;
}
if (node->size() <= kLabelGotoLabelId) { if (node->size() <= kLabelGotoLabelId) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size();
} }
...@@ -42,6 +45,9 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) { ...@@ -42,6 +45,9 @@ static void UpdateLabelGoto(NotNull<CNodePtr> node) {
} }
static void UpdateLabelSwitch(NotNull<CNodePtr> node) { static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) {
return;
}
if (node->size() <= kLabelGotoLabelId) { if (node->size() <= kLabelGotoLabelId) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size();
} }
...@@ -69,9 +75,12 @@ static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph> ...@@ -69,9 +75,12 @@ static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>
if (memo->find(graph.get()) != memo->end()) { if (memo->find(graph.get()) != memo->end()) {
return; return;
} }
memo->insert(graph.get());
MS_LOG(INFO) << "Assign label for " << graph->ToString(); MS_LOG(INFO) << "Assign label for " << graph->ToString();
auto nodes = TopoSort(graph->get_return()); graph->SetExecOrderByDefault();
auto nodes = graph->execution_order();
for (auto &node : nodes) { for (auto &node : nodes) {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
...@@ -97,9 +106,15 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap ...@@ -97,9 +106,15 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
if (memo->find(graph.get()) != memo->end()) { if (memo->find(graph.get()) != memo->end()) {
return; return;
} }
memo->insert(graph.get());
MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString();
auto nodes = TopoSort(graph->get_return()); graph->SetExecOrderByDefault();
auto nodes = graph->execution_order();
auto end_goto = graph->get_end_goto();
if (end_goto != nullptr) {
nodes.push_back(end_goto);
}
for (auto &node : nodes) { for (auto &node : nodes) {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
......
...@@ -53,6 +53,7 @@ class KernelRuntime { ...@@ -53,6 +53,7 @@ class KernelRuntime {
virtual bool GenTask(const session::KernelGraph *graph); virtual bool GenTask(const session::KernelGraph *graph);
bool LaunchKernel(const session::KernelGraph *graph); bool LaunchKernel(const session::KernelGraph *graph);
virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); virtual void AssignStaticMemoryInput(const session::KernelGraph *graph);
virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph);
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
DumpConfPtr GetDumpConf(); DumpConfPtr GetDumpConf();
...@@ -67,7 +68,6 @@ class KernelRuntime { ...@@ -67,7 +68,6 @@ class KernelRuntime {
TypeId type_id) = 0; TypeId type_id) = 0;
virtual bool SyncStream() = 0; virtual bool SyncStream() = 0;
void AssignStaticMemory(session::KernelGraph *graph); void AssignStaticMemory(session::KernelGraph *graph);
void AssignStaticMemoryValueNode(session::KernelGraph *graph);
void AssignDynamicMemory(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph);
void ReuseAssignDynamicMemory(session::KernelGraph *graph); void ReuseAssignDynamicMemory(session::KernelGraph *graph);
void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H #define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H
#include <set> #include <set>
#include <map>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
#include "session/kernel_graph.h" #include "session/kernel_graph.h"
...@@ -28,31 +29,44 @@ namespace session { ...@@ -28,31 +29,44 @@ namespace session {
class AscendControlParser { class AscendControlParser {
public: public:
static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map);
static void LinkGraph(NotNull<KernelGraphPtr> kg); static void LinkGraph(NotNull<KernelGraphPtr> kg);
static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node); static void InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> attch_node);
static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, static void InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
NotNull<AnfNodePtr> second_node); NotNull<AnfNodePtr> second_node);
static void ExecutorValidate(NotNull<KernelGraphPtr> root_graph);
static void UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg);
private: private:
static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node, static NotNull<CNodePtr> ProcessKernelGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &last_node,
const CNodePtr &last_label, const VectorRef &args, const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo);
NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, static void RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo); NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, static void RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo); NotNull<std::set<KernelGraphPtr> *> memo);
static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, static void RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
NotNull<std::set<KernelGraphPtr> *> memo); NotNull<std::set<KernelGraphPtr> *> memo);
static std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &in);
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label, const VectorRef &args); const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo);
static void SetSubGraphInput(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> from_graph_call_node, static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
const VectorRef &args);
static std::tuple<CNodePtr, KernelGraphPtr, VectorRef> ParsePartial(NotNull<AnfNodePtr> node); static void LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param);
static NotNull<AnfNodePtr> GetRealInput(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> param);
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static size_t SetChildGraphInput(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> node, size_t input_index);
static CNodePtr GetNextRealKernel(std::vector<CNodePtr> list, size_t start);
// root graph order
static std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> GetLabelNode(
const std::vector<CNodePtr> &nodes);
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
NotNull<KernelGraphPtr> graph);
static std::vector<CNodePtr> RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
static constexpr size_t kCNodePrim = 0; static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1; static constexpr size_t kCNodeCallArg = 1;
......
...@@ -177,10 +177,6 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co ...@@ -177,10 +177,6 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
for (size_t i = 0; i < cnodes.size(); i++) { for (size_t i = 0; i < cnodes.size(); i++) {
if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) { if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) {
auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]); auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]);
// if graph is the true branch of while,no need split graph
if (call_kernel_graph.size() == 1 && call_kernel_graph[0] == cur_graph.parent_graph()) {
continue;
}
auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i); auto prev_call_list = std::vector<CNodePtr>(cnodes.begin() + after_call_index, cnodes.begin() + i);
auto call_list = std::vector<CNodePtr>(1, cnodes[i]); auto call_list = std::vector<CNodePtr>(1, cnodes[i]);
after_call_index = i + 1; after_call_index = i + 1;
...@@ -195,9 +191,9 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co ...@@ -195,9 +191,9 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
void UpdateRealInput(KernelGraph *graph) { static void UpdateRealInput(KernelGraph *graph) {
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> &parameters, auto bind_call_arg_with_parameter = [&](const std::vector<AnfNodePtr> &parameters,
const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void { const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void {
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id(); MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
...@@ -208,8 +204,21 @@ void UpdateRealInput(KernelGraph *graph) { ...@@ -208,8 +204,21 @@ void UpdateRealInput(KernelGraph *graph) {
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
<< " and args size:" << args.size() << " not equal!"; << " and args size:" << args.size() << " not equal!";
} }
child_graph->SetExecOrderByDefault();
for (size_t i = 0; i < parameters.size(); i++) { for (size_t i = 0; i < parameters.size(); i++) {
MS_LOG(INFO) << "bind paramreter:" << parameters[i]->DebugString() << " ,arg:" << args[i]->DebugString(); if (args[i] == parameters[i]) {
child_graph->SetRealInput(parameters[i], args[i]);
MS_LOG(INFO) << "Parameter and arg are same";
continue;
}
// if arg is a parameter ,then reuse this parameter
if (args[i]->isa<Parameter>()) {
MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id()
<< " reuse parameter:" << args[i]->DebugString()
<< " of graph:" << AnfAlgo::GetGraphId(args[i].get());
child_graph->ReplaceNode(parameters[i], args[i]);
continue;
}
child_graph->SetRealInput(parameters[i], args[i]); child_graph->SetRealInput(parameters[i], args[i]);
} }
}; };
...@@ -218,9 +227,10 @@ void UpdateRealInput(KernelGraph *graph) { ...@@ -218,9 +227,10 @@ void UpdateRealInput(KernelGraph *graph) {
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node);
if (child_graphs.size() == 1) { if (child_graphs.size() == 1) {
MS_EXCEPTION_IF_NULL(child_graphs[0]); MS_EXCEPTION_IF_NULL(child_graphs[0]);
bind_call_partial_with_parameter( std::vector<AnfNodePtr> real_args =
child_graphs[0]->inputs(), std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end()), std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end());
child_graphs[0].get()); std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs();
bind_call_arg_with_parameter(child_inputs, real_args, child_graphs[0].get());
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2)); call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
} else if (child_graphs.size() == 2) { } else if (child_graphs.size() == 2) {
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> { auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> {
...@@ -237,8 +247,8 @@ void UpdateRealInput(KernelGraph *graph) { ...@@ -237,8 +247,8 @@ void UpdateRealInput(KernelGraph *graph) {
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
return ret; return ret;
}; };
bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); bind_call_arg_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); bind_call_arg_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
} }
} }
} }
...@@ -248,6 +258,11 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) { ...@@ -248,6 +258,11 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) {
MS_LOG(INFO) << "start graph id:" << graph->graph_id(); MS_LOG(INFO) << "start graph id:" << graph->graph_id();
graph->UpdateCallRealInput(); graph->UpdateCallRealInput();
for (auto &child_graph : graph->child_graph_order()) { for (auto &child_graph : graph->child_graph_order()) {
if (child_graph == graph->parent_graph()) {
MS_LOG(INFO) << "Child graph:" << child_graph->graph_id()
<< ",parent graph:" << graph->parent_graph()->graph_id();
continue;
}
RecurseToUpdateCallRealInput(child_graph.get()); RecurseToUpdateCallRealInput(child_graph.get());
} }
} }
...@@ -265,31 +280,31 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL ...@@ -265,31 +280,31 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph); auto graph = ConstructKernelGraph(func_graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// split switch // split switch
SplitGraphs(graph); SplitGraphs(graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// insert goto labels and label_sets // insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(graph)); LinkChildGraphs(NOT_NULL(graph));
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// resource initialize // resource initialize
InitRuntimeResource(); InitRuntimeResource();
// assign label // assign label
AssignLabel(NOT_NULL(graph)); AssignLabel(NOT_NULL(graph));
if (!graph->executable()) { MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
return graph->graph_id(); // recurse compile child graph
} RecurseCompileGraph(graph);
for (auto iter : graphs_) { MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
if (iter.second == graph) { // root graph valiate,include genearte execute order and so on
MS_LOG(INFO) << "Entry graph " << graph->ToString() << " graph id " << graph->graph_id(); RootGraphExecutorValidate(NOT_NULL(graph));
final_graph_id_ = graph->graph_id(); MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
}
MS_LOG(INFO) << "CompileChildGraph " << iter.second->ToString();
CompileChildGraph(iter.second);
}
// adjust kernel // adjust kernel
AdjustKernel(graph); AdjustKernel(graph);
// root graph valiate,include genearte execute order and so on MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
RootGraphExecutorValidate(graph.get());
// assign stream // assign stream
AssignStream(graph); AssignStream(graph);
// build kernel
BuildKernel(graph);
// alloc mem // alloc mem
MemoryAlloc(graph.get()); MemoryAlloc(graph.get());
// task generate // task generate
...@@ -365,6 +380,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { ...@@ -365,6 +380,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);
MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString();
opt::AscendBackendIRFusionOptimization(child_graph); opt::AscendBackendIRFusionOptimization(child_graph);
// select kernel build info // select kernel build info
SelectKernel(*child_graph); SelectKernel(*child_graph);
...@@ -376,12 +392,14 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { ...@@ -376,12 +392,14 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignStaticMemoryInput(child_graph.get()); runtime_instance->AssignStaticMemoryInput(child_graph.get());
runtime_instance->AssignStaticMemoryValueNode(child_graph.get());
} }
void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *const outputs) { VectorRef *const outputs) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "start";
auto kernel_graph = GetGraph(graph_id); auto kernel_graph = GetGraph(graph_id);
DumpIR("./run_graph.ir", kernel_graph);
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
// if none of child graph and no anf output exists // if none of child graph and no anf output exists
if (!kernel_graph->executable()) { if (!kernel_graph->executable()) {
...@@ -1378,10 +1396,10 @@ void AscendSession::SyncInitialTenosrToDevice() { ...@@ -1378,10 +1396,10 @@ void AscendSession::SyncInitialTenosrToDevice() {
} }
} }
KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
const std::vector<CNodePtr> &list) { const std::vector<CNodePtr> &list) {
MS_EXCEPTION_IF_NULL(new_kernel_graph); MS_EXCEPTION_IF_NULL(new_kernel_graph);
MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id(); MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id();
// count the output of every anf node // count the output of every anf node
std::set<AnfNodePtr> has_output_nodes; std::set<AnfNodePtr> has_output_nodes;
for (auto &anf_node : list) { for (auto &anf_node : list) {
...@@ -1390,21 +1408,23 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke ...@@ -1390,21 +1408,23 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke
} }
} }
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
std::vector<AnfNodePtr> call_node_inputs;
auto graph_inputs = new_kernel_graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
// create new parameter from cnode // create new parameter from cnode
for (auto &anf_node : list) { for (auto &anf_node : list) {
auto cnode = anf_node->cast<CNodePtr>(); auto cnode = anf_node->cast<CNodePtr>();
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto input = cnode->inputs()[input_idx]; auto input = cnode->inputs()[input_idx];
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>()) { if (input->isa<Parameter>()) {
graph_inputs->push_back(input);
cnode->set_input(input_idx, input); cnode->set_input(input_idx, input);
continue; } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
}
if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
cnode->set_input(input_idx, new_parameter); cnode->set_input(input_idx, new_parameter);
new_kernel_graph->SetRealInput(new_parameter, input);
} }
call_node_inputs.push_back(input);
} }
} }
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
...@@ -1424,7 +1444,7 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke ...@@ -1424,7 +1444,7 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke
new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
} }
MS_LOG(INFO) << "end"; MS_LOG(INFO) << "end";
return new_kernel_graph; return call_node_inputs;
} }
void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) { void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) {
...@@ -1438,7 +1458,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { ...@@ -1438,7 +1458,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto apply_list = GetCNodes(TopoSort(graph->get_return())); auto apply_list = GetCNodes(TopoSort(graph->get_return()));
// update the root graph child graph order // update the root graph child graph order
graph->UpdateChildGraphOrder(); AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph));
// get child list from current graph // get child list from current graph
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list); std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list);
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr { auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
...@@ -1457,7 +1477,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { ...@@ -1457,7 +1477,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
for (auto &child_graph_node : child_graph_list) { for (auto &child_graph_node : child_graph_list) {
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
} }
ConstructSplitedGraph(child_graph, child_graph_list); auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
auto new_call = graph->NewCNode(new_call_input); auto new_call = graph->NewCNode(new_call_input);
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
return new_call; return new_call;
...@@ -1466,26 +1487,59 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { ...@@ -1466,26 +1487,59 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
std::list<AnfNodePtr> depend_input = {}; std::list<AnfNodePtr> depend_input = {};
for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]); auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]);
MS_EXCEPTION_IF_NULL(call_node);
// if call node is the last call of true graph,no need create child graph after that
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
depend_input.push_front(call_node); depend_input.push_front(call_node);
if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) {
break;
}
} }
depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())))); depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))));
auto depend = graph->NewCNode(std::vector<AnfNodePtr>(depend_input.begin(), depend_input.end())); auto depend = graph->NewCNode(std::vector<AnfNodePtr>(depend_input.begin(), depend_input.end()));
auto new_return_primitive = auto new_return_primitive =
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))); graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())));
graph->set_return(graph->NewCNode({new_return_primitive, depend})); graph->set_return(graph->NewCNode({new_return_primitive, depend}));
AnfNodePtr pre_call_node = nullptr;
AnfNodePtr cur_call_node = nullptr;
auto iter = depend_input.begin();
for (++iter; iter != depend_input.end(); ++iter) {
pre_call_node = cur_call_node;
cur_call_node = *iter;
if (pre_call_node != nullptr && cur_call_node != nullptr) {
AscendControlParser::InsertControlDependToGraph(NOT_NULL(graph), NOT_NULL(cur_call_node),
NOT_NULL(pre_call_node));
}
}
} }
graph->UpdateChildGraphOrder(); AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph));
UpdateRealInput(graph.get()); UpdateRealInput(graph.get());
auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id())); auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id()));
DumpIR(graph_name, graph); DumpIR(graph_name, graph);
MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end";
// recurse to split child graph // recurse to split child graph
for (auto &child_graph : graph->child_graph_order()) { for (auto &child_graph : graph->child_graph_order()) {
if (child_graph != graph->parent_graph()) {
SplitGraph(child_graph); SplitGraph(child_graph);
} }
}
} }
void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); } void AscendSession::LinkChildGraphs(NotNull<KernelGraphPtr> graph) { AscendControlParser::LinkGraph(graph); }
void AscendSession::RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph) {
AscendControlParser::ExecutorValidate(graph);
}
void AscendSession::RecurseCompileGraph(const KernelGraphPtr &graph) {
CompileChildGraph(graph);
for (auto child_graph : graph->child_graph_order()) {
if (child_graph == graph->parent_graph()) {
continue;
}
RecurseCompileGraph(child_graph);
}
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore
...@@ -104,10 +104,10 @@ class AscendSession : public SessionBasic { ...@@ -104,10 +104,10 @@ class AscendSession : public SessionBasic {
void SelectKernelGraphKernel(const KernelGraph &graph) {} void SelectKernelGraphKernel(const KernelGraph &graph) {}
void ConvertPredictModel(const KernelGraphPtr graph) {} void ConvertPredictModel(const KernelGraphPtr graph) {}
void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} void HardwareOptimizeGraphs(const KernelGraphPtr graph) {}
void RootGraphExecutorValidate(KernelGraph *graph) {} void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph); std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
KernelGraphPtr ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list); const std::vector<CNodePtr> &list);
void ChildGraphCommunicationDecrease(std::vector<std::vector<AnfNodePtr>> *anf_node_lists); void RecurseCompileGraph(const KernelGraphPtr &graph);
// merge execution order list of child graphs // merge execution order list of child graphs
void MergeGraphExecOrder(); void MergeGraphExecOrder();
......
...@@ -165,6 +165,21 @@ void KernelGraph::SetExecOrderByDefault() { ...@@ -165,6 +165,21 @@ void KernelGraph::SetExecOrderByDefault() {
} }
} }
CheckLoop(); CheckLoop();
// resort start label / end goto
std::vector<CNodePtr> re_order;
if (start_label_ != nullptr) {
re_order.push_back(start_label_);
}
for (auto &node : execution_order_) {
if (node == start_label_ || node == end_goto_) {
continue;
}
re_order.push_back(node);
}
if (end_goto_ != nullptr) {
re_order.push_back(end_goto_);
}
execution_order_ = re_order;
} }
void KernelGraph::CheckLoop() { void KernelGraph::CheckLoop() {
...@@ -360,7 +375,8 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode ...@@ -360,7 +375,8 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode
void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) { void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
MS_EXCEPTION_IF_NULL(old_backend_anf); MS_EXCEPTION_IF_NULL(old_backend_anf);
MS_EXCEPTION_IF_NULL(new_backend_anf); MS_EXCEPTION_IF_NULL(new_backend_anf);
if (old_backend_anf.get() == new_backend_anf.get()) { if (old_backend_anf == new_backend_anf) {
MS_LOG(INFO) << "old:" << old_backend_anf->DebugString() << ",new:" << new_backend_anf->DebugString();
MS_LOG(EXCEPTION) << "old can't be same with new"; MS_LOG(EXCEPTION) << "old can't be same with new";
} }
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
...@@ -569,14 +585,13 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf ...@@ -569,14 +585,13 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
MS_EXCEPTION_IF_NULL(new_anf_node); MS_EXCEPTION_IF_NULL(new_anf_node);
MS_EXCEPTION_IF_NULL(inputs_); MS_EXCEPTION_IF_NULL(inputs_);
auto it = node_output_edges_.find(old_anf_node); auto it = node_output_edges_.find(old_anf_node);
if (it == node_output_edges_.end()) { if (it != node_output_edges_.end()) {
MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map"; const auto &outputs = it->second;
}
auto &outputs = it->second;
for (auto &output_node : outputs) { for (auto &output_node : outputs) {
MS_EXCEPTION_IF_NULL(output_node.first);
auto output_cnode = output_node.first->cast<CNodePtr>(); auto output_cnode = output_node.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode); MS_EXCEPTION_IF_NULL(output_cnode);
auto &output_node_inputs = output_cnode->inputs(); const auto &output_node_inputs = output_cnode->inputs();
for (size_t i = 1; i < output_node_inputs.size(); i++) { for (size_t i = 1; i < output_node_inputs.size(); i++) {
if (output_node_inputs[i] == old_anf_node) { if (output_node_inputs[i] == old_anf_node) {
output_cnode->set_input(i, new_anf_node); output_cnode->set_input(i, new_anf_node);
...@@ -585,16 +600,37 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf ...@@ -585,16 +600,37 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
// update graph inputs // update graph inputs
for (size_t i = 0; i < inputs_->size(); i++) { for (size_t i = 0; i < inputs_->size(); i++) {
if ((*inputs_)[i] == old_anf_node) { if ((*inputs_)[i] == old_anf_node) {
MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString()
<< ",new graph input:" << new_anf_node->DebugString();
(*inputs_)[i] = new_anf_node; (*inputs_)[i] = new_anf_node;
break; break;
} }
} }
MS_LOG(INFO) << "Inputs of graph id:" << graph_id();
for (size_t i = 0; i < inputs().size(); i++) {
MS_LOG(INFO) << "[" << i << "]:" << inputs()[i]->DebugString();
}
} }
// update front to backend map // update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node); FrontBackendlMapUpdate(old_anf_node, new_anf_node);
// update output depend relations // update output depend relations
node_output_edges_[new_anf_node] = it->second; node_output_edges_[new_anf_node] = it->second;
(void)node_output_edges_.erase(old_anf_node); (void)node_output_edges_.erase(old_anf_node);
}
// update graph inputs in child graph
auto it_real_inputs = real_inputs_.find(old_anf_node);
if (it_real_inputs != real_inputs_.end()) {
// insert new parameter to map
auto iter = real_inputs_.find(new_anf_node);
if (iter != real_inputs_.end()) {
MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited.";
iter->second = it_real_inputs->second;
} else {
real_inputs_[new_anf_node] = it_real_inputs->second;
}
// erase old parameter in map
real_inputs_.erase(old_anf_node);
}
} }
void KernelGraph::UpdateExecuteKernelStreamLabel() { void KernelGraph::UpdateExecuteKernelStreamLabel() {
...@@ -603,29 +639,6 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() { ...@@ -603,29 +639,6 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
} }
} }
void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "graph id:" << graph_id_;
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
for (auto &old_child_graph : child_graph_order_) {
old_child_graph->set_parent_graph(nullptr);
}
child_graph_order_.clear();
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
for (const auto &child_graph : call_child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
if (child_graph != parent_graph()) {
child_graph->set_parent_graph(shared_from_this()->cast<std::shared_ptr<KernelGraph>>());
child_graph_order_.push_back(child_graph);
}
}
}
for (size_t i = 0; i < child_graph_order_.size(); i++) {
MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]";
}
}
std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() { std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order; std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
if (IsLeafGraph()) { if (IsLeafGraph()) {
...@@ -643,9 +656,8 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() { ...@@ -643,9 +656,8 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
auto anf_list = TopoSort(get_return());
std::vector<CNodePtr> result; std::vector<CNodePtr> result;
for (const auto &anf : anf_list) { for (const auto &anf : execution_order_) {
if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
result.push_back(anf->cast<CNodePtr>()); result.push_back(anf->cast<CNodePtr>());
} }
...@@ -653,14 +665,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi ...@@ -653,14 +665,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
return result; return result;
} }
std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter);
if (real_inputs_.find(parameter) == real_inputs_.end()) {
return {};
}
return real_inputs_[parameter];
}
void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) { void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) {
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
...@@ -674,39 +678,43 @@ void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &ar ...@@ -674,39 +678,43 @@ void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &ar
(void)args.insert(arg); (void)args.insert(arg);
} }
std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter);
auto iter = real_inputs_.find(parameter);
if (iter != real_inputs_.end()) {
return iter->second;
}
MS_LOG(EXCEPTION) << parameter->DebugString() << " not found.";
}
void KernelGraph::UpdateCallRealInput() { void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "Update graph id: " << graph_id_; MS_LOG(INFO) << "Update graph id: " << graph_id_;
for (auto &it : real_inputs_) { for (auto &it : real_inputs_) {
auto &parameter = it.first; auto &parameter = it.first;
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
auto &real_inputs = it.second; auto &real_inputs = it.second;
std::set<AnfNodePtr> new_real_inputs; std::vector<AnfNodePtr> new_real_inputs;
std::set<AnfNodePtr> erase_real_inputs; std::set<AnfNodePtr> erase_real_inputs;
for (auto &real_input : real_inputs) { for (auto &real_input : real_inputs) {
// if real input is a call node ,find the child graph output act as the new real input // if real input is a call node ,find the child graph output act as the new real input
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0); auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first); MS_EXCEPTION_IF_NULL(item_with_index.first);
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " erase real input:" << item_with_index.first->DebugString();
(void)erase_real_inputs.insert(item_with_index.first); (void)erase_real_inputs.insert(item_with_index.first);
auto call_node_outputs = GetCallRealOutputs(item_with_index.first); new_real_inputs = GetCallRealOutputs(item_with_index.first);
for (auto &call_node_output : call_node_outputs) {
MS_EXCEPTION_IF_NULL(call_node_output);
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " insert real input:" << call_node_output->DebugString();
(void)new_real_inputs.insert(call_node_output);
}
continue; continue;
} }
}
for (auto &erase_node : erase_real_inputs) { for (auto &erase_node : erase_real_inputs) {
MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString();
(void)real_inputs.erase(erase_node); (void)real_inputs.erase(erase_node);
} }
for (auto &new_real_input : new_real_inputs) { for (auto &new_real_input : new_real_inputs) {
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " insert real input:" << new_real_input->DebugString();
(void)real_inputs.insert(new_real_input); (void)real_inputs.insert(new_real_input);
} }
} }
}
} }
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
......
...@@ -103,10 +103,9 @@ class KernelGraph : public FuncGraph { ...@@ -103,10 +103,9 @@ class KernelGraph : public FuncGraph {
void UpdateExecuteKernelStreamLabel(); void UpdateExecuteKernelStreamLabel();
// calculate the leaf graph order of root graph // calculate the leaf graph order of root graph
std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder(); std::vector<std::shared_ptr<KernelGraph>> GetLeafGraphOrder();
// update the child graph order of graph // the child graph of current graph
void UpdateChildGraphOrder(); const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order() const { return child_graph_order_; }
// get the child graph of current graph void set_child_graph_order(const std::vector<std::shared_ptr<KernelGraph>> &order) { child_graph_order_ = order; }
std::vector<std::shared_ptr<KernelGraph>> child_graph_order() const { return child_graph_order_; }
// checkout whether current graph is leaf graph // checkout whether current graph is leaf graph
bool IsLeafGraph() const; bool IsLeafGraph() const;
...@@ -123,6 +122,7 @@ class KernelGraph : public FuncGraph { ...@@ -123,6 +122,7 @@ class KernelGraph : public FuncGraph {
// find anf node in graph // find anf node in graph
std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const; std::vector<CNodePtr> FindNodeByPrimitive(const PrimitivePtr &primitive) const;
// get real inputs // get real inputs
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs() const { return real_inputs_; }
std::set<AnfNodePtr> GetRealInput(const AnfNodePtr &parameter); std::set<AnfNodePtr> GetRealInput(const AnfNodePtr &parameter);
void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg); void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg);
// used to dump ir // used to dump ir
...@@ -132,6 +132,8 @@ class KernelGraph : public FuncGraph { ...@@ -132,6 +132,8 @@ class KernelGraph : public FuncGraph {
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
CNodePtr get_start_label() { return start_label_; } CNodePtr get_start_label() { return start_label_; }
void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; }
CNodePtr get_end_goto() { return end_goto_; }
private: private:
// remove value node form graph // remove value node form graph
...@@ -185,6 +187,7 @@ class KernelGraph : public FuncGraph { ...@@ -185,6 +187,7 @@ class KernelGraph : public FuncGraph {
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_; std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_;
CNodePtr start_label_; CNodePtr start_label_;
CNodePtr end_goto_;
}; };
} // namespace session } // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
......
...@@ -147,6 +147,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, ...@@ -147,6 +147,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]"; MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]";
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first); MS_EXCEPTION_IF_NULL(item_with_index.first);
MS_LOG(INFO) << "create tensor for output after visit:" << item_with_index.first->DebugString();
// special handle for maketuple // special handle for maketuple
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
auto cnode = item_with_index.first->cast<CNodePtr>(); auto cnode = item_with_index.first->cast<CNodePtr>();
...@@ -479,31 +480,12 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) ...@@ -479,31 +480,12 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
} }
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->inputs()[input_idx]; auto anf = cnode->input(input_idx);
MS_EXCEPTION_IF_NULL(anf); MS_EXCEPTION_IF_NULL(anf);
// anf has been created before // anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue; continue;
} else if (anf->isa<ValueNode>()) {
if (!IsValueNode<FuncGraph>(anf)) {
// if input is a common value node,
auto new_value_node = CreateNewValueNode(anf, graph);
if (new_value_node != nullptr) {
cnode_inputs.emplace_back(new_value_node);
}
} else {
// if input is a ValueNode<FuncGraph>
auto new_value_node = CreateValueNodeKernelGraph(anf, graph);
if (new_value_node != nullptr) {
cnode_inputs.emplace_back(new_value_node);
}
}
continue;
} else if (anf->isa<Parameter>()) {
auto new_parameter = CreateNewParameter(anf, graph);
cnode_inputs.push_back(new_parameter);
continue;
} }
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
} }
...@@ -613,32 +595,22 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP ...@@ -613,32 +595,22 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
for (const auto &node : node_list) { for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
if (!node->isa<CNode>()) { if (node->isa<Parameter>()) {
MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode"; (void)CreateNewParameter(node, graph.get());
continue;
} else if (node->isa<ValueNode>()) {
if (!IsValueNode<FuncGraph>(node)) {
// if input is a common value node,
(void)CreateNewValueNode(node, graph.get());
} else {
// if input is a ValueNode<FuncGraph>
auto child_graph = ConstructKernelGraph(AnfAlgo::GetValueNodeFuncGraph(node));
auto new_value_node = CreateValueNodeKernelGraph(node, graph.get());
}
continue; continue;
} else { } else {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// recurse control ops: call, partial
auto attr_input = cnode->input(kAnfPrimitiveIndex);
MS_EXCEPTION_IF_NULL(attr_input);
if (IsValueNode<FuncGraph>(attr_input)) {
// recurse call subgraph
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input);
ConstructKernelGraph(sub_func_graph);
} else if (IsValueNode<Primitive>(attr_input)) {
auto prim = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == kPartialOpName) {
// recurse partial subgraph
auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex);
MS_EXCEPTION_IF_NULL(func_graph_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
ConstructKernelGraph(sub_func_graph);
}
}
// create a new cnode object // create a new cnode object
auto new_cnode = CreateNewCNode(cnode, graph.get()); auto new_cnode = CreateNewCNode(cnode, graph.get());
MS_EXCEPTION_IF_NULL(new_cnode); MS_EXCEPTION_IF_NULL(new_cnode);
...@@ -650,7 +622,21 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP ...@@ -650,7 +622,21 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
} }
} }
} }
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
graph_inputs->clear();
for (auto &parameter : func_graph->parameters()) {
MS_EXCEPTION_IF_NULL(parameter);
auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
if (backend_parameter == nullptr) {
// for example "def f(x,y,z) {return x + y}", parameter z in unused
CreateNewParameterFromParameter(parameter, false, graph.get());
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
continue;
}
MS_LOG(INFO) << "graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
graph_inputs->push_back(backend_parameter);
}
MS_EXCEPTION_IF_NULL(context_); MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = context_->manager(); FuncGraphManagerPtr manager = context_->manager();
if (manager) { if (manager) {
...@@ -716,6 +702,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -716,6 +702,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap
const std::vector<tensor::TensorPtr> &input_tensors) const { const std::vector<tensor::TensorPtr> &input_tensors) const {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(outputs); MS_EXCEPTION_IF_NULL(outputs);
if (!kernel_graph->child_graph_order().empty()) {
// use the last child graph output as the root graph output
UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
return;
}
auto anf_outputs = kernel_graph->outputs(); auto anf_outputs = kernel_graph->outputs();
for (auto &item : anf_outputs) { for (auto &item : anf_outputs) {
MS_LOG(INFO) << "update output[" << item->DebugString() << "]"; MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
......
...@@ -487,8 +487,7 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { ...@@ -487,8 +487,7 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
} }
void TraverseGraphMap( void TraverseGraphMap(
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphSet &fgs,
const FuncGraphSet &fgs,
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(manager_ptr);
MS_EXCEPTION_IF_NULL(tr); MS_EXCEPTION_IF_NULL(tr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册