diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 2318a08462d5d536cffffa366d6f954c1d5b0a14..75a27f256962c9b897c5b325b1e4aada90e7c13b 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -844,6 +844,8 @@ def _run_dygraph(instance, input, program_holder): continue persistable_var._set_grad_type(grad_var.type()) + drop_scope_if_no_grad(instance, tmp_scope_vec) + # 3. prepare output, keep same form with inputs outs = output_vars if len(output_vars) == 1: @@ -851,6 +853,12 @@ def _run_dygraph(instance, input, program_holder): return outs +def drop_scope_if_no_grad(instance, scope_vec): + tracer = framework._dygraph_tracer() + if (not instance._is_test) and (not tracer._has_grad): + scope_vec.value().get_scope().drop_kids() + + def _run_static_graph(input, program_holder, trace_program): main_program = framework.default_main_program() param_var_names = _get_persistable_var_names(trace_program) diff --git a/python/paddle/fluid/tests/unittests/test_io_save_load.py b/python/paddle/fluid/tests/unittests/test_io_save_load.py index c532c1bdbaa0518620eaf54c865fc1e8466317ea..89ca28510b9b929b1fe36e0c9883da020e71555c 100644 --- a/python/paddle/fluid/tests/unittests/test_io_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_io_save_load.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest +import paddle import paddle.fluid as fluid from paddle.fluid import core @@ -69,5 +70,22 @@ class TestSaveInferenceModelAPIError(unittest.TestCase): main_program=main_prog) +class TestWhenTrainWithNoGrad(unittest.TestCase): + def test_when_train_with_no_grad(self): + paddle.disable_static() + net = paddle.nn.Linear(1024, 1) + net = paddle.jit.to_static(net) + x = paddle.rand([1024], 'float32') + net(x) + save_path = './train_with_no_grad' + paddle.jit.save(net, save_path) + net = paddle.jit.load(save_path) + net.train() + + with paddle.no_grad(): + x = paddle.rand([1024], 'float32') + net(x) + + if __name__ == '__main__': unittest.main()