未验证 提交 0a7c7c1c 编写于 作者: L Leo Chen 提交者: GitHub

use has_grad instead of train_mode (#29309) (#29346)

* use has_grad instead of train_mode

* add vlog for debug

* fix ut

* fix ut
上级 f616daaa
...@@ -35,6 +35,8 @@ class VariableWrapper { ...@@ -35,6 +35,8 @@ class VariableWrapper {
explicit VariableWrapper(const std::string& name) : name_(name) {} explicit VariableWrapper(const std::string& name) : name_(name) {}
~VariableWrapper() { VLOG(10) << "Destruct VariableWrapper: " << Name(); }
const framework::Variable& Var() const { return var_; } const framework::Variable& Var() const { return var_; }
framework::Variable* MutableVar() { return &var_; } framework::Variable* MutableVar() { return &var_; }
......
...@@ -326,13 +326,13 @@ class no_grad_: ...@@ -326,13 +326,13 @@ class no_grad_:
def __enter__(self): def __enter__(self):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if tracer: if tracer:
self.orig = tracer._train_mode self.orig = tracer._has_grad
tracer._train_mode = False tracer._has_grad = False
def __exit__(self, *args): def __exit__(self, *args):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if tracer: if tracer:
tracer._train_mode = self.orig tracer._has_grad = self.orig
@signature_safe_contextmanager @signature_safe_contextmanager
......
...@@ -288,13 +288,13 @@ class TestImperative(unittest.TestCase): ...@@ -288,13 +288,13 @@ class TestImperative(unittest.TestCase):
self.assertTrue(l1.weight.stop_gradient is False) self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2 tmp = l1.weight * 2
print(tmp) print(tmp)
self.assertFalse(tmp.stop_gradient) self.assertTrue(tmp.stop_gradient)
x = fluid.dygraph.to_variable(data) x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp y = l0(x) + tmp
o = l1(y) o = l1(y)
o.backward() o.backward()
self.assertTrue(tmp._grad_ivar() is not None) self.assertTrue(tmp._grad_ivar() is None)
self.assertTrue(l0.weight._grad_ivar() is not None) self.assertTrue(l0.weight._grad_ivar() is not None)
def test_sum_op(self): def test_sum_op(self):
......
...@@ -79,7 +79,8 @@ class TestTracerMode2(TestTracerMode): ...@@ -79,7 +79,8 @@ class TestTracerMode2(TestTracerMode):
class TestNoGradClass(unittest.TestCase): class TestNoGradClass(unittest.TestCase):
@paddle.no_grad() @paddle.no_grad()
def no_grad_func(self, a): def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False) self.assertEqual(self.tracer._train_mode, True)
self.assertEqual(self.tracer._has_grad, False)
return a return a
def test_main(self): def test_main(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册