未验证 提交 7339a124 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]Fix error when calling sublayer's non-forward func in dy2stat (#37296)

* fix error when calling sublayer's non-forward func in dy2stat

* fix circular import using an inelegant way

* deal with parameters

* remove param_guard in __call__

* remove comment

* fix error when jit.load

* rename block var

* remove wrong code

* add unit test
上级 8e6d5d2b
......@@ -1077,8 +1077,13 @@ def append_var_from_block_desc_static(block,
else:
lod_level = None
if var_desc.persistable():
current_block = block.program.global_block()
else:
current_block = block
vars_append.append(
block.create_var(
current_block.create_var(
name=var_desc.name(),
dtype=data_type,
type=var_type,
......
......@@ -31,7 +31,7 @@ from .. import unique_name
from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper
from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder
from .base import program_desc_tracing_guard, param_guard, in_declarative_mode
from .base import program_desc_tracing_guard, param_guard, in_declarative_mode, _convert_into_variable
from paddle.fluid import framework
from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
......@@ -914,16 +914,7 @@ class Layer(object):
return outputs
def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
if in_declarative_mode() and not framework.in_dygraph_mode():
with param_guard(self._parameters), param_guard(self._buffers):
return self._dygraph_call_func(*inputs, **kwargs)
else:
return self._dygraph_call_func(*inputs, **kwargs)
return self._dygraph_call_func(*inputs, **kwargs)
def forward(self, *inputs, **kwargs):
"""
......@@ -1103,6 +1094,8 @@ class Layer(object):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in self._parameters:
if in_declarative_mode() and not framework.in_dygraph_mode():
return _convert_into_variable(self._parameters[name])
return self._parameters[name]
if '_sub_layers' in self.__dict__:
_sub_layers = self.__dict__['_sub_layers']
......@@ -1111,6 +1104,8 @@ class Layer(object):
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
if in_declarative_mode() and not framework.in_dygraph_mode():
return _convert_into_variable(_buffers[name])
return _buffers[name]
return object.__getattribute__(self, name)
......
......@@ -379,5 +379,34 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase):
net.forward.outputs
class CallNonForwardFuncNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncNet, self).__init__()
self.sub = CallNonForwardFuncSubNet()
@paddle.jit.to_static
def forward(self):
return self.sub.func()
class CallNonForwardFuncSubNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncSubNet, self).__init__()
self.a = paddle.to_tensor([1, 2])
def func(self):
x = self.a * 2
return x
class TestCallNonForwardFunc(unittest.TestCase):
def test_call_non_forward(self):
paddle.disable_static()
net = CallNonForwardFuncNet()
out = net()
self.assertEqual(out.numpy().tolist(), [2, 4])
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册