提交 9fe6074c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1384 [control sink refactor]Update real input if it is a call

Merge pull request !1384 from chenfei_mindspore/sort-call-node
......@@ -942,7 +942,6 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
auto switch_node = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_node);
MS_LOG(INFO) << "switch : " << switch_node->DebugString();
auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
auto partial = switch_node->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
......@@ -950,7 +949,6 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(1);
MS_EXCEPTION_IF_NULL(graph_node);
MS_LOG(INFO) << graph_node->DebugString();
auto graph_value_node = graph_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_value_node);
auto graph_value = graph_value_node->value();
......@@ -976,5 +974,17 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
}
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
}
bool AnfRuntimeAlgorithm::IsWhileTrueGraph(const KernelGraphPtr &child_graph) {
auto call_nodes = child_graph->FindNodeByPrimitive(prim::kPrimCall);
for (const auto &call_node : call_nodes) {
auto graphs = GetCallNodeKernelGraph(call_node);
if (graphs.size() == 1 && graphs[0] == child_graph->parent_graph()) {
return true;
}
}
return false;
}
} // namespace session
} // namespace mindspore
......@@ -185,6 +185,7 @@ class AnfRuntimeAlgorithm {
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
static std::vector<KernelGraphPtr> GetCallNodeKernelGraph(const CNodePtr &call_node);
static bool IsSwitchCall(const CNodePtr &call_node);
static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;
......
......@@ -18,6 +18,7 @@
#include <map>
#include <tuple>
#include <set>
#include <list>
#include "operator/ops.h"
#include "ir/meta_tensor.h"
#include "ir/anf.h"
......@@ -160,7 +161,7 @@ void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) {
std::vector<CNodePtr> GetCNodes(const std::vector<AnfNodePtr> &anf_nodes) {
std::vector<CNodePtr> cnodes = {};
size_t i = 0;
for (auto anf : anf_nodes) {
for (const auto &anf : anf_nodes) {
MS_LOG(INFO) << "apply_list[" << i++ << "] = " << anf->DebugString();
MS_EXCEPTION_IF_NULL(anf);
if (anf->isa<CNode>()) {
......@@ -192,6 +193,8 @@ std::vector<std::vector<CNodePtr>> GetChildList(const KernelGraph &cur_graph, co
return ret;
}
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
void UpdateRealInput(KernelGraph *graph) {
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
auto bind_call_partial_with_parameter = [&](const std::vector<AnfNodePtr> &parameters,
......@@ -239,6 +242,15 @@ void UpdateRealInput(KernelGraph *graph) {
}
}
}
void RecurseToUpdateCallRealInput(KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "start graph id:" << graph->graph_id();
graph->UpdateCallRealInput();
for (auto &child_graph : graph->child_graph_order()) {
RecurseToUpdateCallRealInput(child_graph.get());
}
}
} // namespace
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
......@@ -254,7 +266,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph);
// split switch
SplitGraph(graph);
SplitGraphs(graph);
// insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(graph));
// resource initialize
......@@ -1366,7 +1378,7 @@ void AscendSession::SyncInitialTenosrToDevice() {
}
}
KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_graph,
KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
const std::vector<CNodePtr> &list) {
MS_EXCEPTION_IF_NULL(new_kernel_graph);
MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id();
......@@ -1376,9 +1388,6 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
for (auto &input : anf_node->inputs()) {
(void)has_output_nodes.insert(input);
}
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
new_kernel_graph->set_return(anf_node->cast<CNodePtr>());
}
}
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
// create new parameter from cnode
......@@ -1386,6 +1395,7 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
auto cnode = anf_node->cast<CNodePtr>();
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto input = cnode->inputs()[input_idx];
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>()) {
cnode->set_input(input_idx, input);
continue;
......@@ -1417,6 +1427,12 @@ KernelGraphPtr AscendSession::SplitKernelGraph(const KernelGraphPtr &new_kernel_
return new_kernel_graph;
}
void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) {
SplitGraph(root_graph);
// replace the real input if the real input is a call
RecurseToUpdateCallRealInput(root_graph.get());
}
void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
MS_EXCEPTION_IF_NULL(graph);
......@@ -1426,6 +1442,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
// get child list from current graph
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(*graph, apply_list);
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
// if child graph list only has a call ,then return the exist call
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
return child_graph_list[0];
}
......@@ -1440,22 +1457,22 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) {
for (auto &child_graph_node : child_graph_list) {
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
}
SplitKernelGraph(child_graph, child_graph_list);
ConstructSplitedGraph(child_graph, child_graph_list);
auto new_call = graph->NewCNode(new_call_input);
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
return new_call;
};
if (child_graph_lists.size() > 1) {
std::list<AnfNodePtr> depend_input = {};
for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]);
if (call_index == 0) {
depend_input.push_front(call_node);
}
depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name()))));
auto depend = graph->NewCNode(std::vector<AnfNodePtr>(depend_input.begin(), depend_input.end()));
auto new_return_primitive =
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name())));
graph->set_return(graph->NewCNode({new_return_primitive, call_node}));
continue;
}
InsertDependToGraph(graph->graph_id(), call_node);
}
graph->set_return(graph->NewCNode({new_return_primitive, depend}));
}
graph->UpdateChildGraphOrder();
UpdateRealInput(graph.get());
......
......@@ -97,15 +97,16 @@ class AscendSession : public SessionBasic {
void SetFinalGraphOutput(const VectorRef &vec_output);
void SplitGraph(const KernelGraphPtr &graph);
// split graphs with recurse from root graph
void SplitGraphs(const KernelGraphPtr &root_graph);
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
void IRFusion(const KernelGraphPtr &graph) {}
void SelectKernelGraphKernel(const KernelGraph &graph) {}
void ConvertPredictModel(const KernelGraphPtr graph) {}
void HardwareOptimizeGraphs(const KernelGraphPtr graph) {}
void RootGraphExecutorValidate(KernelGraph *graph) {}
void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph);
KernelGraphPtr SplitKernelGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list);
KernelGraphPtr ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list);
void ChildGraphCommunicationDecrease(std::vector<std::vector<AnfNodePtr>> *anf_node_lists);
// merge execution order list of child graphs
......
......@@ -39,16 +39,35 @@ void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
MS_LOG(DEBUG) << "Push que:" << node->DebugString();
}
}
std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(call_node, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
return {item_with_index.first};
}
std::vector<AnfNodePtr> real_inputs;
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(item_with_index.first->cast<CNodePtr>());
for (const auto &child_graph : child_graphs) {
if (AnfAlgo::IsWhileTrueGraph(child_graph)) {
continue;
}
auto real_input = child_graph->output();
auto child_real_inputs = GetCallRealOutputs(real_input);
std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs));
}
return real_inputs;
}
} // namespace
std::vector<AnfNodePtr> KernelGraph::outputs() const {
MS_EXCEPTION_IF_NULL(output());
if (IsPrimitiveCNode(output(), prim::kPrimMakeTuple)) {
auto graph_output = output();
if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
auto make_tuple = output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
auto &inputs = make_tuple->inputs();
return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end());
}
return std::vector<AnfNodePtr>();
return std::vector<AnfNodePtr>(1, graph_output);
}
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
......@@ -587,6 +606,9 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
void KernelGraph::UpdateChildGraphOrder() {
MS_LOG(INFO) << "graph id:" << graph_id_;
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
for (auto &old_child_graph : child_graph_order_) {
old_child_graph->set_parent_graph(nullptr);
}
child_graph_order_.clear();
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
......@@ -640,6 +662,9 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
}
void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg) {
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg);
MS_LOG(INFO) << "parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString();
MS_EXCEPTION_IF_NULL(parameter);
MS_EXCEPTION_IF_NULL(arg);
if (real_inputs_.find(parameter) == real_inputs_.end()) {
......@@ -649,6 +674,41 @@ void KernelGraph::SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &ar
(void)args.insert(arg);
}
void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "Update graph id: " << graph_id_;
for (auto &it : real_inputs_) {
auto &parameter = it.first;
MS_EXCEPTION_IF_NULL(parameter);
auto &real_inputs = it.second;
std::set<AnfNodePtr> new_real_inputs;
std::set<AnfNodePtr> erase_real_inputs;
for (auto &real_input : real_inputs) {
// if real input is a call node ,find the child graph output act as the new real input
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " erase real input:" << item_with_index.first->DebugString();
(void)erase_real_inputs.insert(item_with_index.first);
auto call_node_outputs = GetCallRealOutputs(item_with_index.first);
for (auto &call_node_output : call_node_outputs) {
MS_EXCEPTION_IF_NULL(call_node_output);
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " insert real input:" << call_node_output->DebugString();
(void)new_real_inputs.insert(call_node_output);
}
continue;
}
for (auto &erase_node : erase_real_inputs) {
(void)real_inputs.erase(erase_node);
}
for (auto &new_real_input : new_real_inputs) {
(void)real_inputs.insert(new_real_input);
}
}
}
}
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
} // namespace session
} // namespace mindspore
......@@ -127,6 +127,8 @@ class KernelGraph : public FuncGraph {
void SetRealInput(const AnfNodePtr &parameter, const AnfNodePtr &arg);
// used to dump ir
std::string ToString() const override;
// update the real input if the node is a call
void UpdateCallRealInput();
void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; }
CNodePtr get_start_label() { return start_label_; }
......
......@@ -640,16 +640,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(func_graph_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
ConstructKernelGraph(sub_func_graph);
} else if (prim->name() == kReturnOpName) {
std::vector<AnfNodePtr> outputs;
auto inputs = cnode->inputs();
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "CNode[return] must have two inputs at least, actual inputs size is " << inputs.size();
}
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outputs));
// add a make_tuple before return as graph output
graph->set_output(ConstructOutput(outputs, graph));
continue;
}
}
......@@ -659,6 +649,9 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
graph->FrontBackendlMapAdd(node, new_cnode);
if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
graph->set_return(new_cnode);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册