From 37257d6a8584b437db36f20c43109b1950474ded Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Fri, 15 Oct 2021 13:51:52 +0800 Subject: [PATCH] fix no_grad context error in train mode when using save/load (#36434) * fix no_grad context error in train mode when using save/load * change net to train mode in test case --- python/paddle/fluid/dygraph/io.py | 8 ++++++++ .../fluid/tests/unittests/test_io_save_load.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 2318a08462d..75a27f25696 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -844,6 +844,8 @@ def _run_dygraph(instance, input, program_holder): continue persistable_var._set_grad_type(grad_var.type()) + drop_scope_if_no_grad(instance, tmp_scope_vec) + # 3. prepare output, keep same form with inputs outs = output_vars if len(output_vars) == 1: @@ -851,6 +853,12 @@ def _run_dygraph(instance, input, program_holder): return outs +def drop_scope_if_no_grad(instance, scope_vec): + tracer = framework._dygraph_tracer() + if (not instance._is_test) and (not tracer._has_grad): + scope_vec.value().get_scope().drop_kids() + + def _run_static_graph(input, program_holder, trace_program): main_program = framework.default_main_program() param_var_names = _get_persistable_var_names(trace_program) diff --git a/python/paddle/fluid/tests/unittests/test_io_save_load.py b/python/paddle/fluid/tests/unittests/test_io_save_load.py index c532c1bdbaa..89ca28510b9 100644 --- a/python/paddle/fluid/tests/unittests/test_io_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_io_save_load.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest +import paddle import paddle.fluid as fluid from paddle.fluid import core @@ -69,5 +70,22 @@ class TestSaveInferenceModelAPIError(unittest.TestCase): main_program=main_prog) +class TestWhenTrainWithNoGrad(unittest.TestCase): + def test_when_train_with_no_grad(self): + paddle.disable_static() + net = paddle.nn.Linear(1024, 1) + net = paddle.jit.to_static(net) + x = paddle.rand([1024], 'float32') + net(x) + save_path = './train_with_no_grad' + paddle.jit.save(net, save_path) + net = paddle.jit.load(save_path) + net.train() + + with paddle.no_grad(): + x = paddle.rand([1024], 'float32') + net(x) + + if __name__ == '__main__': unittest.main() -- GitLab