未验证 提交 5e0dde02 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] support usage: to_static(model) (#27040)

* support to_static(model)

* add warning and unittest
上级 1b84c0bf
......@@ -212,7 +212,16 @@ def declarative(function=None, input_spec=None):
# for usage: `declarative(foo, ...)`
if function is not None:
return decorated(function)
if isinstance(function, Layer):
if isinstance(function.forward, StaticLayer):
class_name = function.__class__.__name__
warnings.warn(
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
format(class_name))
function.forward = decorated(function.forward)
return function
else:
return decorated(function)
# for usage: `@declarative`
return decorated
......
......@@ -332,5 +332,31 @@ class TestDeclarativeAPI(unittest.TestCase):
func(np.ones(5).astype("int32"))
class TestDecorateModelDirectly(unittest.TestCase):
def setUp(self):
paddle.disable_static()
program_trans.enable(True)
self.x = to_variable(np.ones([4, 10]).astype('float32'))
def test_fake_input(self):
net = SimpleNet()
net = declarative(net)
y = net(self.x)
self.assertTrue(len(net.forward.program_cache) == 1)
def test_input_spec(self):
net = SimpleNet()
net = declarative(net, input_spec=[InputSpec([None, 8, 10])])
self.assertTrue(len(net.forward.inputs) == 1)
self.assertTrue(len(net.forward.program_cache) == 1)
input_shape = net.forward.inputs[0].shape
self.assertListEqual(list(input_shape), [-1, 8, 10])
# redecorate
net = declarative(net, input_spec=[InputSpec([None, 16, 10])])
input_shape = net.forward.inputs[0].shape
self.assertListEqual(list(input_shape), [-1, 16, 10])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册