提交 0a24f91d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4103 Reuse parameter by ref info

Merge pull request !4103 from chenfei_mindspore/reuse-parameter-by-ref-info
...@@ -523,38 +523,39 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &no ...@@ -523,38 +523,39 @@ std::vector<Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &no
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
TypePtr type_ptr = node->Type(); auto get_single_type = [](const TypePtr &type_ptr) -> TypeId {
MS_EXCEPTION_IF_NULL(type_ptr); MS_EXCEPTION_IF_NULL(type_ptr);
if (type_ptr->isa<TensorType>() && output_idx == 0) { if (type_ptr->isa<TensorType>()) {
auto tensor_ptr = type_ptr->cast<TensorTypePtr>(); auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
TypePtr elem = tensor_ptr->element();
MS_EXCEPTION_IF_NULL(elem);
return elem->type_id();
} else if (type_ptr->isa<Tuple>()) {
auto tuple_ptr = type_ptr->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (output_idx >= tuple_ptr->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
}
auto tuple_i = (*tuple_ptr)[output_idx];
MS_EXCEPTION_IF_NULL(tuple_i);
if (tuple_i->isa<TensorType>()) {
auto tensor_ptr = tuple_i->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr); MS_EXCEPTION_IF_NULL(tensor_ptr);
TypePtr elem = tensor_ptr->element(); TypePtr elem = tensor_ptr->element();
MS_EXCEPTION_IF_NULL(elem); MS_EXCEPTION_IF_NULL(elem);
return elem->type_id(); return elem->type_id();
} else if (tuple_i->isa<Number>()) {
return tuple_i->type_id();
} else {
MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
return tuple_i->type_id();
} }
} else if (type_ptr->isa<Number>()) { if (type_ptr->isa<Number>()) {
return type_ptr->type_id();
}
return type_ptr->type_id(); return type_ptr->type_id();
};
auto get_tuple_type = [get_single_type](const TypePtr &type_ptr, size_t output_idx) -> TypeId {
MS_EXCEPTION_IF_NULL(type_ptr);
if (!type_ptr->isa<Tuple>()) {
return get_single_type(type_ptr);
}
auto tuple_ptr = type_ptr->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (output_idx >= tuple_ptr->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
}
return get_single_type((*tuple_ptr)[output_idx]);
};
TypePtr type_ptr = node->Type();
if (type_ptr->isa<RefType>()) {
auto ref_type_ptr = type_ptr->cast<RefTypePtr>();
MS_EXCEPTION_IF_NULL(ref_type_ptr);
return get_tuple_type(ref_type_ptr->subtype(), output_idx);
} }
return type_ptr->type_id(); return get_tuple_type(type_ptr, output_idx);
} }
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
......
...@@ -414,12 +414,6 @@ void AscendControlParser::ChildGraphDataAssign( ...@@ -414,12 +414,6 @@ void AscendControlParser::ChildGraphDataAssign(
<< node->DebugString(5) << " gives " << args.size(); << node->DebugString(5) << " gives " << args.size();
} }
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->isa<Parameter>() && memo->find(child_graph) == memo->end()) {
MS_LOG(INFO) << args[i]->DebugString() << " to " << params[i]->DebugString()
<< " should be reused, continue.";
link_list->emplace_back(args[i], params[i]);
continue;
}
InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i])); InsertMultipleAssignToGraph(kg, node, NOT_NULL(args[i]), NOT_NULL(params[i]));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册