提交 b346f0b3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5026 Fix the problem of resource clear

Merge pull request !5026 from Simson/enhancement-API
......@@ -130,6 +130,13 @@ static std::string GetId(const py::object &obj) {
}
return prefix + key;
}
if (py::isinstance<mindspore::Type>(to_process)) {
auto type_ptr = py::cast<mindspore::TypePtr>(to_process);
return prefix + type_ptr->ToString();
}
if (py::isinstance<py::str>(to_process)) {
return prefix + std::string(py::str(to_process));
}
if (py::isinstance<py::int_>(to_process)) {
return prefix + std::string(py::str(to_process));
}
......@@ -1258,17 +1265,24 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
pipeline::ReclaimOptimizer();
}
template <typename T>
void MapClear(T map, const std::string &flag) {
for (auto it = map.begin(); it != map.end();) {
if (it->first.find(flag) != std::string::npos) {
it->second = nullptr;
it = map.erase(it);
} else {
it++;
}
}
}
void PynativeExecutor::Clear(const std::string &flag) {
if (!flag.empty()) {
MS_LOG(DEBUG) << "Clear res";
auto key_value = std::find_if(graph_map_.begin(), graph_map_.end(),
[&flag](const auto &item) { return item.first.find(flag) != std::string::npos; });
if (key_value != graph_map_.end()) {
std::string key = key_value->first;
(void)graph_map_.erase(key);
(void)cell_graph_map_.erase(key);
(void)cell_resource_map_.erase(key);
}
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(graph_map_, flag);
MapClear<std::unordered_map<std::string, FuncGraphPtr>>(cell_graph_map_, flag);
MapClear<std::unordered_map<std::string, ResourcePtr>>(cell_resource_map_, flag);
Clean();
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto ms_context = MsContext::GetInstance();
......@@ -1286,7 +1300,6 @@ void PynativeExecutor::Clear(const std::string &flag) {
curr_g_ = nullptr;
graph_info_map_.clear();
op_id_map_.clear();
// node_abs_map_.clear();
std::stack<FuncGraphPtr>().swap(graph_p_);
ConfigManager::GetInstance().ResetIterNum();
}
......@@ -1300,7 +1313,18 @@ void PynativeExecutor::Clean() {
pipeline::ReclaimOptimizer();
}
template <typename T>
void MapErase(T map) {
for (auto it = map.begin(); it != map.end();) {
it = map.erase(it++);
}
}
void PynativeExecutor::ClearRes() {
MapErase<std::unordered_map<std::string, FuncGraphPtr>>(graph_map_);
MapErase<std::unordered_map<std::string, FuncGraphPtr>>(cell_graph_map_);
MapErase<std::unordered_map<std::string, ResourcePtr>>(cell_resource_map_);
MapErase<std::unordered_map<std::string, abstract::AbstractBasePtr>>(node_abs_map_);
Clean();
resource_.reset();
}
......
......@@ -102,13 +102,17 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
py::object grad_dtype = grads[i].attr("dtype");
py::tuple arg_shape = py_args[i].attr("shape");
py::object arg_dtype = py_args[i].attr("dtype");
if (!grad_shape.equal(arg_shape) || !grad_dtype.is(arg_dtype)) {
MS_EXCEPTION(ValueError) << "For user define net bprop, the gradient of the " << i
<< "th arg should have the same shape and dtype as the " << i << "th arg, but the "
<< i << "th arg shape: " << py::cast<py::str>(arg_shape)
<< " and dtype: " << py::cast<py::str>(arg_dtype)
<< ", the gradient shape: " << py::cast<py::str>(grad_shape)
<< " and dtype: " << py::cast<py::str>(grad_dtype) << ".";
if (!grad_shape.equal(arg_shape)) {
MS_EXCEPTION(ValueError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same shape as the " << i << "th arg, but the " << i
<< "th arg shape is: " << py::cast<py::str>(arg_shape)
<< ", the gradient shape is: " << py::cast<py::str>(grad_shape) << ".";
}
if (!grad_dtype.is(arg_dtype)) {
MS_EXCEPTION(TypeError) << "When user defines the net bprop, the gradient of the " << i
<< "th arg should have the same dtype as the " << i << "th arg, but the " << i
<< "th arg dtype is: " << py::cast<py::str>(arg_dtype)
<< ", the gradient dtype is: " << py::cast<py::str>(grad_dtype) << ".";
}
}
}
......
......@@ -227,6 +227,7 @@ def dtype_to_pytype(type_):
return {
bool_: bool,
int_: int,
int8: int,
int16: int,
int32: int,
......@@ -235,6 +236,7 @@ def dtype_to_pytype(type_):
uint16: int,
uint32: int,
uint64: int,
float_: float,
float16: float,
float32: float,
float64: float,
......
......@@ -116,7 +116,6 @@ def test_user_define_bprop_check_shape():
grad_net = GradNet(net)
with pytest.raises(ValueError) as ex:
ret = grad_net(x, sens)
assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value)
def test_user_define_bprop_check_dtype():
......@@ -145,9 +144,8 @@ def test_user_define_bprop_check_dtype():
context.set_context(mode=context.PYNATIVE_MODE, check_bprop=True)
net = Net()
grad_net = GradNet(net)
with pytest.raises(ValueError) as ex:
with pytest.raises(TypeError) as ex:
ret = grad_net(x, sens)
assert "the gradient of the 0th arg should have the same shape and dtype as the 0th arg" in str(ex.value)
def test_user_define_bprop_check_parameter():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册