From 1e52cb450c647ca2c90bc64b1e9fd0e8dea27bf3 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Thu, 4 Jun 2020 15:04:12 +0800 Subject: [PATCH] reuse parameter in new ConstructKernelGraph --- mindspore/ccsrc/session/session_basic.cc | 38 +++++++++++++++++------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index d3befcefe..331f5071e 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -324,8 +324,9 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf if (python_paras_ == nullptr) { python_paras_ = std::make_shared>(); } - if (python_paras_->find(m_tensor) != python_paras_->end()) { - new_parameter = (*python_paras_)[m_tensor]; + auto iter = python_paras_->find(m_tensor); + if (iter != python_paras_->end()) { + new_parameter = iter->second; } else { TraceManager::DebugTrace(std::make_shared(anf->debug_info())); new_parameter = graph->NewParameter(anf->cast()); @@ -502,13 +503,23 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph if (!anf->isa()) { MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; } - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - TraceManager::DebugTrace(std::make_shared(anf->debug_info())); - auto new_parameter = graph->NewParameter(anf->cast()); - TraceManager::EndTrace(); - graph_inputs->push_back(new_parameter); - graph->FrontBackendlMapAdd(anf, new_parameter); + + auto m_tensor = GetParamDefaultInputTensor(anf); + ParameterPtr new_parameter = nullptr; + if (python_paras_ == nullptr) { + python_paras_ = std::make_shared>(); + } + auto iter = python_paras_->find(m_tensor); + if (iter != python_paras_->end()) { + new_parameter = iter->second; + } else { + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); + new_parameter = graph->NewParameter(anf->cast()); + if (m_tensor != nullptr) { + (*python_paras_)[m_tensor] = new_parameter; + } + TraceManager::EndTrace(); + } return new_parameter; } @@ -571,7 +582,11 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); if (node->isa()) { - (void)CreateNewParameter(node, graph.get()); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + auto new_parameter = CreateNewParameter(node, graph.get()); + graph_inputs->push_back(new_parameter); + graph->FrontBackendlMapAdd(node, new_parameter); continue; } else if (node->isa()) { if (!IsValueNode(node)) { @@ -629,7 +644,8 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector ¶ auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); if (backend_parameter == nullptr) { // for example "def f(x,y,z) {return x + y}", parameter z in unused - CreateNewParameterFromParameter(parameter, true, graph); + auto new_parameter = CreateNewParameter(parameter, graph); + graph_inputs->push_back(new_parameter); MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); continue; } -- GitLab