From b58cfff89dc9bb12e47926113d197d1ab95b9776 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 3 Dec 2020 18:07:56 +0800 Subject: [PATCH] use has_grad instead of train_mode (#29309) * use has_grad instead of train_mode * add vlog for debug * fix ut * fix ut --- paddle/fluid/imperative/variable_wrapper.h | 2 ++ python/paddle/fluid/dygraph/base.py | 6 +++--- .../paddle/fluid/tests/unittests/test_imperative_basic.py | 4 ++-- .../fluid/tests/unittests/test_imperative_decorator.py | 3 ++- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index fec12f2da13..5922bfcdb9f 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -35,6 +35,8 @@ class VariableWrapper { explicit VariableWrapper(const std::string& name) : name_(name) {} + ~VariableWrapper() { VLOG(10) << "Destruct VariableWrapper: " << Name(); } + const framework::Variable& Var() const { return var_; } framework::Variable* MutableVar() { return &var_; } diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 5868c9d078c..78cc9afde07 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -326,13 +326,13 @@ class no_grad_: def __enter__(self): tracer = framework._dygraph_tracer() if tracer: - self.orig = tracer._train_mode - tracer._train_mode = False + self.orig = tracer._has_grad + tracer._has_grad = False def __exit__(self, *args): tracer = framework._dygraph_tracer() if tracer: - tracer._train_mode = self.orig + tracer._has_grad = self.orig @signature_safe_contextmanager diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index d2f143d7ad4..e33e7247d02 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -288,13 +288,13 @@ class TestImperative(unittest.TestCase): self.assertTrue(l1.weight.stop_gradient is False) tmp = l1.weight * 2 print(tmp) - self.assertFalse(tmp.stop_gradient) + self.assertTrue(tmp.stop_gradient) x = fluid.dygraph.to_variable(data) y = l0(x) + tmp o = l1(y) o.backward() - self.assertTrue(tmp._grad_ivar() is not None) + self.assertTrue(tmp._grad_ivar() is None) self.assertTrue(l0.weight._grad_ivar() is not None) def test_sum_op(self): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py index 7d20a9b952e..6f86a0c0d65 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_decorator.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_decorator.py @@ -79,7 +79,8 @@ class TestTracerMode2(TestTracerMode): class TestNoGradClass(unittest.TestCase): @paddle.no_grad() def no_grad_func(self, a): - self.assertEqual(self.tracer._train_mode, False) + self.assertEqual(self.tracer._train_mode, True) + self.assertEqual(self.tracer._has_grad, False) return a def test_main(self): -- GitLab