diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index 5a6b20d78381028fe22873b0bc20b902f3243b7e..46cc91443bbb16bbad711f12c9931aa70beb6649 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -20,6 +20,9 @@ namespace mindspore { namespace opt { namespace irpass { AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + return nullptr; + } PatternNode x, y, z, xs; PConstant one_(node, false, 1); PConstant one_scalar_(node, false, 1, true); @@ -68,6 +71,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr } AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + return nullptr; + } PatternNode x, y; PConstant zero_(node, false, 0); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 7328b3b78e7e70b5861b805896b9206181c8aaa6..0d55b70e5f43dad6c0b8a8a15b69ed8d8aa4fab6 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1223,6 +1223,7 @@ void PynativeExecutor::Clear(const std::string &flag) { } MS_LOG(DEBUG) << "Clear"; + grad_flag_ = false; top_g_ = nullptr; df_builder_ = nullptr; curr_g_ = nullptr; diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 44ef9c3701c48058d9689a6f036427bcae76ee00..93375d15dd80320910d24f2341a6aeb6b4ed8676 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -84,16 +84,16 @@ class Cell: self._backward_hook = None self.enable_hook = False self._bprop_debug = False - self._is_run = False + self._already_run = False self.cell_type = None @property - def is_run(self): - return self._is_run + def already_run(self): + return self._already_run - @is_run.setter - def is_run(self, value): - self._is_run = value + @already_run.setter + def already_run(self, value): + self._already_run = value @property def create_time(self): @@ -260,7 +260,7 @@ class Cell: _pynative_exec.end_graph(self, output, *inputs) for i, cell in enumerate(self.cells()): cell.set_grad(orign_grad[i]) - self._is_run = True + self._already_run = True return output def __setattr__(self, name, value): diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 0f28d9572fd2bdf0d28a50a62f1ce7a53983c57f..766bedc5d04f673492f9a151c8d5402e2ea7f8ae 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -129,14 +129,14 @@ class GradOperation(GradOperation_): output = fn(*args) _pynative_exec.end_graph(fn, output, *args) else: - if fn.is_run and not fn.requires_grad: + if fn.already_run and not fn.requires_grad: raise ValueError("obj must set_grad.") - if not fn.is_run: + if not fn.already_run: self.need_forward = True - print("already has forward run before grad by user") if self.need_forward: fn.set_grad() fn(*args) + fn.already_run = False def __call__(self, fn, weights=None): grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 48865a10d24a44419443116b29d4b643812a4b3b..ae7212fe41d710865d382070bf700c16c70951e0 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -40,6 +40,9 @@ class TestOptLib : public UT::Common { void SetUp() { UT::InitPythonPath(); parse::data_converter::ClearObjectCache(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_execution_mode(kGraphMode); } FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { equiv_node.clear(); diff --git a/tests/ut/python/pynative_mode/test_hook.py b/tests/ut/python/pynative_mode/test_hook.py index f34a81ab5c09c8c223d3861805a7e7e01d7bb06e..8c5fc9d42e9a58db7c20ec8d907decb4dc840aa4 100644 --- a/tests/ut/python/pynative_mode/test_hook.py +++ b/tests/ut/python/pynative_mode/test_hook.py @@ -152,7 +152,7 @@ def test_hook(): assert cell_hook_done assert var_hook_done assert cell_bprop_done - print(loss_output.asnumpy().shape) + print(loss_output.asnumpy()) bprop_debug = False