未验证 提交 1e3f01ed 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2st]Fix error when set buffer in forward (#38540)

* fix error when set buffer in forward

* add unittest

* refine class name

* refine not framework.in_dygraph_mode() in if

* fix UT error

* add comment

* refine code

* remove useless import
上级 719f7419
...@@ -1094,7 +1094,7 @@ class Layer(object): ...@@ -1094,7 +1094,7 @@ class Layer(object):
if '_parameters' in self.__dict__: if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters'] _parameters = self.__dict__['_parameters']
if name in self._parameters: if name in self._parameters:
if in_declarative_mode() and not framework.in_dygraph_mode(): if in_declarative_mode():
return _convert_into_variable(self._parameters[name]) return _convert_into_variable(self._parameters[name])
return self._parameters[name] return self._parameters[name]
if '_sub_layers' in self.__dict__: if '_sub_layers' in self.__dict__:
...@@ -1104,7 +1104,7 @@ class Layer(object): ...@@ -1104,7 +1104,7 @@ class Layer(object):
if '_buffers' in self.__dict__: if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers'] _buffers = self.__dict__['_buffers']
if name in _buffers: if name in _buffers:
if in_declarative_mode() and not framework.in_dygraph_mode(): if in_declarative_mode():
return _convert_into_variable(_buffers[name]) return _convert_into_variable(_buffers[name])
return _buffers[name] return _buffers[name]
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
...@@ -1176,11 +1176,16 @@ class Layer(object): ...@@ -1176,11 +1176,16 @@ class Layer(object):
# but should all non-Variable _buffers[name] be re-assign? We # but should all non-Variable _buffers[name] be re-assign? We
# should consider it in the future. I current wrote this as # should consider it in the future. I current wrote this as
# conservative code. # conservative code.
if _buffers[name] is None or type(_buffers[ if in_declarative_mode() and _buffers[name] is None:
name]) == core.VarBase: raise RuntimeError(
'In Dy2stat, self.{0} is a buffer and self.{0} is '
'not allowed to be set to Variable when self.{0} is None.'.
format(name))
elif _buffers[name] is None or type(
getattr(self, name)) == core.VarBase:
_buffers[name] = assign(value) _buffers[name] = assign(value)
else: else:
assign(value, _buffers[name]) assign(value, getattr(self, name))
elif value is not None: elif value is not None:
raise TypeError( raise TypeError(
"assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'" "assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'"
......
...@@ -205,6 +205,7 @@ class NetWithControlFlowIf(fluid.dygraph.Layer): ...@@ -205,6 +205,7 @@ class NetWithControlFlowIf(fluid.dygraph.Layer):
self.alpha = 10. self.alpha = 10.
self.constant_vars = {} self.constant_vars = {}
@paddle.jit.to_static
def forward(self, input): def forward(self, input):
hidden_dim = input.shape[-1] hidden_dim = input.shape[-1]
if hidden_dim != self.hidden_dim: if hidden_dim != self.hidden_dim:
......
...@@ -408,5 +408,45 @@ class TestCallNonForwardFunc(unittest.TestCase): ...@@ -408,5 +408,45 @@ class TestCallNonForwardFunc(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
class SetBuffersNet1(paddle.nn.Layer):
def __init__(self):
super(SetBuffersNet1, self).__init__()
self.a = paddle.to_tensor([1])
@paddle.jit.to_static
def forward(self):
self.a = self.a + 1
return self.a
class SetBuffersNet2(paddle.nn.Layer):
def __init__(self):
super(SetBuffersNet2, self).__init__()
self.b = paddle.to_tensor([2])
@paddle.jit.to_static
def forward(self):
self.b = None
self.b = paddle.to_tensor([3])
return self.b
class TestSetBuffers(unittest.TestCase):
def test_set_buffers1(self):
paddle.disable_static()
net = SetBuffersNet1()
out = net()
self.assertEqual(out.numpy().tolist(), [2])
paddle.jit.save(net, './SetBuffersNet1')
paddle.enable_static()
def test_set_buffers2(self):
paddle.disable_static()
net = SetBuffersNet2()
with self.assertRaises(RuntimeError):
out = net()
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -410,14 +410,17 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1): ...@@ -410,14 +410,17 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
self.dyfunc = dyfunc_ifelse_ret_int4 self.dyfunc = dyfunc_ifelse_ret_int4
def test_ast_to_func(self): def test_ast_to_func(self):
ProgramTranslator().enable(True)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
ProgramTranslator().enable(True)
static_func = paddle.jit.to_static(self.dyfunc) static_func = paddle.jit.to_static(self.dyfunc)
out = static_func(self.x) out = static_func(self.x)
# Why need set `_in_declarative_mode_` here?
def __del__(self): # In Dy2St we use `with _switch_declarative_mode_guard_()` to indicate
# that the code block is under @to_static, but in this UT
# an exception is thrown during Dy2St, making the `_in_declarative_mode_`
# a wrong value. So We need set `_in_declarative_mode_` to False manually.
paddle.fluid.dygraph.base._in_declarative_mode_ = False
ProgramTranslator().enable(False) ProgramTranslator().enable(False)
super(TestDy2StIfElseRetInt4, self).__del__()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册