提交 34214e8f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3946 Ignore create parameter from control depend inputs

Merge pull request !3946 from YuJianfeng/master
...@@ -288,6 +288,22 @@ bool ExistSummaryNode(const KernelGraph *graph) { ...@@ -288,6 +288,22 @@ bool ExistSummaryNode(const KernelGraph *graph) {
} }
return false; return false;
} }
bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &node_inputs = cnode->inputs();
for (size_t i = 1; i < node_inputs.size(); ++i) {
if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) {
return false;
}
}
return true;
}
} // namespace } // namespace
GraphId SessionBasic::graph_sum_ = 0; GraphId SessionBasic::graph_sum_ = 0;
...@@ -354,8 +370,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr ...@@ -354,8 +370,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters; std::vector<AnfNodePtr> parameters;
std::vector<AnfNodePtr> pre_graph_out = {node}; std::vector<AnfNodePtr> pre_graph_out = {node};
if (IgnoreCreateParameterForMakeTuple(node)) {
pre_graph_out.clear();
}
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
if (!AnfAlgo::IsRealKernel(node)) { if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
} }
auto valid_inputs = graph->MutableValidInputs(); auto valid_inputs = graph->MutableValidInputs();
...@@ -431,7 +450,8 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool ...@@ -431,7 +450,8 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph); auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
if (parameters.empty()) { if (parameters.empty()) {
MS_LOG(EXCEPTION) << "No parameter exist!!"; MS_LOG(INFO) << "Empty parameter from cnode";
return nullptr;
} }
if (parameters.size() == 1) { if (parameters.size() == 1) {
return parameters[0]; return parameters[0];
...@@ -505,11 +525,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K ...@@ -505,11 +525,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
continue; continue;
} else if (optimize_control_depend) { } else if (optimize_control_depend) {
cnode_inputs.push_back(NewValueNode(MakeValue(input_idx))); cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
} else { } else {
*from_other_graph = true; *from_other_graph = true;
// the input node is a cnode from other graph // the input node is a cnode from other graph
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
if (parameter_from_cnode == nullptr) {
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx)));
}
cnode_inputs.push_back(parameter_from_cnode); cnode_inputs.push_back(parameter_from_cnode);
(*other_graph_cnode)[anf] = parameter_from_cnode; (*other_graph_cnode)[anf] = parameter_from_cnode;
} }
...@@ -878,7 +901,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -878,7 +901,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
auto tensor = inputs[i]; auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
auto input_node = input_nodes[i]; auto input_node = input_nodes[i];
if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->execution_mode() == kPynativeMode || if (ms_context->execution_mode() == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) { AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
......
...@@ -79,6 +79,42 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, c ...@@ -79,6 +79,42 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, c
return output; return output;
} }
namespace {
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr,
AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(inputs_ptr);
MS_EXCEPTION_IF_NULL(eqv_ptr);
MS_EXCEPTION_IF_NULL(node);
auto &inputs = *inputs_ptr;
auto &eqv = *eqv_ptr;
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
eqv[node] = node;
} else if (eqv.find(node) == eqv.end()) {
bool ignore_make_tuple = false;
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
ignore_make_tuple = true;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &node_inputs = cnode->inputs();
for (size_t i = 1; i < node_inputs.size(); ++i) {
if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) {
ignore_make_tuple = false;
break;
}
}
}
if (!ignore_make_tuple) {
inputs.push_back(node);
}
eqv[node] = fg->add_parameter();
eqv[node]->set_abstract(node->abstract());
eqv[node]->set_kernel_info(node->kernel_info_ptr());
}
return eqv[node];
}
} // namespace
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
auto fg = std::make_shared<FuncGraph>(); auto fg = std::make_shared<FuncGraph>();
AnfNodePtrList inputs; AnfNodePtrList inputs;
...@@ -86,17 +122,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr ...@@ -86,17 +122,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
if (lst.empty()) { if (lst.empty()) {
MS_LOG(EXCEPTION) << "Input anf node list is empty"; MS_LOG(EXCEPTION) << "Input anf node list is empty";
} }
auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr {
if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) {
eqv[a] = a;
} else if (eqv.find(a) == eqv.end()) {
inputs.push_back(a);
eqv[a] = fg->add_parameter();
eqv[a]->set_abstract(a->abstract());
eqv[a]->set_kernel_info(a->kernel_info_ptr());
}
return eqv[a];
};
// Merge CNodes into a AnfGraph that represents a linear instruction segment // Merge CNodes into a AnfGraph that represents a linear instruction segment
for (auto n : lst) { for (auto n : lst) {
if (!n->isa<CNode>()) { if (!n->isa<CNode>()) {
...@@ -122,11 +147,12 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr ...@@ -122,11 +147,12 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {
args.emplace_back(NewValueNode(MakeValue(i))); args.emplace_back(NewValueNode(MakeValue(i)));
} else { } else {
args.emplace_back(ref(inps[i])); args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv));
} }
} }
} else { } else {
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
} }
eqv[n] = fg->NewCNode(args); eqv[n] = fg->NewCNode(args);
eqv[n]->set_abstract(n->abstract()); eqv[n]->set_abstract(n->abstract());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册