提交 1e52cb45 编写于 作者: M Margaret_wangrui

reuse parameter in new ConstructKernelGraph

上级 18ecafcf
...@@ -324,8 +324,9 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf ...@@ -324,8 +324,9 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if (python_paras_ == nullptr) { if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>(); python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
} }
if (python_paras_->find(m_tensor) != python_paras_->end()) { auto iter = python_paras_->find(m_tensor);
new_parameter = (*python_paras_)[m_tensor]; if (iter != python_paras_->end()) {
new_parameter = iter->second;
} else { } else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
...@@ -502,13 +503,23 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph ...@@ -502,13 +503,23 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
if (!anf->isa<Parameter>()) { if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
} }
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs); auto m_tensor = GetParamDefaultInputTensor(anf);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info())); ParameterPtr new_parameter = nullptr;
auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); if (python_paras_ == nullptr) {
TraceManager::EndTrace(); python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
graph_inputs->push_back(new_parameter); }
graph->FrontBackendlMapAdd(anf, new_parameter); auto iter = python_paras_->find(m_tensor);
if (iter != python_paras_->end()) {
new_parameter = iter->second;
} else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (m_tensor != nullptr) {
(*python_paras_)[m_tensor] = new_parameter;
}
TraceManager::EndTrace();
}
return new_parameter; return new_parameter;
} }
...@@ -571,7 +582,11 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP ...@@ -571,7 +582,11 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
if (node->isa<Parameter>()) { if (node->isa<Parameter>()) {
(void)CreateNewParameter(node, graph.get()); auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto new_parameter = CreateNewParameter(node, graph.get());
graph_inputs->push_back(new_parameter);
graph->FrontBackendlMapAdd(node, new_parameter);
continue; continue;
} else if (node->isa<ValueNode>()) { } else if (node->isa<ValueNode>()) {
if (!IsValueNode<FuncGraph>(node)) { if (!IsValueNode<FuncGraph>(node)) {
...@@ -629,7 +644,8 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para ...@@ -629,7 +644,8 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para
auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
if (backend_parameter == nullptr) { if (backend_parameter == nullptr) {
// for example "def f(x,y,z) {return x + y}", parameter z in unused // for example "def f(x,y,z) {return x + y}", parameter z in unused
CreateNewParameterFromParameter(parameter, true, graph); auto new_parameter = CreateNewParameter(parameter, graph);
graph_inputs->push_back(new_parameter);
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
continue; continue;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册