提交 ea65e61c 编写于 作者: W wenchunjiang

adding new constructKernelGraph to transform all subgraphs to

kernel_graph from root func_graph
上级 92d196f0
......@@ -909,5 +909,15 @@ bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto kernel_name = AnfAlgo::GetCNodeName(node);
return kernel_name == kGetNextOpName;
}
FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
auto func_graph = value->cast<FuncGraphPtr>();
return func_graph;
}
} // namespace session
} // namespace mindspore
......@@ -181,6 +181,7 @@ class AnfRuntimeAlgorithm {
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;
......
......@@ -451,6 +451,126 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
return new_cnode;
}
CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> cnode_inputs;
auto attr_input = cnode->input(kAnfPrimitiveIndex);
MS_EXCEPTION_IF_NULL(attr_input);
if (IsValueNode<FuncGraph>(attr_input)) {
// create primitive of cnode:call
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
// create a ValueNode<KernelGraph> as input of cnode:call
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
} else {
auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
if (new_value_node != nullptr) {
cnode_inputs.emplace_back(new_value_node);
}
}
} else if (attr_input->isa<CNode>()) {
// create primitive of cnode:call(switch)
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kCallOpName))};
if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
auto prim = GetCNodePrimitive(cnode_input);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() != kSwitchOpName) {
MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
}
cnode_inputs.emplace_back(cnode_input);
} else {
MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
<< ", but input[0] has not been created.";
}
} else {
// get primitive of old node
auto prim = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
// push attr to inputs[0] of new cnode
cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
}
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto anf = cnode->inputs()[input_idx];
MS_EXCEPTION_IF_NULL(anf);
// anf has been created before
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (anf->isa<ValueNode>()) {
if (!IsValueNode<FuncGraph>(anf)) {
// if input is a common value node,
auto new_value_node = CreateNewValueNode(anf, graph);
if (new_value_node != nullptr) {
cnode_inputs.emplace_back(new_value_node);
}
} else {
// if input is a ValueNode<FuncGraph>
auto new_value_node = CreateValueNodeKernelGraph(anf, graph);
if (new_value_node != nullptr) {
cnode_inputs.emplace_back(new_value_node);
}
}
continue;
} else if (anf->isa<Parameter>()) {
auto new_parameter = CreateNewParameter(anf, graph);
cnode_inputs.push_back(new_parameter);
continue;
}
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
}
TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
auto new_cnode = graph->NewCNode(cnode_inputs);
TraceManager::EndTrace();
return new_cnode;
}
ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
auto value_node = anf->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
MS_EXCEPTION_IF_NULL(sub_func_graph);
if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
}
auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
new_value_node->set_abstract(value_node->abstract());
// create new kernel_info of new value_node
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->SetFeatureMapFlag(false);
new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
graph->FrontBackendlMapAdd(anf, new_value_node);
graph->AddValueNodeToGraph(new_value_node);
return new_value_node;
}
ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(anf);
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
graph_inputs->push_back(new_parameter);
graph->FrontBackendlMapAdd(anf, new_parameter);
return new_parameter;
}
KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
auto graph = NewKernelGraph();
......@@ -494,7 +614,69 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
return graph;
}
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &) { return nullptr; }
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) {
MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph.";
return front_backend_graph_map_[func_graph];
}
auto node_list = TopoSort(func_graph->get_return());
auto graph = NewKernelGraph();
front_backend_graph_map_[func_graph] = graph;
MS_LOG(INFO) << "Create graph: " << graph->graph_id();
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode";
continue;
} else {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// recurse control ops: call, partial
auto attr_input = cnode->input(kAnfPrimitiveIndex);
MS_EXCEPTION_IF_NULL(attr_input);
if (IsValueNode<FuncGraph>(attr_input)) {
// recurse call subgraph
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input);
ConstructKernelGraph(sub_func_graph);
} else if (IsValueNode<Primitive>(attr_input)) {
auto prim = GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == kPartialOpName) {
// recurse partial subgraph
auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex);
MS_EXCEPTION_IF_NULL(func_graph_node);
auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
ConstructKernelGraph(sub_func_graph);
}
}
// create a new cnode object
auto new_cnode = CreateNewCNode(cnode, graph.get());
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
graph->FrontBackendlMapAdd(node, new_cnode);
// set original return to kernel_graph
if (IsPrimitive(new_cnode->input(kAnfPrimitiveIndex), prim::kPrimReturn)) {
graph->set_return(new_cnode);
}
}
}
MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = context_->manager();
if (manager) {
manager->AddFuncGraph(graph);
graph->set_manager(manager);
}
graph->SetExecOrderByDefault();
return graph;
}
// run graph steps
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
......
......@@ -78,6 +78,7 @@ class SessionBasic {
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph);
// set parameters of final graph
virtual GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &) { return kInvalidGraphId; }
......@@ -111,9 +112,12 @@ class SessionBasic {
// create a new kernel graph and update the graph sum
KernelGraphPtr NewKernelGraph();
ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph);
ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph);
ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph);
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
std::unordered_map<FuncGraphPtr, KernelGraphPtr> front_backend_graph_map_;
std::shared_ptr<Context> context_;
CallBackFunc summary_callback_;
static GraphId graph_sum_;
......
......@@ -141,6 +141,9 @@ constexpr auto kLabelSetOpName = "LabelSet";
constexpr auto kLabelSwitchOpName = "LabelSwitch";
constexpr auto kLabelGotoOpName = "LabelGoto";
constexpr auto kBNInferGradOpName = "BNInferGrad";
constexpr auto kCallOpName = "call";
constexpr auto kPartialOpName = "partial";
constexpr auto kSwitchOpName = "switch";
// attr key name
constexpr auto kAttrInputNames = "input_names";
......@@ -196,6 +199,7 @@ const size_t kMemAlignSize = 512;
// define special index in special node
constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kAnfPartialFuncGraphIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
constexpr auto kTupleGetItemInputSize = 3;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册