提交 1e52cb45 编写于 作者: M Margaret_wangrui

reuse parameter in new ConstructKernelGraph

上级 18ecafcf
......@@ -324,8 +324,9 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
}
if (python_paras_->find(m_tensor) != python_paras_->end()) {
new_parameter = (*python_paras_)[m_tensor];
auto iter = python_paras_->find(m_tensor);
if (iter != python_paras_->end()) {
new_parameter = iter->second;
} else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
......@@ -502,13 +503,23 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
if (!anf->isa<Parameter>()) {
MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
}
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
TraceManager::EndTrace();
graph_inputs->push_back(new_parameter);
graph->FrontBackendlMapAdd(anf, new_parameter);
auto m_tensor = GetParamDefaultInputTensor(anf);
ParameterPtr new_parameter = nullptr;
if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
}
auto iter = python_paras_->find(m_tensor);
if (iter != python_paras_->end()) {
new_parameter = iter->second;
} else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (m_tensor != nullptr) {
(*python_paras_)[m_tensor] = new_parameter;
}
TraceManager::EndTrace();
}
return new_parameter;
}
......@@ -571,7 +582,11 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
MS_EXCEPTION_IF_NULL(node);
MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
if (node->isa<Parameter>()) {
(void)CreateNewParameter(node, graph.get());
auto graph_inputs = graph->MutableInputs();
MS_EXCEPTION_IF_NULL(graph_inputs);
auto new_parameter = CreateNewParameter(node, graph.get());
graph_inputs->push_back(new_parameter);
graph->FrontBackendlMapAdd(node, new_parameter);
continue;
} else if (node->isa<ValueNode>()) {
if (!IsValueNode<FuncGraph>(node)) {
......@@ -629,7 +644,8 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para
auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
if (backend_parameter == nullptr) {
// for example "def f(x,y,z) {return x + y}", parameter z in unused
CreateNewParameterFromParameter(parameter, true, graph);
auto new_parameter = CreateNewParameter(parameter, graph);
graph_inputs->push_back(new_parameter);
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
continue;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册