提交 329ddbeb 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!429 Fix bug of parameter nums don't match with input args when set child graph input

Merge pull request !429 from chenfei/expand-tuple-output-of-node-when-set-child-graph-input
......@@ -92,6 +92,51 @@ GraphId GetDistinctionLabel(const KernelGraphPtr &graph) {
// else use first node of execution order as label
return AnfAlgo::GetStreamDistinctionLabel(graph->execution_order()[0].get());
}
std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
auto valid_inputs = graph->ValidInputs();
size_t real_args_size = 0;
std::vector<BaseRef> real_args = {};
for (size_t i = 0; i < args.size(); i++) {
if (utils::isa<AnfNodePtr>(args[i])) {
auto tmp_args = AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem});
for (auto &real_arg : tmp_args) {
auto anf_node = utils::cast<AnfNodePtr>(real_arg);
MS_EXCEPTION_IF_NULL(anf_node);
auto abstract = anf_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
// create multiple parameters if is a tuple output real kernel
if (abstract->isa<abstract::AbstractTuple>() &&
!AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
real_args_size += tuple_abstract->size();
continue;
}
real_args_size += 1;
real_args.push_back(real_arg);
}
} else {
real_args_size += 1;
real_args.push_back(args[i]);
}
}
if (graph_inputs.size() != valid_inputs.size()) {
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
}
if (real_args_size != graph_inputs.size()) {
for (size_t j = 0; j < valid_inputs.size(); j++) {
if (valid_inputs[j]) {
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
}
}
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
<< " not equal";
}
return real_args;
}
} // namespace
GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
......@@ -763,38 +808,26 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) {
UpdateGraphOrder(g);
std::vector<AnfNodePtr> graph_inputs = to_graph->inputs();
auto valid_inputs = to_graph->ValidInputs();
size_t real_args_size = 0;
for (size_t i = 0; i < args.size(); i++) {
real_args_size += AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem}).size();
}
if (real_args_size != graph_inputs.size()) {
for (size_t j = 0; j < valid_inputs.size(); j++) {
if (valid_inputs[j]) {
MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString();
}
}
MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size()
<< " not equal";
}
auto real_args = GetRealArgs(to_graph, args);
size_t input_index = 0;
if (graph_inputs.size() != valid_inputs.size()) {
MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size()
<< ", valid_inputs.size(): " << valid_inputs.size() << " not equal";
}
for (size_t i = 0; i < args.size(); i++) {
for (size_t i = 0; i < real_args.size(); i++) {
if (input_index >= graph_inputs.size()) {
MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size();
}
if (utils::isa<AnfNodePtr>(args[i])) {
if (utils::isa<AnfNodePtr>(real_args[i])) {
// arg is a anf node
for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast<AnfNodePtr>(args[i]), {prim::kPrimTupleGetItem})) {
if (!valid_inputs[input_index]) {
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString();
continue;
}
auto real_arg = utils::cast<AnfNodePtr>(real_args[i]);
auto real_arg_output_num = AnfAlgo::GetOutputTensorNum(real_arg);
if (!AnfAlgo::CheckPrimitiveType(real_arg, prim::kPrimTupleGetItem) && real_arg_output_num > 1) {
input_index += real_arg_output_num;
continue;
}
if (valid_inputs[input_index]) {
SetChildGraphParameter(real_arg, graph_inputs[input_index]);
input_index++;
} else {
MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString();
}
input_index++;
} else if (utils::isa<ValuePtr>(args[i])) {
auto value = utils::cast<ValuePtr>(args[i]);
MS_EXCEPTION_IF_NULL(value);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册