提交 34214e8f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3946 Ignore create parameter from control depend inputs

Merge pull request !3946 from YuJianfeng/master
......@@ -288,6 +288,22 @@ bool ExistSummaryNode(const KernelGraph *graph) {
}
return false;
}
bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &node_inputs = cnode->inputs();
for (size_t i = 1; i < node_inputs.size(); ++i) {
if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) {
return false;
}
}
return true;
}
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
......@@ -354,8 +370,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> parameters;
std::vector<AnfNodePtr> pre_graph_out = {node};
if (IgnoreCreateParameterForMakeTuple(node)) {
pre_graph_out.clear();
}
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
if (!AnfAlgo::IsRealKernel(node)) {
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
}
auto valid_inputs = graph->MutableValidInputs();
......@@ -431,7 +450,8 @@ AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool
MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
if (parameters.empty()) {
MS_LOG(EXCEPTION) << "No parameter exist!!";
MS_LOG(INFO) << "Empty parameter from cnode";
return nullptr;
}
if (parameters.size() == 1) {
return parameters[0];
......@@ -505,11 +525,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
continue;
} else if (optimize_control_depend) {
cnode_inputs.push_back(NewValueNode(MakeValue(input_idx)));
cnode_inputs.push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
} else {
*from_other_graph = true;
// the input node is a cnode from other graph
auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
if (parameter_from_cnode == nullptr) {
parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx)));
}
cnode_inputs.push_back(parameter_from_cnode);
(*other_graph_cnode)[anf] = parameter_from_cnode;
}
......@@ -878,7 +901,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor);
auto input_node = input_nodes[i];
if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->execution_mode() == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
......
......@@ -79,6 +79,42 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, c
return output;
}
namespace {
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *inputs_ptr,
AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
MS_EXCEPTION_IF_NULL(fg);
MS_EXCEPTION_IF_NULL(inputs_ptr);
MS_EXCEPTION_IF_NULL(eqv_ptr);
MS_EXCEPTION_IF_NULL(node);
auto &inputs = *inputs_ptr;
auto &eqv = *eqv_ptr;
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
eqv[node] = node;
} else if (eqv.find(node) == eqv.end()) {
bool ignore_make_tuple = false;
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
ignore_make_tuple = true;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &node_inputs = cnode->inputs();
for (size_t i = 1; i < node_inputs.size(); ++i) {
if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) {
ignore_make_tuple = false;
break;
}
}
}
if (!ignore_make_tuple) {
inputs.push_back(node);
}
eqv[node] = fg->add_parameter();
eqv[node]->set_abstract(node->abstract());
eqv[node]->set_kernel_info(node->kernel_info_ptr());
}
return eqv[node];
}
} // namespace
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGraph(const AnfNodePtrList &lst) {
auto fg = std::make_shared<FuncGraph>();
AnfNodePtrList inputs;
......@@ -86,17 +122,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
if (lst.empty()) {
MS_LOG(EXCEPTION) << "Input anf node list is empty";
}
auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr {
if (a->isa<ValueNode>() && !IsValueNode<FuncGraph>(a)) {
eqv[a] = a;
} else if (eqv.find(a) == eqv.end()) {
inputs.push_back(a);
eqv[a] = fg->add_parameter();
eqv[a]->set_abstract(a->abstract());
eqv[a]->set_kernel_info(a->kernel_info_ptr());
}
return eqv[a];
};
// Merge CNodes into a AnfGraph that represents a linear instruction segment
for (auto n : lst) {
if (!n->isa<CNode>()) {
......@@ -122,11 +147,12 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {
args.emplace_back(NewValueNode(MakeValue(i)));
} else {
args.emplace_back(ref(inps[i]));
args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv));
}
}
} else {
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref);
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
}
eqv[n] = fg->NewCNode(args);
eqv[n]->set_abstract(n->abstract());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册