未验证 提交 3e897489 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]fix no_grad context error in dy2stat (#35725)

* fix no_grad context error in dy2stat

* remove useless comments

* fix error by drop_kids in python

* add test and fix review
上级 b666fd3c
......@@ -1241,6 +1241,17 @@ All parameter, weight, gradient are variables in Paddle.
return self.GetMutable<framework::ReaderHolder>();
},
py::return_value_policy::reference)
.def("get_scope",
[](Variable &self) -> Scope * {
auto scope_vec =
self.GetMutable<std::vector<framework::Scope *>>();
PADDLE_ENFORCE_GT(
scope_vec->size(), 0,
platform::errors::InvalidArgument(
"The size of scope_vec should be greater than 0"));
return scope_vec->front();
},
py::return_value_policy::reference)
.def("set_scope", [](Variable &self, Scope &scope) {
auto scope_vec = self.GetMutable<std::vector<framework::Scope *>>();
scope_vec->emplace_back(&scope);
......
......@@ -290,10 +290,15 @@ class PartialProgramLayer:
self._valid_vars(self._params),
self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
*attrs)
self.drop_scope_if_no_grad()
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def drop_scope_if_no_grad(self):
tracer = framework._dygraph_tracer()
if self.training and not tracer._has_grad:
self._tmp_scope_vec.value().get_scope().drop_kids()
@property
def program(self):
if self.training:
......
......@@ -152,6 +152,21 @@ class TestWithTrainAndEval(unittest.TestCase):
partial_layer._train_program)
class TestWithNoGrad(unittest.TestCase):
def test_with_no_grad(self):
with fluid.dygraph.guard():
linear_net = Linear()
x_data = np.random.random((5, 10)).astype('float32')
x = fluid.dygraph.to_variable(x_data)
with paddle.no_grad():
linear_net.train()
linear_net(x)
_, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(partial_layer.program,
partial_layer._train_program)
class GPT2LMHeadModel(fluid.dygraph.Layer):
def __init__(self):
super(GPT2LMHeadModel, self).__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册