From 28dabf0332b11be75f2c3e61275cd7efd02d9ddc Mon Sep 17 00:00:00 2001 From: kingfo Date: Wed, 5 Aug 2020 14:32:34 +0800 Subject: [PATCH] fix grad flag update issue in pynative --- .../optimizer/irpass/arithmetic_simplify.cc | 6 ++++++ .../ccsrc/pipeline/pynative/pynative_execute.cc | 1 + mindspore/nn/cell.py | 14 +++++++------- mindspore/ops/composite/base.py | 6 +++--- tests/ut/cpp/optimizer/lib_test.cc | 3 +++ tests/ut/python/pynative_mode/test_hook.py | 2 +- 6 files changed, 21 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index 5a6b20d78..46cc91443 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -20,6 +20,9 @@ namespace mindspore { namespace opt { namespace irpass { AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + return nullptr; + } PatternNode x, y, z, xs; PConstant one_(node, false, 1); PConstant one_scalar_(node, false, 1, true); @@ -68,6 +71,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr } AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + return nullptr; + } PatternNode x, y; PConstant zero_(node, false, 0); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 7328b3b78..0d55b70e5 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1223,6 +1223,7 @@ void PynativeExecutor::Clear(const std::string &flag) { } MS_LOG(DEBUG) << "Clear"; + grad_flag_ = false; top_g_ = nullptr; df_builder_ = nullptr; curr_g_ = nullptr; diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 44ef9c370..93375d15d 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -84,16 +84,16 @@ class Cell: self._backward_hook = None self.enable_hook = False self._bprop_debug = False - self._is_run = False + self._already_run = False self.cell_type = None @property - def is_run(self): - return self._is_run + def already_run(self): + return self._already_run - @is_run.setter - def is_run(self, value): - self._is_run = value + @already_run.setter + def already_run(self, value): + self._already_run = value @property def create_time(self): @@ -260,7 +260,7 @@ class Cell: _pynative_exec.end_graph(self, output, *inputs) for i, cell in enumerate(self.cells()): cell.set_grad(orign_grad[i]) - self._is_run = True + self._already_run = True return output def __setattr__(self, name, value): diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 0f28d9572..766bedc5d 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -129,14 +129,14 @@ class GradOperation(GradOperation_): output = fn(*args) _pynative_exec.end_graph(fn, output, *args) else: - if fn.is_run and not fn.requires_grad: + if fn.already_run and not fn.requires_grad: raise ValueError("obj must set_grad.") - if not fn.is_run: + if not fn.already_run: self.need_forward = True - print("already has forward run before grad by user") if self.need_forward: fn.set_grad() fn(*args) + fn.already_run = False def __call__(self, fn, weights=None): grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 48865a10d..ae7212fe4 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -40,6 +40,9 @@ class TestOptLib : public UT::Common { void SetUp() { UT::InitPythonPath(); parse::data_converter::ClearObjectCache(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_execution_mode(kGraphMode); } FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) { equiv_node.clear(); diff --git a/tests/ut/python/pynative_mode/test_hook.py b/tests/ut/python/pynative_mode/test_hook.py index f34a81ab5..8c5fc9d42 100644 --- a/tests/ut/python/pynative_mode/test_hook.py +++ b/tests/ut/python/pynative_mode/test_hook.py @@ -152,7 +152,7 @@ def test_hook(): assert cell_hook_done assert var_hook_done assert cell_bprop_done - print(loss_output.asnumpy().shape) + print(loss_output.asnumpy()) bprop_debug = False -- GitLab