From 5e0dde02b23c9cc2c952daa7c563e9f712b039f4 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 11 Sep 2020 11:11:09 +0800 Subject: [PATCH] [Dy2stat] support usage: to_static(model) (#27040) * support to_static(model) * add warning and unittest --- python/paddle/fluid/dygraph/jit.py | 11 +++++++- .../dygraph_to_static/test_declarative.py | 26 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index ec96bdd978..57864efec8 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -212,7 +212,16 @@ def declarative(function=None, input_spec=None): # for usage: `declarative(foo, ...)` if function is not None: - return decorated(function) + if isinstance(function, Layer): + if isinstance(function.forward, StaticLayer): + class_name = function.__class__.__name__ + warnings.warn( + "`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.". + format(class_name)) + function.forward = decorated(function.forward) + return function + else: + return decorated(function) # for usage: `@declarative` return decorated diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index eed02ea655..5582a65304 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -332,5 +332,31 @@ class TestDeclarativeAPI(unittest.TestCase): func(np.ones(5).astype("int32")) +class TestDecorateModelDirectly(unittest.TestCase): + def setUp(self): + paddle.disable_static() + program_trans.enable(True) + self.x = to_variable(np.ones([4, 10]).astype('float32')) + + def test_fake_input(self): + net = SimpleNet() + net = declarative(net) + y = net(self.x) + self.assertTrue(len(net.forward.program_cache) == 1) + + def test_input_spec(self): + net = SimpleNet() + net = declarative(net, input_spec=[InputSpec([None, 8, 10])]) + self.assertTrue(len(net.forward.inputs) == 1) + self.assertTrue(len(net.forward.program_cache) == 1) + input_shape = net.forward.inputs[0].shape + self.assertListEqual(list(input_shape), [-1, 8, 10]) + + # redecorate + net = declarative(net, input_spec=[InputSpec([None, 16, 10])]) + input_shape = net.forward.inputs[0].shape + self.assertListEqual(list(input_shape), [-1, 16, 10]) + + if __name__ == '__main__': unittest.main() -- GitLab