From ba65e4ebef4a3346a973317ac522a438adf262a1 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 10 Mar 2020 10:02:35 +0800 Subject: [PATCH] support Tensor.shape in control_flow_if test=develop (#22916) --- .../dygraph_to_static/ast_transformer.py | 9 +-- .../dygraph/dygraph_to_static/ast_utils.py | 69 +++++++++++++++--- .../fluid/tests/unittests/test_ast_util.py | 71 ++++++++++++++++--- 3 files changed, 126 insertions(+), 23 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index b175dde68ea..c0f7b30ca9a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -46,11 +46,11 @@ class IfElseTransformer(gast.NodeTransformer): wrapper_root, AstNodeWrapper ), "Type of input node should be AstNodeWrapper, but received %s ." % type( wrapper_root) - self.wrapper_root = wrapper_root self.root = wrapper_root.node + self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.new_func_nodes = {} - def ast_visit(self): + def transform(self): """ Main function to transform AST. """ @@ -59,7 +59,8 @@ class IfElseTransformer(gast.NodeTransformer): def visit_If(self, node): assert isinstance(node, gast.If) - need_transform = is_control_flow_if(node.test) + need_transform = is_control_flow_if(node.test, + self.static_analysis_visitor) self.generic_visit(node) if need_transform: pred_node = node.test @@ -143,7 +144,7 @@ class DygraphToStaticAst(gast.NodeTransformer): self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() # Transform all if/else statement of Dygraph into Static Graph. - IfElseTransformer(node_wrapper).ast_visit() + IfElseTransformer(node_wrapper).transform() LoopTransformer(node_wrapper).transform() diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py index 3f8a1699739..42321101c30 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py @@ -26,6 +26,8 @@ import atexit from collections import defaultdict from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType, StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' @@ -49,23 +51,36 @@ class IsControlFlowIfVisitor(gast.NodeTransformer): because reshape_op may be called before this statement. """ - def __init__(self, node): - self.node = node + def __init__(self, static_analysis_visitor): + self.static_analysis_visitor = static_analysis_visitor + self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( + ) self.is_control_flow = False - def ast_visit(self): - self.visit(self.node) + def transform(self, node): + if self._is_candidate_node(node): + self.visit(node) return self.is_control_flow + def visit_BoolOp(self, node): + for child in node.values: + if not self._is_candidate_node(child): + continue + self.generic_visit(node) + return node + def visit_Compare(self, node): - for child in gast.walk(node): - if isinstance(child, gast.Subscript): - self._visit_Subscript(child) + # Ignores child node with `if x` or `if x is None` + if not self._compare_with_none(node): + self.generic_visit(node) + for child in gast.walk(node): + if isinstance(child, gast.Subscript): + self._visit_Subscript(child) return node def _visit_Subscript(self, node): self.generic_visit(node) - if isinstance(node.value, gast.Call): + if hasattr(node, 'value') and isinstance(node.value, gast.Call): self._visit_Call(node.value) return node @@ -73,10 +88,40 @@ class IsControlFlowIfVisitor(gast.NodeTransformer): assert isinstance(node, gast.Call) if isinstance(node.func, gast.Attribute): attr_node = node.func - self.is_control_flow = (attr_node.attr == 'numpy') + if attr_node.attr == 'numpy': + self.is_control_flow = True + + def visit_Call(self, node): + if is_paddle_api(node): + self.is_control_flow = True + return node + + def visit_Name(self, node): + wrapper_node = self.node_to_wrapper_map.get(node, None) + if wrapper_node is not None: + if wrapper_node.node_var_type & { + NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES + }: + self.is_control_flow = True + return node + + def _is_candidate_node(self, node): + return isinstance(node, (gast.Compare, gast.BoolOp)) + + def _compare_with_none(self, node): + if isinstance(node, gast.Compare): + for child in [node.left, node.comparators]: + # node.comparators is a list. + if isinstance(child, list): child = child[0] + if (isinstance(child, gast.Constant) and + child.value is None) or ( + isinstance(child, gast.Name) and + child.id == 'None'): + return True + return False -def is_control_flow_if(node): +def is_control_flow_if(node, static_analysis_visitor=None): """ Determine whether the node is a plain python `if statement` or control flow in Paddle. @@ -84,7 +129,9 @@ def is_control_flow_if(node): assert isinstance( node, gast.AST ), "Type of input node should be gast.AST, but received %s." % type(node) - return IsControlFlowIfVisitor(node).ast_visit() + if static_analysis_visitor is None: + static_analysis_visitor = StaticAnalysisVisitor(node) + return IsControlFlowIfVisitor(static_analysis_visitor).transform(node) def get_name_ids(nodes, not_name_set=None, node_black_list=None): diff --git a/python/paddle/fluid/tests/unittests/test_ast_util.py b/python/paddle/fluid/tests/unittests/test_ast_util.py index 27f4f2e5cef..2d984a0a521 100644 --- a/python/paddle/fluid/tests/unittests/test_ast_util.py +++ b/python/paddle/fluid/tests/unittests/test_ast_util.py @@ -21,6 +21,7 @@ import inspect import numpy as np import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func, is_control_flow_if +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from test_dygraph_to_static_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else @@ -98,35 +99,89 @@ class TestIsControlFlowIf(unittest.TestCase): def test_expr(self): # node is not ast.Compare node = gast.parse("a + b") - self.assertFalse(is_control_flow_if(node)) + self.assertFalse(is_control_flow_if(node.body[0].value)) def test_expr2(self): node = gast.parse("a + x.numpy()[1]") - self.assertFalse(is_control_flow_if(node)) + self.assertFalse(is_control_flow_if(node.body[0].value)) def test_is_None(self): node = gast.parse("x is None") - self.assertFalse(is_control_flow_if(node)) + self.assertFalse(is_control_flow_if(node.body[0].value)) def test_is_None2(self): node = gast.parse("fluid.layers.sum(x) is None") - self.assertFalse(is_control_flow_if(node)) + self.assertFalse(is_control_flow_if(node.body[0].value)) def test_is_None3(self): node = gast.parse("fluid.layers.sum(x).numpy() != None") - self.assertFalse(is_control_flow_if(node)) + self.assertFalse(is_control_flow_if(node.body[0].value)) def test_if(self): node = gast.parse("x.numpy()[1] > 1") - self.assertTrue(is_control_flow_if(node)) + self.assertTrue(is_control_flow_if(node.body[0].value)) def test_if_with_and(self): node = gast.parse("x is not None and 1 < x.numpy()[1]") - self.assertTrue(is_control_flow_if(node)) + self.assertTrue(is_control_flow_if(node.body[0].value)) def test_if_with_or(self): node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0") - self.assertTrue(is_control_flow_if(node)) + self.assertTrue(is_control_flow_if(node.body[0].value)) + + def test_shape(self): + code = """ + def foo(x): + batch_size = fluid.layers.shape(x) + if batch_size[0] > 16: + x = x + 1 + return x + """ + code = textwrap.dedent(code) + node = gast.parse(code) + visitor = StaticAnalysisVisitor(node) + test_node = node.body[0].body[1].test + self.assertTrue(is_control_flow_if(test_node, visitor)) + + def test_shape_with_andOr(self): + code = """ + def foo(x): + batch_size = fluid.layers.shape(x) + if x is not None and batch_size[0] > 16 or 2 > 1: + x = x + 1 + return x + """ + code = textwrap.dedent(code) + node = gast.parse(code) + visitor = StaticAnalysisVisitor(node) + test_node = node.body[0].body[1].test + self.assertTrue(is_control_flow_if(test_node, visitor)) + + def test_paddle_api(self): + code = """ + def foo(x): + if fluid.layers.shape(x)[0] > 16: + x = x + 1 + return x + """ + code = textwrap.dedent(code) + node = gast.parse(code) + visitor = StaticAnalysisVisitor(node) + test_node = node.body[0].body[0].test + self.assertTrue(is_control_flow_if(test_node, visitor)) + + def test_paddle_api_with_andOr(self): + code = """ + def foo(x): + if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None : + x = x + 1 + return x + """ + code = textwrap.dedent(code) + node = gast.parse(code) + visitor = StaticAnalysisVisitor(node) + test_node = node.body[0].body[0].test + self.assertTrue(is_control_flow_if(test_node, visitor)) def test_raise_error(self): node = "a + b" -- GitLab