未验证 提交 2500dca8 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat] Fix bug in convert_call (#28368)

* Fix bug in convert_call

* refine unittest

* refine code

* refine code

* fix unittest failed

* add assert
上级 ca415414
......@@ -142,7 +142,7 @@ def convert_call(func):
# Note(Aurelius84): Because `@declarative` returns a class instance instead of
# a function. This will modify the value referring to itself in `__globals__`.
# For example:
# For example:
#
# @declarative
# def foo(x):
......@@ -150,7 +150,7 @@ def convert_call(func):
#
# `foo` will be converted into a wrapper class, suppose as `StaticFunction`.
# And `foo.__globals__['foo']` will still return this `StaticFunction` instead of
# `foo` function. So `isinstance(fn, StaticFunction)` is added here.
# `foo` function. So `isinstance(fn, StaticFunction)` is added here.
global_functions = set()
for fn in func.__globals__.values():
if inspect.isfunction(fn):
......@@ -193,8 +193,10 @@ def convert_call(func):
try:
_, forward_func = unwrap_decorators(func.forward)
forward_func = convert_to_static(forward_func)
setattr(func, 'forward', forward_func)
func_self = func
# Bound mothod will be convert into plain function after `convert_to_static`.
# So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func))
except Exception:
# NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self
......
......@@ -18,14 +18,24 @@ import unittest
from paddle import nn
class LSTMLayer(nn.Layer):
def __init__(self, in_channels, hidden_size):
super(LSTMLayer, self).__init__()
self.cell = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2)
def forward(self, x):
x, _ = self.cell(x)
return x
class Net(nn.Layer):
def __init__(self, in_channels, hidden_size):
super(Net, self).__init__()
self.lstm = nn.LSTM(
in_channels, hidden_size, direction='bidirectional', num_layers=2)
self.lstm = LSTMLayer(in_channels, hidden_size)
def forward(self, x):
x, _ = self.lstm(x)
x = self.lstm(x)
return x
......@@ -115,5 +125,21 @@ class TestSaveInEvalMode(unittest.TestCase):
infer_out))
class TestEvalAfterSave(unittest.TestCase):
def test_eval_after_save(self):
x = paddle.randn((2, 10, 12)).astype('float32')
net = Net(12, 2)
dy_out = net(x)
# save model
paddle.jit.save(net, 'jit.save/lstm', input_spec=[x])
load_net = paddle.jit.load('jit.save/lstm')
load_out = load_net(x)
self.assertTrue(np.allclose(dy_out.numpy(), load_out.numpy()))
# eval
net.eval()
out = net(x)
self.assertTrue(np.allclose(dy_out.numpy(), out.numpy()))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册