未验证 提交 4474e085 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Enhance @to_static auto-skip paddle inner API (#50596)

* [Dy2St]Enhance @to_static auto-skip paddle inner API

* fix comment

* fix class method
上级 0661e8f1
......@@ -269,5 +269,33 @@ class TestNotToConvert2(TestRecursiveCall2):
self.assertIn("if x.shape[0] > 1", self.dygraph_func.code)
# Situation 3 : test to_static for paddle api
def forward(self, x):
if x.shape[0] > 1:
x = x + 1
return x
class TestConvertPaddleAPI(unittest.TestCase):
def test_functional_api(self):
func = paddle.nn.functional.relu
func = paddle.jit.to_static(func)
self.assertNotIn("_jst.IfElse", func.code)
self.assertIn("if in_dygraph_mode()", func.code)
def test_class_api(self):
bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if in_dygraph_mode()", bn.forward.code)
def test_class_patch_api(self):
paddle.nn.SyncBatchNorm.forward = forward
bn = paddle.nn.SyncBatchNorm(2)
paddle.jit.to_static(bn)
self.assertNotIn("_jst.IfElse", bn.forward.code)
self.assertIn("if x.shape[0] > 1", bn.forward.code)
if __name__ == '__main__':
unittest.main()
......@@ -42,10 +42,10 @@ class Layer0(nn.Layer):
def forward(self, x):
out1 = self._linear1(x)
out2 = self._linear2(x)
# out2.stop_gradient = True 如果stop_gradient不报错
# out2.stop_gradient = True not raise error
a = [out1, out2]
b = self.layer1(a)
# self.layer1(out1, out2) 也出错
# self.layer1(out1, out2) will raise error
return b
......
......@@ -48,6 +48,7 @@ from .utils import (
ast_to_source_code,
func_to_source_code,
input_specs_compatible,
is_paddle_func,
make_hashable,
prim_or_cinn_is_enabled,
type_name,
......@@ -150,6 +151,8 @@ def convert_to_static(function):
"""
Transforms function of dygraph into static function using the cache mechanism.
Note(dev): It will return function.__func__ if encountering class method.
Args:
function(callable): The function with dygraph layers that will be converted into static layers.
"""
......@@ -158,7 +161,11 @@ def convert_to_static(function):
# Return directly if decorated with @not_to_static and DO NOT Cache it
options = getattr(function, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
# or ignore paddle api
need_skip = (options is not None and options.not_convert) or is_paddle_func(
function
)
if need_skip:
return function.__func__ if inspect.ismethod(function) else function
with _CACHE_LOCK:
......@@ -415,7 +422,7 @@ class StaticFunction:
def _clone(self):
return self.__class__(
self._dygraph_function, self._input_spec, **self._kwargs
self.dygraph_function, self._input_spec, **self._kwargs
)
def __call__(self, *args, **kwargs):
......@@ -513,14 +520,7 @@ class StaticFunction:
Return:
Outputs of dygraph function.
"""
if self._class_instance is not None:
dygraph_function = self._dygraph_function.__get__(
self._class_instance
)
else:
dygraph_function = self._dygraph_function
return dygraph_function(*args, **kwargs)
return self.dygraph_function(*args, **kwargs)
def _raise_when_property(self):
"""raise RuntimeError when property=True
......@@ -586,7 +586,7 @@ class StaticFunction:
"""
Returns the source code of transformed static function for debugging.
"""
static_func = convert_to_static(self._dygraph_function)
static_func = convert_to_static(self.dygraph_function)
source_code = func_to_source_code(static_func)
return source_code
......@@ -595,7 +595,10 @@ class StaticFunction:
"""
Returns the original decorated function.
"""
return self._dygraph_function
if self._class_instance is not None:
return self._dygraph_function.__get__(self._class_instance)
else:
return self._dygraph_function
@property
def concrete_program(self):
......
......@@ -287,8 +287,22 @@ def is_paddle_api(node):
def is_paddle_func(func):
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
try:
if isinstance(func, functools.partial):
func = func.func
# In case of dynamically monkey patch customised function
# into paddle class obj, so we consider its class module
# path as prefix.
if hasattr(func, "__self__"):
func = func.__self__
elif inspect.ismethod(func):
func = func.__func__
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
except Exception:
return False
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册