From 4474e0855a896ec02d2abc296ea16bb7ddae67e9 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 9 Mar 2023 14:13:11 +0800 Subject: [PATCH] [Dy2St]Enhance @to_static auto-skip paddle inner API (#50596) * [Dy2St]Enhance @to_static auto-skip paddle inner API * fix comment * fix class method --- .../dygraph_to_static/test_convert_call.py | 28 +++++++++++++++++++ .../dygraph_to_static/test_unuseful_inputs.py | 4 +-- .../jit/dy2static/program_translator.py | 27 ++++++++++-------- python/paddle/jit/dy2static/utils.py | 18 ++++++++++-- 4 files changed, 61 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py index e1b6ec15d48..49e976cd0e4 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py @@ -269,5 +269,33 @@ class TestNotToConvert2(TestRecursiveCall2): self.assertIn("if x.shape[0] > 1", self.dygraph_func.code) +# Situation 3 : test to_static for paddle api +def forward(self, x): + if x.shape[0] > 1: + x = x + 1 + return x + + +class TestConvertPaddleAPI(unittest.TestCase): + def test_functional_api(self): + func = paddle.nn.functional.relu + func = paddle.jit.to_static(func) + self.assertNotIn("_jst.IfElse", func.code) + self.assertIn("if in_dygraph_mode()", func.code) + + def test_class_api(self): + bn = paddle.nn.SyncBatchNorm(2) + paddle.jit.to_static(bn) + self.assertNotIn("_jst.IfElse", bn.forward.code) + self.assertIn("if in_dygraph_mode()", bn.forward.code) + + def test_class_patch_api(self): + paddle.nn.SyncBatchNorm.forward = forward + bn = paddle.nn.SyncBatchNorm(2) + paddle.jit.to_static(bn) + self.assertNotIn("_jst.IfElse", bn.forward.code) + self.assertIn("if x.shape[0] > 1", bn.forward.code) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_unuseful_inputs.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_unuseful_inputs.py index 2e4d12ac4dc..5cafba4e040 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_unuseful_inputs.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_unuseful_inputs.py @@ -42,10 +42,10 @@ class Layer0(nn.Layer): def forward(self, x): out1 = self._linear1(x) out2 = self._linear2(x) - # out2.stop_gradient = True 如果stop_gradient不报错 + # out2.stop_gradient = True not raise error a = [out1, out2] b = self.layer1(a) - # self.layer1(out1, out2) 也出错 + # self.layer1(out1, out2) will raise error return b diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index f654c34d04e..41722ad4bad 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -48,6 +48,7 @@ from .utils import ( ast_to_source_code, func_to_source_code, input_specs_compatible, + is_paddle_func, make_hashable, prim_or_cinn_is_enabled, type_name, @@ -150,6 +151,8 @@ def convert_to_static(function): """ Transforms function of dygraph into static function using the cache mechanism. + Note(dev): It will return function.__func__ if encountering class method. + Args: function(callable): The function with dygraph layers that will be converted into static layers. """ @@ -158,7 +161,11 @@ def convert_to_static(function): # Return directly if decorated with @not_to_static and DO NOT Cache it options = getattr(function, CONVERSION_OPTIONS, None) - if options is not None and options.not_convert: + # or ignore paddle api + need_skip = (options is not None and options.not_convert) or is_paddle_func( + function + ) + if need_skip: return function.__func__ if inspect.ismethod(function) else function with _CACHE_LOCK: @@ -415,7 +422,7 @@ class StaticFunction: def _clone(self): return self.__class__( - self._dygraph_function, self._input_spec, **self._kwargs + self.dygraph_function, self._input_spec, **self._kwargs ) def __call__(self, *args, **kwargs): @@ -513,14 +520,7 @@ class StaticFunction: Return: Outputs of dygraph function. """ - if self._class_instance is not None: - dygraph_function = self._dygraph_function.__get__( - self._class_instance - ) - else: - dygraph_function = self._dygraph_function - - return dygraph_function(*args, **kwargs) + return self.dygraph_function(*args, **kwargs) def _raise_when_property(self): """raise RuntimeError when property=True @@ -586,7 +586,7 @@ class StaticFunction: """ Returns the source code of transformed static function for debugging. """ - static_func = convert_to_static(self._dygraph_function) + static_func = convert_to_static(self.dygraph_function) source_code = func_to_source_code(static_func) return source_code @@ -595,7 +595,10 @@ class StaticFunction: """ Returns the original decorated function. """ - return self._dygraph_function + if self._class_instance is not None: + return self._dygraph_function.__get__(self._class_instance) + else: + return self._dygraph_function @property def concrete_program(self): diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index ee69ccde1a9..9ea21bdfc2f 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -287,8 +287,22 @@ def is_paddle_api(node): def is_paddle_func(func): - m = inspect.getmodule(func) - return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX) + try: + if isinstance(func, functools.partial): + func = func.func + + # In case of dynamically monkey patch customised function + # into paddle class obj, so we consider its class module + # path as prefix. + if hasattr(func, "__self__"): + func = func.__self__ + elif inspect.ismethod(func): + func = func.__func__ + + m = inspect.getmodule(func) + return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX) + except Exception: + return False # Is numpy_api cannot reuse is_api_in_module because of numpy module problem -- GitLab