diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py index 8442403e04c83ef1653572dba12f196e97970d1a..299b5faa55402af0941baecca17041c1e7d62ccc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/decorator_transformer.py @@ -22,6 +22,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, as import warnings import re +from paddle.fluid.dygraph.dygraph_to_static.utils import RE_PYNAME, RE_PYMODULE IGNORE_NAMES = [ 'declarative', 'to_static', 'dygraph_to_static_func', 'wraps', @@ -65,17 +66,23 @@ class DecoratorTransformer(BaseTransformer): for deco in reversed(deco_list): # skip INGNORE_NAMES - if isinstance(deco, gast.Attribute): - deco_name = deco.attr - elif isinstance(deco, gast.Call): - if hasattr(deco.func, 'args'): - deco_name = deco.func.args[0].id - elif hasattr(deco.func, 'attr'): - deco_name = deco.func.attr - else: - deco_name = deco.func.id + deco_full_name = ast_to_source_code(deco).strip() + if isinstance(deco, gast.Call): + # match case like : + # 1: @_jst.Call(a.b.c.d.deco)() + # 2: @q.w.e.r.deco() + re_tmp = re.match( + r'({module})*({name}\(){{0,1}}({module})*({name})(\)){{0,1}}\(.*$' + .format(name=RE_PYNAME, module=RE_PYMODULE), deco_full_name) + deco_name = re_tmp.group(4) else: - deco_name = deco.id + # match case like: + # @a.d.g.deco + re_tmp = re.match( + r'({module})*({name})$'.format(name=RE_PYNAME, + module=RE_PYMODULE), + deco_full_name) + deco_name = re_tmp.group(2) if deco_name in IGNORE_NAMES: continue elif deco_name == 'contextmanager': @@ -83,7 +90,6 @@ class DecoratorTransformer(BaseTransformer): "Dy2Static : A context manager decorator is used, this may not work correctly after transform." ) - deco_full_name = ast_to_source_code(deco).strip() decoed_func = '_decoedby_' + deco_name # get function after decoration diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 4a6e16e09eaa84cb9e7b59581bac87e82ceefa41..05938aa4b0f7ff233af77225a9167d150d0ea66c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -93,6 +93,9 @@ FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len' FOR_ITER_VAR_NAME_PREFIX = '__for_loop_iter_var' FOR_ITER_ZIP_TO_LIST_PREFIX = '__for_loop_iter_zip' +RE_PYNAME = '[a-zA-Z0-9_]+' +RE_PYMODULE = '[a-zA-Z0-9_]+\.' + # FullArgSpec is valid from Python3. Defined a Namedtuple to # to make it available in Python2. FullArgSpec = collections.namedtuple('FullArgSpec', [