提交 d953b2b5 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1483 Clean code of pr1459

Merge pull request !1483 from zhoufeng/code-clean
......@@ -19,6 +19,17 @@
#include "session/ascend_control_parser.h"
#include "session/anf_runtime_algorithm.h"
static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1;
static constexpr size_t kCNodeSwitchCond = 1;
static constexpr size_t kCNodeSwitchTrue = 2;
static constexpr size_t kCNodeSwitchFalse = 3;
static constexpr size_t kCNodeSwitchLength = 4;
static constexpr size_t kCNodePartialLength = 2;
static constexpr size_t kCNodePartialFunc = 1;
static constexpr size_t kCNodeSwitchLayerBranch = 2;
static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace mindspore {
namespace session {
......@@ -61,7 +72,7 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
ChildGraphDataAssign(graph_id_map);
}
CNodePtr AscendControlParser::GetNextRealKernel(std::vector<CNodePtr> list, size_t start) {
CNodePtr AscendControlParser::GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
for (size_t i = start; i < list.size() - 1; ++i) {
if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) {
return list[i];
......@@ -83,11 +94,11 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
memo->insert(kg.get());
// 2. args replace placeholder
LinkParentGraph(kg, last_node, last_label, memo);
LinkParentGraph(kg, last_node, last_label);
// 3. topological sort
kg->SetExecOrderByDefault();
std::vector<CNodePtr> nodes = kg->execution_order();
const std::vector<CNodePtr> &nodes = kg->execution_order();
if (nodes.empty()) {
MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!";
}
......@@ -149,9 +160,9 @@ void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg,
}
void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo) {
const CNodePtr &last_label) {
auto origin_return = kg->get_return();
std::vector<AnfNodePtr> origin_return_inputs = origin_return->inputs();
const std::vector<AnfNodePtr> &origin_return_inputs = origin_return->inputs();
// if entry graph, replace return with make_tuple
if (from_graph_call_node == nullptr || last_label == nullptr) {
MS_LOG(INFO) << kg->ToString() << " is entry graph.";
......@@ -173,7 +184,7 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
MS_LOG(INFO) << "process call func " << cur_node->DebugString();
// 1 get kernel graph
auto origin_inputs = cur_node->inputs();
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
......@@ -217,15 +228,14 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
}
// 3 recurse sub graph
auto origin_switch_inputs = cur_node->inputs();
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]};
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
// 3.1 branch kernel graph and args
CNodePtr partial;
KernelGraphPtr branch_fg;
std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
// 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label);
......@@ -249,9 +259,9 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch);
MS_EXCEPTION_IF_NULL(branch_tuple);
if (!branch_tuple->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode";
}
auto branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
const std::vector<AnfNodePtr> &branch_partial = utils::cast<CNodePtr>(branch_tuple)->inputs();
// 1 return label
auto back_label = kg->NewCNode({std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSetOpName))});
// 2 add depend relationship
......@@ -260,15 +270,14 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node));
}
// 3 recurse sub graph
auto origin_switch_inputs = cur_node->inputs();
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]};
for (size_t i = 0; i < branch_partial.size(); ++i) {
// 3.1 branch kernel graph and args
CNodePtr partial;
KernelGraphPtr branch_fg;
std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
// 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label);
......@@ -315,18 +324,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
InsertDependToGraph(kg, NOT_NULL(assign_node));
}
NotNull<AnfNodePtr> AscendControlParser::GetRealInput(NotNull<KernelGraphPtr> from_graph,
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> param) {
std::set<AnfNodePtr> args_list = to_graph->GetRealInput(param);
for (auto arg : args_list) {
if (arg->func_graph() == from_graph.get()) {
return NOT_NULL(arg);
}
}
MS_LOG(EXCEPTION) << to_graph->ToString() << " input " << param->DebugString() << " not from "
<< from_graph->ToString();
}
void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param) {
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) {
......@@ -369,10 +366,10 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
return {};
}
memo->insert(graph.get());
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
graph->SetExecOrderByDefault();
std::vector<CNodePtr> cnodes = graph->execution_order();
const std::vector<CNodePtr> &cnodes = graph->execution_order();
std::map<uint32_t, CNodePtr> label_map;
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map;
std::tie(label_map, label_switch_map) = GetLabelNode(cnodes);
......@@ -388,10 +385,10 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
std::find_if(label_map.begin(), label_map.end(),
[node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; });
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
if (!CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) {
if (label_iter == label_map.end() || !CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
auto child_graph = graph->child_graph_order()[label_iter->first];
auto child_graph = child_graph_order[label_iter->first];
if (child_graph == graph->parent_graph()) {
continue;
}
......@@ -407,7 +404,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
auto child_graph = graph->child_graph_order()[label_iter->first + i];
auto child_graph = child_graph_order[label_iter->first + i];
if (child_graph == graph->parent_graph()) {
continue;
}
......@@ -426,10 +423,11 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
NotNull<KernelGraphPtr> graph) {
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
// check index and child order size
if (graph->child_graph_order().size() <= static_cast<size_t>(order_index)) {
if (child_graph_order.size() <= IntToSize(order_index)) {
MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size "
<< graph->child_graph_order().size() << " goto index " << order_index;
<< child_graph_order.size() << " goto index " << order_index;
}
if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) {
......@@ -443,7 +441,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
label_index = label_goto_index;
}
// get start_label_set_index of child graph
auto child_graph = graph->child_graph_order()[order_index];
auto child_graph = child_graph_order[order_index];
MS_EXCEPTION_IF_NULL(child_graph);
auto start_label_set = child_graph->get_start_label();
if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) {
......@@ -468,8 +466,7 @@ std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t
uint32_t index = 0;
for (auto &node : nodes) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
label_map[index] = node;
++index;
label_map[index++] = node;
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
......@@ -479,8 +476,7 @@ std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t
std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
label_switch_map.insert({node, label_list});
for (size_t i = 0; i < label_list.size(); ++i) {
label_map[index] = node;
++index;
label_map[index++] = node;
}
}
}
......
......@@ -49,16 +49,15 @@ class AscendControlParser {
NotNull<std::set<KernelGraphPtr> *> memo);
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label, NotNull<std::set<KernelGraphPtr> *> memo);
const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
static void LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param);
static NotNull<AnfNodePtr> GetRealInput(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> param);
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static CNodePtr GetNextRealKernel(std::vector<CNodePtr> list, size_t start);
static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start);
// root graph order
static std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> GetLabelNode(
......@@ -67,20 +66,7 @@ class AscendControlParser {
NotNull<KernelGraphPtr> graph);
static std::vector<CNodePtr> RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1;
static constexpr size_t kCNodeSwitchCond = 1;
static constexpr size_t kCNodeSwitchTrue = 2;
static constexpr size_t kCNodeSwitchFalse = 3;
static constexpr size_t kCNodeSwitchLength = 4;
static constexpr size_t kCNodePartialLength = 2;
static constexpr size_t kCNodePartialFunc = 1;
static constexpr size_t kCNodeSwitchLayerCond = 1;
static constexpr size_t kCNodeSwitchLayerBranch = 2;
static constexpr size_t kCNodeSwitchLayerLength = 3;
};
} // namespace session
} // namespace mindspore
......
......@@ -256,7 +256,6 @@ static void UpdateRealInput(KernelGraph *graph) {
void RecurseToUpdateCallRealInput(KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "start graph id:" << graph->graph_id();
graph->UpdateCallRealInput();
for (auto &child_graph : graph->child_graph_order()) {
if (child_graph == graph->parent_graph()) {
MS_LOG(INFO) << "Child graph:" << child_graph->graph_id()
......@@ -265,6 +264,8 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) {
}
RecurseToUpdateCallRealInput(child_graph.get());
}
// this action should from bottom to top
graph->UpdateCallRealInput();
}
} // namespace
......@@ -280,27 +281,20 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
MS_LOG(INFO) << "start";
auto graph = ConstructKernelGraph(func_graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// split switch
SplitGraphs(graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// insert goto labels and label_sets
LinkChildGraphs(NOT_NULL(graph));
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// resource initialize
InitRuntimeResource();
// assign label
AssignLabel(NOT_NULL(graph));
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// recurse compile child graph
RecurseCompileGraph(graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// root graph valiate,include genearte execute order and so on
RootGraphExecutorValidate(NOT_NULL(graph));
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// adjust kernel
AdjustKernel(graph);
MS_LOG(INFO) << "graph input size:" << graph->inputs().size();
// assign stream
AssignStream(graph);
// build kernel
......@@ -313,7 +307,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
LoadTask(graph);
// return the graph id to backend
auto graph_id = graph->graph_id();
MS_LOG(INFO) << "Compile graph " << graph_id << " success";
return graph_id;
}
......
......@@ -606,10 +606,6 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
break;
}
}
MS_LOG(INFO) << "Inputs of graph id:" << graph_id();
for (size_t i = 0; i < inputs().size(); i++) {
MS_LOG(INFO) << "[" << i << "]:" << inputs()[i]->DebugString();
}
}
// update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
......@@ -713,6 +709,9 @@ void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " insert real input:" << new_real_input->DebugString();
(void)real_inputs.insert(new_real_input);
if (new_real_input->isa<Parameter>()) {
ReplaceNode(parameter, new_real_input);
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册