From 2b9d1922da4972db946a5cea290dfe9a12cdd1e4 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Mon, 18 Oct 2021 14:34:15 +0800 Subject: [PATCH] [Cherry-pick][Dy2stat]fix no_grad context error in train mode when using save/load (#36434) (#36463) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复使用jit.save/load接口加载模型后,在train模式和no_grad上下文中,显存会一直增长的问题 --- python/paddle/fluid/dygraph/io.py | 8 ++++++++ .../fluid/tests/unittests/test_io_save_load.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 2318a08462..75a27f2569 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -844,6 +844,8 @@ def _run_dygraph(instance, input, program_holder): continue persistable_var._set_grad_type(grad_var.type()) + drop_scope_if_no_grad(instance, tmp_scope_vec) + # 3. prepare output, keep same form with inputs outs = output_vars if len(output_vars) == 1: @@ -851,6 +853,12 @@ def _run_dygraph(instance, input, program_holder): return outs +def drop_scope_if_no_grad(instance, scope_vec): + tracer = framework._dygraph_tracer() + if (not instance._is_test) and (not tracer._has_grad): + scope_vec.value().get_scope().drop_kids() + + def _run_static_graph(input, program_holder, trace_program): main_program = framework.default_main_program() param_var_names = _get_persistable_var_names(trace_program) diff --git a/python/paddle/fluid/tests/unittests/test_io_save_load.py b/python/paddle/fluid/tests/unittests/test_io_save_load.py index c532c1bdba..89ca28510b 100644 --- a/python/paddle/fluid/tests/unittests/test_io_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_io_save_load.py @@ -15,6 +15,7 @@ from __future__ import print_function import unittest +import paddle import paddle.fluid as fluid from paddle.fluid import core @@ -69,5 +70,22 @@ class TestSaveInferenceModelAPIError(unittest.TestCase): main_program=main_prog) +class TestWhenTrainWithNoGrad(unittest.TestCase): + def test_when_train_with_no_grad(self): + paddle.disable_static() + net = paddle.nn.Linear(1024, 1) + net = paddle.jit.to_static(net) + x = paddle.rand([1024], 'float32') + net(x) + save_path = './train_with_no_grad' + paddle.jit.save(net, save_path) + net = paddle.jit.load(save_path) + net.train() + + with paddle.no_grad(): + x = paddle.rand([1024], 'float32') + net(x) + + if __name__ == '__main__': unittest.main() -- GitLab