提交 00672a47 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1730 control sink refactor

Merge pull request !1730 from wenchunjiang/fix_code_check
......@@ -15,13 +15,6 @@
"""tbe common"""
import json
import os
from attrdict import AttrDict
class ParamType(AttrDict):
Required = "required"
Dynamic = "dynamic"
Optional = "optional"
class TBEException(Exception):
"""tbe exception class"""
......@@ -112,7 +105,7 @@ def get_input_output(io_info, args):
if len(item) > 1:
arg.append(info)
else:
if info['param_type'] == ParamType.Dynamic:
if info['param_type'] == 'dynamic':
arg.append(info)
args.append(arg)
else:
......
......@@ -542,7 +542,7 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph>
GetNeedActiveStreams(graph_ptr);
MS_LOG(INFO) << "after finish stream assign";
PrintGraphExeOrders(graph_ptr);
graph_ptr->PrintGraphExecuteOrder();
// Get info for D Model
generator::IRModelUtil::GetInstance().set_event_num(total_event_num());
......@@ -810,26 +810,6 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size();
graph_ptr->set_execution_order(exe_orders);
}
void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
MS_EXCEPTION_IF_NULL(graph_ptr);
auto cnode_ptr_list = graph_ptr->execution_order();
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) {
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_LOG(INFO) << "node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id["
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], event_id["
<< GetValue<uint32_t>(primitive->GetAttr(kAttrEventId)) << "]";
} else {
MS_LOG(INFO) << "node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "]";
}
}
}
} // namespace ascend
} // namespace device
} // namespace mindspore
......@@ -87,7 +87,6 @@ class AscendStreamAssign {
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id);
void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr);
void SetCommonStreamNum(uint32_t cur_stream_id);
......
......@@ -32,7 +32,6 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace mindspore {
namespace session {
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
for (auto &iter : graph_id_map) {
auto &kg = iter.second;
......@@ -356,12 +355,6 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
auto print_vector = [&](std::vector<CNodePtr> vec) -> void {
MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order";
for (size_t i = 0; i < vec.size(); i++) {
MS_LOG(INFO) << "[" << i << "][" << vec[i]->DebugString() << "]";
}
};
if (memo->find(graph) != memo->end()) {
return {};
}
......@@ -403,7 +396,7 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
}
}
graph->set_execution_order(execution_order);
print_vector(graph->execution_order());
graph->PrintGraphExecuteOrder();
return execution_order;
}
......@@ -474,6 +467,5 @@ void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
}
kg->set_child_graph_order(child_graph_order);
}
} // namespace session
} // namespace mindspore
......@@ -26,7 +26,6 @@
namespace mindspore {
namespace session {
class AscendControlParser {
public:
static void ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map);
......
......@@ -206,39 +206,40 @@ static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePt
return ret;
}
static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters, const std::vector<AnfNodePtr> &args,
KernelGraph *child_graph) {
MS_EXCEPTION_IF_NULL(child_graph);
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
if (args.empty()) {
return;
}
if (parameters.size() != args.size()) {
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
<< " and args size:" << args.size() << " not equal!";
}
child_graph->SetExecOrderByDefault();
for (size_t i = 0; i < parameters.size(); i++) {
if (args[i] == parameters[i]) {
child_graph->SetRealInput(parameters[i], args[i]);
MS_LOG(INFO) << "Parameter and arg are same";
continue;
}
// if arg is a parameter ,then reuse this parameter
if (args[i]->isa<Parameter>()) {
MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id()
<< " reuse parameter:" << args[i]->DebugString()
<< " of graph:" << AnfAlgo::GetGraphId(args[i].get());
child_graph->ReplaceNode(parameters[i], args[i]);
continue;
}
child_graph->SetRealInput(parameters[i], args[i]);
}
}
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
auto bind_call_arg_with_parameter = [&](const std::vector<AnfNodePtr> &parameters,
const std::vector<AnfNodePtr> &args, KernelGraph *child_graph) -> void {
MS_EXCEPTION_IF_NULL(child_graph);
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
if (args.empty()) {
return;
}
if (parameters.size() != args.size()) {
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
<< " and args size:" << args.size() << " not equal!";
}
child_graph->SetExecOrderByDefault();
for (size_t i = 0; i < parameters.size(); i++) {
if (args[i] == parameters[i]) {
child_graph->SetRealInput(parameters[i], args[i]);
MS_LOG(INFO) << "Parameter and arg are same";
continue;
}
// if arg is a parameter ,then reuse this parameter
if (args[i]->isa<Parameter>()) {
MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id()
<< " reuse parameter:" << args[i]->DebugString()
<< " of graph:" << AnfAlgo::GetGraphId(args[i].get());
child_graph->ReplaceNode(parameters[i], args[i]);
continue;
}
child_graph->SetRealInput(parameters[i], args[i]);
}
};
for (auto &call_node : call_nodes) {
MS_EXCEPTION_IF_NULL(call_node);
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node);
......@@ -247,7 +248,7 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
std::vector<AnfNodePtr> real_args =
std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end());
std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs();
bind_call_arg_with_parameter(child_inputs, real_args, child_graphs[0].get());
BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get());
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
} else if (child_graphs.size() == 2) {
auto get_partial_args = [&](size_t input_index) -> std::vector<AnfNodePtr> {
......@@ -264,8 +265,8 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
std::vector<AnfNodePtr>(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2));
return ret;
};
bind_call_arg_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
bind_call_arg_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
}
}
}
......@@ -1429,10 +1430,7 @@ void AscendSession::SyncInitialTenosrToDevice() {
}
}
std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
const std::vector<CNodePtr> &list) {
MS_EXCEPTION_IF_NULL(new_kernel_graph);
MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id();
static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector<CNodePtr> &list) {
// count the output of every anf node
std::set<AnfNodePtr> has_output_nodes;
for (auto &anf_node : list) {
......@@ -1440,6 +1438,28 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
(void)has_output_nodes.insert(input);
}
}
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
int output_idx = 0;
for (auto &anf_node : list) {
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
new_kernel_graph->set_return(anf_node);
}
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
make_tuple_inputs.push_back(anf_node);
}
}
if (new_kernel_graph->get_return() == nullptr) {
new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
}
}
std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
const std::vector<CNodePtr> &list) {
MS_EXCEPTION_IF_NULL(new_kernel_graph);
MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id();
MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id();
std::vector<AnfNodePtr> call_node_inputs;
std::vector<AnfNodePtr> new_graph_inputs;
......@@ -1479,22 +1499,9 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
MS_EXCEPTION_IF_NULL(graph_inputs);
graph_inputs->clear();
std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs));
MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id();
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
int output_idx = 0;
for (auto &anf_node : list) {
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
new_kernel_graph->set_return(anf_node);
}
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
make_tuple_inputs.push_back(anf_node);
}
}
if (new_kernel_graph->get_return() == nullptr) {
new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs));
}
ConstructSplitedGraphOutput(new_kernel_graph, list);
MS_LOG(INFO) << "end";
return call_node_inputs;
}
......@@ -1516,6 +1523,30 @@ void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo));
}
AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
const std::vector<CNodePtr> &child_graph_list) {
// if child graph list only has a call ,then return the exist call
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
return child_graph_list[0];
}
// create new child graph
auto child_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(child_graph);
// create new value node to bind child graph
auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph));
std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())),
graph_value_node};
// set the graph id of all node of child graph
for (auto &child_graph_node : child_graph_list) {
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
}
auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
auto new_call = graph->NewCNode(new_call_input);
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
return new_call;
}
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
......@@ -1523,32 +1554,10 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
AscendControlParser::UpdateChildGraphOrder(graph);
// get child list from current graph
std::vector<std::vector<CNodePtr>> child_graph_lists = GetChildList(apply_list, cut_prims);
auto bind_new_call_to_new_graph = [&](std::vector<CNodePtr> child_graph_list) -> AnfNodePtr {
// if child graph list only has a call ,then return the exist call
if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) {
return child_graph_list[0];
}
// create new child graph
auto child_graph = NewKernelGraph();
MS_EXCEPTION_IF_NULL(child_graph);
// create new value node to bind child graph
auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph));
std::vector<AnfNodePtr> new_call_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())),
graph_value_node};
// set the graph id of all node of child graph
for (auto &child_graph_node : child_graph_list) {
AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get());
}
auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
auto new_call = graph->NewCNode(new_call_input);
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
return new_call;
};
if (child_graph_lists.size() > 1) {
std::list<AnfNodePtr> depend_input = {};
for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) {
auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]);
auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]);
MS_EXCEPTION_IF_NULL(call_node);
// if call node is the last call of true graph,no need create child graph after that
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
......@@ -1605,6 +1614,5 @@ void AscendSession::RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const Not
RecurseCompileGraph(NOT_NULL(child_graph), memo);
}
}
} // namespace session
} // namespace mindspore
......@@ -107,6 +107,7 @@ class AscendSession : public SessionBasic {
const std::vector<CNodePtr> &list);
void RecurseCompileGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
void RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo);
AnfNodePtr BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph, const std::vector<CNodePtr> &child_graph_list);
// merge execution order list of child graphs
void MergeGraphExecOrder();
......
......@@ -735,6 +735,26 @@ void KernelGraph::UpdateCallRealInput() {
real_inputs_ = real_inputs_map;
}
void KernelGraph::PrintGraphExecuteOrder() const {
MS_LOG(INFO) << "graph:" << graph_id_ << "execution order";
for (size_t i = 0; i < execution_order_.size(); i++) {
CNodePtr cur_cnode_ptr = execution_order_[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) {
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_LOG(INFO) << "index[" << i << "], node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id["
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], event_id["
<< GetValue<uint32_t>(primitive->GetAttr(kAttrEventId)) << "], node info["
<< cur_cnode_ptr->DebugString() << "]";
} else {
MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]";
}
}
}
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
KernelGraph::~KernelGraph() {
......
......@@ -136,6 +136,7 @@ class KernelGraph : public FuncGraph {
CNodePtr get_end_goto() { return end_goto_; }
bool get_output_null() { return null_output_; }
void set_output_null(bool is_output_null) { null_output_ = is_output_null; }
void PrintGraphExecuteOrder() const;
private:
// remove value node form graph
......
......@@ -563,7 +563,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
// if input is a ValueNode<FuncGraph>
FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
MS_LOG(INFO) << "FuncGraph: " << child_graph->ToString() << " has been transformed to KernelGraph.";
is_trace_back = true;
} else {
(void)ConstructKernelGraph(child_graph);
......@@ -587,29 +586,34 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
}
// if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
graph->set_output_null(is_trace_back);
AddParameterToGraphInputs(func_graph->parameters(), graph.get());
MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = context_->manager();
if (manager) {
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
return graph;
}
void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph);
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
graph_inputs->clear();
for (auto &parameter : func_graph->parameters()) {
for (auto &parameter : parameters) {
MS_EXCEPTION_IF_NULL(parameter);
auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
if (backend_parameter == nullptr) {
// for example "def f(x,y,z) {return x + y}", parameter z in unused
CreateNewParameterFromParameter(parameter, false, graph.get());
CreateNewParameterFromParameter(parameter, false, graph);
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
continue;
}
MS_LOG(INFO) << "graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
graph_inputs->push_back(backend_parameter);
}
MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = context_->manager();
if (manager) {
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
return graph;
}
// run graph steps
......
......@@ -118,6 +118,7 @@ class SessionBasic {
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph);
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册