未验证 提交 2b9d1922 编写于 作者: 0 0x45f 提交者: GitHub

[Cherry-pick][Dy2stat]fix no_grad context error in train mode when using...

[Cherry-pick][Dy2stat]fix no_grad context error in train mode when using save/load (#36434) (#36463)

修复使用jit.save/load接口加载模型后,在train模式和no_grad上下文中,显存会一直增长的问题
上级 cc449652
...@@ -844,6 +844,8 @@ def _run_dygraph(instance, input, program_holder): ...@@ -844,6 +844,8 @@ def _run_dygraph(instance, input, program_holder):
continue continue
persistable_var._set_grad_type(grad_var.type()) persistable_var._set_grad_type(grad_var.type())
drop_scope_if_no_grad(instance, tmp_scope_vec)
# 3. prepare output, keep same form with inputs # 3. prepare output, keep same form with inputs
outs = output_vars outs = output_vars
if len(output_vars) == 1: if len(output_vars) == 1:
...@@ -851,6 +853,12 @@ def _run_dygraph(instance, input, program_holder): ...@@ -851,6 +853,12 @@ def _run_dygraph(instance, input, program_holder):
return outs return outs
def drop_scope_if_no_grad(instance, scope_vec):
tracer = framework._dygraph_tracer()
if (not instance._is_test) and (not tracer._has_grad):
scope_vec.value().get_scope().drop_kids()
def _run_static_graph(input, program_holder, trace_program): def _run_static_graph(input, program_holder, trace_program):
main_program = framework.default_main_program() main_program = framework.default_main_program()
param_var_names = _get_persistable_var_names(trace_program) param_var_names = _get_persistable_var_names(trace_program)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
...@@ -69,5 +70,22 @@ class TestSaveInferenceModelAPIError(unittest.TestCase): ...@@ -69,5 +70,22 @@ class TestSaveInferenceModelAPIError(unittest.TestCase):
main_program=main_prog) main_program=main_prog)
class TestWhenTrainWithNoGrad(unittest.TestCase):
def test_when_train_with_no_grad(self):
paddle.disable_static()
net = paddle.nn.Linear(1024, 1)
net = paddle.jit.to_static(net)
x = paddle.rand([1024], 'float32')
net(x)
save_path = './train_with_no_grad'
paddle.jit.save(net, save_path)
net = paddle.jit.load(save_path)
net.train()
with paddle.no_grad():
x = paddle.rand([1024], 'float32')
net(x)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册