提交 d05c22a1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5741 fix op id issue in pynative mode

Merge pull request !5741 from wangqiuliang/fix-net-id-issue
......@@ -302,6 +302,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
auto inst = pynative::PynativeExecutor::GetInstance();
inst->SaveOpForwardValue(input_value.second, input_value.first);
auto input_value_node = NewValueNode(input_value.first);
input_value_node->set_has_new_value(true);
manager->Replace(paras[i], input_value_node);
}
}
......
......@@ -674,6 +674,9 @@ ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) {
MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second;
return iter->second;
}
if (!first_grad_step_) {
++op_id_map_[id];
}
return nullptr;
}
......@@ -1021,7 +1024,10 @@ void ClearPyNativeSession() { session = nullptr; }
PynativeExecutor::~PynativeExecutor() { ClearRes(); }
PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
PynativeExecutor::PynativeExecutor() {
grad_flag_ = false;
first_grad_step_ = false;
}
void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) {
auto cell_id = GetCellId(cell, args);
......@@ -1042,6 +1048,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
cell_resource_map_[cell_id] = resource_;
df_builder_ = std::make_shared<FuncGraph>();
MS_LOG(DEBUG) << "First new graph" << top_g_.get();
first_grad_step_ = true;
top_graph_cells_.insert(cell_id);
Pushp();
} else {
Pushp();
......@@ -1223,7 +1231,9 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
resource_->manager()->AddFuncGraph(curr_g_);
// custom bprop debug
bool need_replace_param = false;
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
need_replace_param = true;
size_t par_number = py::tuple(parse::python_adapter::CallPyObjMethod(cell, "get_parameters")).size();
if (par_number > 0) {
MS_LOG(EXCEPTION) << "When user defines the net bprop, there are " << par_number
......@@ -1237,6 +1247,15 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
}
}
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
if (need_replace_param) {
auto params = newfg->parameters();
auto manager = Manage({newfg}, false);
for (size_t i = 0; i < params.size(); i++) {
ValuePtr value = PyAttrValue(args[i]);
auto v_node = NewValueNode(value);
manager->Replace(params[i], v_node);
}
}
graph_info_map_.erase(curr_g_);
if (curr_g_ != top_g_) {
Popp();
......@@ -1397,6 +1416,9 @@ void PynativeExecutor::Clear(const std::string &flag) {
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
}
ConfigManager::GetInstance().ResetIterNum();
if (top_graph_cells_.find(flag) != top_graph_cells_.end()) {
op_forward_map_.clear();
}
return;
}
......@@ -1405,6 +1427,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;
first_grad_step_ = false;
graph_info_map_.clear();
op_id_map_.clear();
obj_to_forward_id_.clear();
......@@ -1416,7 +1439,6 @@ void PynativeExecutor::Clean() {
MS_LOG(DEBUG) << "Clean all res";
Clear();
grad_flag_ = false;
op_forward_map_.clear();
ad::CleanRes();
pipeline::ReclaimOptimizer();
}
......
......@@ -24,6 +24,7 @@
#include <unordered_map>
#include <mutex>
#include <stack>
#include <set>
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
......@@ -145,6 +146,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
static ResourcePtr resource_;
static int graph_id_;
bool grad_flag_;
bool first_grad_step_;
std::unordered_map<std::string, FuncGraphPtr> graph_map_;
std::unordered_map<std::string, FuncGraphPtr> cell_graph_map_;
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
......@@ -158,6 +160,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
std::set<std::string> top_graph_cells_;
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册