未验证 提交 0a7c7c1c 编写于 作者: L Leo Chen 提交者: GitHub

use has_grad instead of train_mode (#29309) (#29346)

* use has_grad instead of train_mode

* add vlog for debug

* fix ut

* fix ut
上级 f616daaa
......@@ -35,6 +35,8 @@ class VariableWrapper {
explicit VariableWrapper(const std::string& name) : name_(name) {}
~VariableWrapper() { VLOG(10) << "Destruct VariableWrapper: " << Name(); }
const framework::Variable& Var() const { return var_; }
framework::Variable* MutableVar() { return &var_; }
......
......@@ -326,13 +326,13 @@ class no_grad_:
def __enter__(self):
tracer = framework._dygraph_tracer()
if tracer:
self.orig = tracer._train_mode
tracer._train_mode = False
self.orig = tracer._has_grad
tracer._has_grad = False
def __exit__(self, *args):
tracer = framework._dygraph_tracer()
if tracer:
tracer._train_mode = self.orig
tracer._has_grad = self.orig
@signature_safe_contextmanager
......
......@@ -288,13 +288,13 @@ class TestImperative(unittest.TestCase):
self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2
print(tmp)
self.assertFalse(tmp.stop_gradient)
self.assertTrue(tmp.stop_gradient)
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
self.assertTrue(tmp._grad_ivar() is not None)
self.assertTrue(tmp._grad_ivar() is None)
self.assertTrue(l0.weight._grad_ivar() is not None)
def test_sum_op(self):
......
......@@ -79,7 +79,8 @@ class TestTracerMode2(TestTracerMode):
class TestNoGradClass(unittest.TestCase):
@paddle.no_grad()
def no_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, False)
self.assertEqual(self.tracer._train_mode, True)
self.assertEqual(self.tracer._has_grad, False)
return a
def test_main(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册