提交 28dabf03 编写于 作者: K kingfo

fix grad flag update issue in pynative

上级 57fd31b2
......@@ -20,6 +20,9 @@ namespace mindspore {
namespace opt {
namespace irpass {
AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
return nullptr;
}
PatternNode x, y, z, xs;
PConstant one_(node, false, 1);
PConstant one_scalar_(node, false, 1, true);
......@@ -68,6 +71,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
}
AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
if (MsContext::GetInstance()->execution_mode() == kPynativeMode) {
return nullptr;
}
PatternNode x, y;
PConstant zero_(node, false, 0);
......
......@@ -1223,6 +1223,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
}
MS_LOG(DEBUG) << "Clear";
grad_flag_ = false;
top_g_ = nullptr;
df_builder_ = nullptr;
curr_g_ = nullptr;
......
......@@ -84,16 +84,16 @@ class Cell:
self._backward_hook = None
self.enable_hook = False
self._bprop_debug = False
self._is_run = False
self._already_run = False
self.cell_type = None
@property
def is_run(self):
return self._is_run
def already_run(self):
return self._already_run
@is_run.setter
def is_run(self, value):
self._is_run = value
@already_run.setter
def already_run(self, value):
self._already_run = value
@property
def create_time(self):
......@@ -260,7 +260,7 @@ class Cell:
_pynative_exec.end_graph(self, output, *inputs)
for i, cell in enumerate(self.cells()):
cell.set_grad(orign_grad[i])
self._is_run = True
self._already_run = True
return output
def __setattr__(self, name, value):
......
......@@ -129,14 +129,14 @@ class GradOperation(GradOperation_):
output = fn(*args)
_pynative_exec.end_graph(fn, output, *args)
else:
if fn.is_run and not fn.requires_grad:
if fn.already_run and not fn.requires_grad:
raise ValueError("obj must set_grad.")
if not fn.is_run:
if not fn.already_run:
self.need_forward = True
print("already has forward run before grad by user")
if self.need_forward:
fn.set_grad()
fn(*args)
fn.already_run = False
def __call__(self, fn, weights=None):
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
......
......@@ -40,6 +40,9 @@ class TestOptLib : public UT::Common {
void SetUp() {
UT::InitPythonPath();
parse::data_converter::ClearObjectCache();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
}
FuncGraphPtr RunTransform(FuncGraphPtr gbefore, const SubstitutionList &transform) {
equiv_node.clear();
......
......@@ -152,7 +152,7 @@ def test_hook():
assert cell_hook_done
assert var_hook_done
assert cell_bprop_done
print(loss_output.asnumpy().shape)
print(loss_output.asnumpy())
bprop_debug = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册