diff --git a/python/paddle/distributed/auto_parallel/helper.py b/python/paddle/distributed/auto_parallel/helper.py index 164f51fbcc7dda5e892cf22dd088c9b7f9439ab8..e901c861131cc0d306fc3c9ccbc719edf9f8c8dc 100644 --- a/python/paddle/distributed/auto_parallel/helper.py +++ b/python/paddle/distributed/auto_parallel/helper.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from collections import defaultdict from paddle.jit import not_to_static, to_static from paddle.jit.dy2static.program_translator import StaticFunction +from paddle.jit.dy2static.utils import as_not_paddle_func from paddle.nn import Layer from paddle.static import Parameter, global_scope, program_guard @@ -47,6 +49,12 @@ class ProxyLayer(Layer): self._loss_vars = defaultdict(list) self._metric_vars = defaultdict(list) + # Consider ProxyLayer as not Paddle inner function because it contains + # user-defined layer. + as_not_paddle_func( + inspect.getmodule(ProxyLayer).__name__ + ".ProxyLayer" + ) + def _train(self, inputs, labels): """ Train process of inner_layer with forward/loss/metric logic. diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py index 04d34c92ac58a0ee9a15e902aa8e945f53b037df..dd593bca3e7fdf1452fce1004597d38596ed69c3 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py @@ -20,10 +20,11 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import LazyGuard -from paddle.distributed.auto_parallel.helper import ProgramHelper +from paddle.distributed.auto_parallel.helper import ProgramHelper, ProxyLayer from paddle.distributed.fleet import auto from paddle.fluid.framework import _non_static_mode from paddle.io import Dataset +from paddle.jit.dy2static.utils import is_paddle_func from paddle.static import InputSpec batch_size = 4 @@ -182,5 +183,23 @@ class TestLazyInit(unittest.TestCase): program_helper.reset() +class TestIgnoreProxyLayer(unittest.TestCase): + def test_is_paddle_func(self): + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) + loss = paddle.nn.CrossEntropyLoss() + metrics = paddle.metric.Accuracy() + + proxy_layer = ProxyLayer(mlp, loss, metrics) + + self.assertFalse(is_paddle_func(proxy_layer._train)) + self.assertFalse(is_paddle_func(proxy_layer._eval)) + self.assertFalse(is_paddle_func(proxy_layer._predict)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 1a933d96d169133e7a98cfeb527196314f2d50e8..496e75914b2a3c428fb911f0b8efdae3fe3baaaf 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -296,21 +296,51 @@ def is_paddle_api(node): return is_api_in_module(node, PADDLE_MODULE_PREFIX) -def is_paddle_func(func): +# NOTE(Aurelius84): Consider the following paddle inner API as common case to +# apply @to_static code transformation as usual. Because they contains +# user-defined layer, like paddle.distributed.auto_parallel.helper.ProxyLayer. +AS_NOT_INNER_FUNC_LIST = set() + + +def as_not_paddle_func(path): + """ + Append API or class as ignored case for is_paddle_func, and they + will be retured False while calling is_paddle_func(func). + """ + global INNER_FUNC_WHITE_LIST + AS_NOT_INNER_FUNC_LIST.add(path) + + +def is_paddle_func(func, ignore_white_list=True): + """ + Return True if function is defined in Paddle module. + Skip to check APIs in white list if specifying ignore_white_list as True. + """ + + def in_white_list(module, func_name): + if func_name is None: + return False + return (module.__name__ + '.' + func_name) in AS_NOT_INNER_FUNC_LIST + try: if isinstance(func, functools.partial): func = func.func + func_name = getattr(func, '__name__', None) # In case of dynamically monkey patch customised function # into paddle class obj, so we consider its class module # path as prefix. if hasattr(func, "__self__"): func = func.__self__ + func_name = func.__class__.__name__ elif inspect.ismethod(func): func = func.__func__ m = inspect.getmodule(func) - return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX) + flag = m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX) + if ignore_white_list: + flag = flag and not in_white_list(m, func_name) + return flag except Exception: return False