未验证 提交 4474e085 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Enhance @to_static auto-skip paddle inner API (#50596)

* [Dy2St]Enhance @to_static auto-skip paddle inner API

* fix comment

* fix class method
上级 0661e8f1
...@@ -269,5 +269,33 @@ class TestNotToConvert2(TestRecursiveCall2): ...@@ -269,5 +269,33 @@ class TestNotToConvert2(TestRecursiveCall2):
self.assertIn("if x.shape[0] > 1", self.dygraph_func.code) self.assertIn("if x.shape[0] > 1", self.dygraph_func.code)
# Situation 3 : test to_static for paddle api
def forward(self, x):
if x.shape[0] > 1:
x = x + 1
return x
class TestConvertPaddleAPI(unittest.TestCase):
def test_functional_api(self):
func = paddle.nn.functional.relu
func = paddle.jit.to_static(func)
self.assertNotIn("_jst.IfElse", func.code)
self.assertIn("if in_dygraph_mode()", func.code)
def test_class_api(self):
bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if in_dygraph_mode()", bn.forward.code)
def test_class_patch_api(self):
paddle.nn.SyncBatchNorm.forward = forward
bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if x.shape[0] > 1", bn.forward.code)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -42,10 +42,10 @@ class Layer0(nn.Layer): ...@@ -42,10 +42,10 @@ class Layer0(nn.Layer):
def forward(self, x): def forward(self, x):
out1 = self._linear1(x) out1 = self._linear1(x)
out2 = self._linear2(x) out2 = self._linear2(x)
# out2.stop_gradient = True 如果stop_gradient不报错 # out2.stop_gradient = True not raise error
a = [out1, out2] a = [out1, out2]
b = self.layer1(a) b = self.layer1(a)
# self.layer1(out1, out2) 也出错 # self.layer1(out1, out2) will raise error
return b return b
......
...@@ -48,6 +48,7 @@ from .utils import ( ...@@ -48,6 +48,7 @@ from .utils import (
ast_to_source_code, ast_to_source_code,
func_to_source_code, func_to_source_code,
input_specs_compatible, input_specs_compatible,
is_paddle_func,
make_hashable, make_hashable,
prim_or_cinn_is_enabled, prim_or_cinn_is_enabled,
type_name, type_name,
...@@ -150,6 +151,8 @@ def convert_to_static(function): ...@@ -150,6 +151,8 @@ def convert_to_static(function):
""" """
Transforms function of dygraph into static function using the cache mechanism. Transforms function of dygraph into static function using the cache mechanism.
Note(dev): It will return function.__func__ if encountering class method.
Args: Args:
function(callable): The function with dygraph layers that will be converted into static layers. function(callable): The function with dygraph layers that will be converted into static layers.
""" """
...@@ -158,7 +161,11 @@ def convert_to_static(function): ...@@ -158,7 +161,11 @@ def convert_to_static(function):
# Return directly if decorated with @not_to_static and DO NOT Cache it # Return directly if decorated with @not_to_static and DO NOT Cache it
options = getattr(function, CONVERSION_OPTIONS, None) options = getattr(function, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert: # or ignore paddle api
need_skip = (options is not None and options.not_convert) or is_paddle_func(
function
)
if need_skip:
return function.__func__ if inspect.ismethod(function) else function return function.__func__ if inspect.ismethod(function) else function
with _CACHE_LOCK: with _CACHE_LOCK:
...@@ -415,7 +422,7 @@ class StaticFunction: ...@@ -415,7 +422,7 @@ class StaticFunction:
def _clone(self): def _clone(self):
return self.__class__( return self.__class__(
self._dygraph_function, self._input_spec, **self._kwargs self.dygraph_function, self._input_spec, **self._kwargs
) )
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
...@@ -513,14 +520,7 @@ class StaticFunction: ...@@ -513,14 +520,7 @@ class StaticFunction:
Return: Return:
Outputs of dygraph function. Outputs of dygraph function.
""" """
if self._class_instance is not None: return self.dygraph_function(*args, **kwargs)
dygraph_function = self._dygraph_function.__get__(
self._class_instance
)
else:
dygraph_function = self._dygraph_function
return dygraph_function(*args, **kwargs)
def _raise_when_property(self): def _raise_when_property(self):
"""raise RuntimeError when property=True """raise RuntimeError when property=True
...@@ -586,7 +586,7 @@ class StaticFunction: ...@@ -586,7 +586,7 @@ class StaticFunction:
""" """
Returns the source code of transformed static function for debugging. Returns the source code of transformed static function for debugging.
""" """
static_func = convert_to_static(self._dygraph_function) static_func = convert_to_static(self.dygraph_function)
source_code = func_to_source_code(static_func) source_code = func_to_source_code(static_func)
return source_code return source_code
...@@ -595,7 +595,10 @@ class StaticFunction: ...@@ -595,7 +595,10 @@ class StaticFunction:
""" """
Returns the original decorated function. Returns the original decorated function.
""" """
return self._dygraph_function if self._class_instance is not None:
return self._dygraph_function.__get__(self._class_instance)
else:
return self._dygraph_function
@property @property
def concrete_program(self): def concrete_program(self):
......
...@@ -287,8 +287,22 @@ def is_paddle_api(node): ...@@ -287,8 +287,22 @@ def is_paddle_api(node):
def is_paddle_func(func): def is_paddle_func(func):
m = inspect.getmodule(func) try:
return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX) if isinstance(func, functools.partial):
func = func.func
# In case of dynamically monkey patch customised function
# into paddle class obj, so we consider its class module
# path as prefix.
if hasattr(func, "__self__"):
func = func.__self__
elif inspect.ismethod(func):
func = func.__func__
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
except Exception:
return False
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem # Is numpy_api cannot reuse is_api_in_module because of numpy module problem
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册