未验证 提交 7339a124 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2stat]Fix error when calling sublayer's non-forward func in dy2stat (#37296)

* fix error when calling sublayer's non-forward func in dy2stat

* fix circular import using an inelegant way

* deal with parameters

* remove param_guard in __call__

* remove comment

* fix error when jit.load

* rename block var

* remove wrong code

* add unit test
上级 8e6d5d2b
...@@ -1077,8 +1077,13 @@ def append_var_from_block_desc_static(block, ...@@ -1077,8 +1077,13 @@ def append_var_from_block_desc_static(block,
else: else:
lod_level = None lod_level = None
if var_desc.persistable():
current_block = block.program.global_block()
else:
current_block = block
vars_append.append( vars_append.append(
block.create_var( current_block.create_var(
name=var_desc.name(), name=var_desc.name(),
dtype=data_type, dtype=data_type,
type=var_type, type=var_type,
......
...@@ -31,7 +31,7 @@ from .. import unique_name ...@@ -31,7 +31,7 @@ from .. import unique_name
from paddle.fluid import core from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper from .layer_object_helper import LayerObjectHelper
from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder
from .base import program_desc_tracing_guard, param_guard, in_declarative_mode from .base import program_desc_tracing_guard, param_guard, in_declarative_mode, _convert_into_variable
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope from paddle.fluid.executor import Executor, global_scope
...@@ -914,15 +914,6 @@ class Layer(object): ...@@ -914,15 +914,6 @@ class Layer(object):
return outputs return outputs
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
# NOTE(Aurelius84): Why we still need param_guard here?
# In case of ControlFlow, true_fn and false_fn will contain
# parameters that may not trigger logic of `Operator` to create
# them. we add this to make sure all parameters is available.
if in_declarative_mode() and not framework.in_dygraph_mode():
with param_guard(self._parameters), param_guard(self._buffers):
return self._dygraph_call_func(*inputs, **kwargs)
else:
return self._dygraph_call_func(*inputs, **kwargs) return self._dygraph_call_func(*inputs, **kwargs)
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
...@@ -1103,6 +1094,8 @@ class Layer(object): ...@@ -1103,6 +1094,8 @@ class Layer(object):
if '_parameters' in self.__dict__: if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters'] _parameters = self.__dict__['_parameters']
if name in self._parameters: if name in self._parameters:
if in_declarative_mode() and not framework.in_dygraph_mode():
return _convert_into_variable(self._parameters[name])
return self._parameters[name] return self._parameters[name]
if '_sub_layers' in self.__dict__: if '_sub_layers' in self.__dict__:
_sub_layers = self.__dict__['_sub_layers'] _sub_layers = self.__dict__['_sub_layers']
...@@ -1111,6 +1104,8 @@ class Layer(object): ...@@ -1111,6 +1104,8 @@ class Layer(object):
if '_buffers' in self.__dict__: if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers'] _buffers = self.__dict__['_buffers']
if name in _buffers: if name in _buffers:
if in_declarative_mode() and not framework.in_dygraph_mode():
return _convert_into_variable(_buffers[name])
return _buffers[name] return _buffers[name]
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
......
...@@ -379,5 +379,34 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase): ...@@ -379,5 +379,34 @@ class TestErrorWithInitFromStaticMode(unittest.TestCase):
net.forward.outputs net.forward.outputs
class CallNonForwardFuncNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncNet, self).__init__()
self.sub = CallNonForwardFuncSubNet()
@paddle.jit.to_static
def forward(self):
return self.sub.func()
class CallNonForwardFuncSubNet(paddle.nn.Layer):
def __init__(self):
super(CallNonForwardFuncSubNet, self).__init__()
self.a = paddle.to_tensor([1, 2])
def func(self):
x = self.a * 2
return x
class TestCallNonForwardFunc(unittest.TestCase):
def test_call_non_forward(self):
paddle.disable_static()
net = CallNonForwardFuncNet()
out = net()
self.assertEqual(out.numpy().tolist(), [2, 4])
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册