From 9ab335bbd9d18b48b287b391a7dd4dd20fedf328 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Thu, 19 Nov 2020 09:59:31 +0800 Subject: [PATCH] Fix convert_call May be Called Multiple Times, test=develop (#28710) Fix convert_callmMay be called multiple times in Dy2stat. Also strip some strings to make sure no influence from blank spaces. --- .../fluid/dygraph/dygraph_to_static/loop_transformer.py | 6 +++--- python/paddle/fluid/dygraph/dygraph_to_static/utils.py | 9 ++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index b25ff8360be..8e3ca72788b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -260,9 +260,9 @@ class NameVisitor(gast.NodeVisitor): type_node = node.args[1] if isinstance(type_node, gast.Tuple): for element in type_node.elts: - self.type_vars.add(ast_to_source_code(element)) + self.type_vars.add(ast_to_source_code(element).strip()) else: - self.type_vars.add(ast_to_source_code(type_node)) + self.type_vars.add(ast_to_source_code(type_node).strip()) self.generic_visit(node) def _var_nodes_to_names(self, node_set, ctx_filter_set=None): @@ -381,7 +381,7 @@ class NameVisitor(gast.NodeVisitor): # 3. Remove var type names which are stored in self.type_vars for var in loop_vars: - if ast_to_source_code(var) in self.type_vars: + if ast_to_source_code(var).strip() in self.type_vars: removed_vars.add(var) return loop_vars - removed_vars diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index b44739ca848..cdb4b8e52dc 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -149,7 +149,14 @@ def _is_api_in_module_helper(obj, module_prefix): def is_api_in_module(node, module_prefix): assert isinstance(node, gast.Call), "Input non-Call node for is_dygraph_api" - func_str = astor.to_source(gast.gast_to_ast(node.func)) + + # Python can have gast.Call as function, for example: covert_call(func)(x) + # We only check the most outside function + func_node = node.func + while isinstance(func_node, gast.Call): + func_node = func_node.func + + func_str = astor.to_source(gast.gast_to_ast(func_node)).strip() try: # TODO(liym27): # Consider a better to import modules like: -- GitLab