未验证 提交 ba65e4eb 编写于 作者: A Aurelius84 提交者: GitHub

support Tensor.shape in control_flow_if test=develop (#22916)

上级 d33c4343
...@@ -46,11 +46,11 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -46,11 +46,11 @@ class IfElseTransformer(gast.NodeTransformer):
wrapper_root, AstNodeWrapper wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type( ), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root) wrapper_root)
self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.new_func_nodes = {} self.new_func_nodes = {}
def ast_visit(self): def transform(self):
""" """
Main function to transform AST. Main function to transform AST.
""" """
...@@ -59,7 +59,8 @@ class IfElseTransformer(gast.NodeTransformer): ...@@ -59,7 +59,8 @@ class IfElseTransformer(gast.NodeTransformer):
def visit_If(self, node): def visit_If(self, node):
assert isinstance(node, gast.If) assert isinstance(node, gast.If)
need_transform = is_control_flow_if(node.test) need_transform = is_control_flow_if(node.test,
self.static_analysis_visitor)
self.generic_visit(node) self.generic_visit(node)
if need_transform: if need_transform:
pred_node = node.test pred_node = node.test
...@@ -143,7 +144,7 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -143,7 +144,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
# Transform all if/else statement of Dygraph into Static Graph. # Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).ast_visit() IfElseTransformer(node_wrapper).transform()
LoopTransformer(node_wrapper).transform() LoopTransformer(node_wrapper).transform()
......
...@@ -26,6 +26,8 @@ import atexit ...@@ -26,6 +26,8 @@ import atexit
from collections import defaultdict from collections import defaultdict
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
TRUE_FUNC_PREFIX = 'true_fn' TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn' FALSE_FUNC_PREFIX = 'false_fn'
...@@ -49,23 +51,36 @@ class IsControlFlowIfVisitor(gast.NodeTransformer): ...@@ -49,23 +51,36 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
because reshape_op may be called before this statement. because reshape_op may be called before this statement.
""" """
def __init__(self, node): def __init__(self, static_analysis_visitor):
self.node = node self.static_analysis_visitor = static_analysis_visitor
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
self.is_control_flow = False self.is_control_flow = False
def ast_visit(self): def transform(self, node):
self.visit(self.node) if self._is_candidate_node(node):
self.visit(node)
return self.is_control_flow return self.is_control_flow
def visit_BoolOp(self, node):
for child in node.values:
if not self._is_candidate_node(child):
continue
self.generic_visit(node)
return node
def visit_Compare(self, node): def visit_Compare(self, node):
for child in gast.walk(node): # Ignores child node with `if x` or `if x is None`
if isinstance(child, gast.Subscript): if not self._compare_with_none(node):
self._visit_Subscript(child) self.generic_visit(node)
for child in gast.walk(node):
if isinstance(child, gast.Subscript):
self._visit_Subscript(child)
return node return node
def _visit_Subscript(self, node): def _visit_Subscript(self, node):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.value, gast.Call): if hasattr(node, 'value') and isinstance(node.value, gast.Call):
self._visit_Call(node.value) self._visit_Call(node.value)
return node return node
...@@ -73,10 +88,40 @@ class IsControlFlowIfVisitor(gast.NodeTransformer): ...@@ -73,10 +88,40 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Attribute): if isinstance(node.func, gast.Attribute):
attr_node = node.func attr_node = node.func
self.is_control_flow = (attr_node.attr == 'numpy') if attr_node.attr == 'numpy':
self.is_control_flow = True
def visit_Call(self, node):
if is_paddle_api(node):
self.is_control_flow = True
return node
def visit_Name(self, node):
wrapper_node = self.node_to_wrapper_map.get(node, None)
if wrapper_node is not None:
if wrapper_node.node_var_type & {
NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES
}:
self.is_control_flow = True
return node
def _is_candidate_node(self, node):
return isinstance(node, (gast.Compare, gast.BoolOp))
def _compare_with_none(self, node):
if isinstance(node, gast.Compare):
for child in [node.left, node.comparators]:
# node.comparators is a list.
if isinstance(child, list): child = child[0]
if (isinstance(child, gast.Constant) and
child.value is None) or (
isinstance(child, gast.Name) and
child.id == 'None'):
return True
return False
def is_control_flow_if(node): def is_control_flow_if(node, static_analysis_visitor=None):
""" """
Determine whether the node is a plain python `if statement` or Determine whether the node is a plain python `if statement` or
control flow in Paddle. control flow in Paddle.
...@@ -84,7 +129,9 @@ def is_control_flow_if(node): ...@@ -84,7 +129,9 @@ def is_control_flow_if(node):
assert isinstance( assert isinstance(
node, gast.AST node, gast.AST
), "Type of input node should be gast.AST, but received %s." % type(node) ), "Type of input node should be gast.AST, but received %s." % type(node)
return IsControlFlowIfVisitor(node).ast_visit() if static_analysis_visitor is None:
static_analysis_visitor = StaticAnalysisVisitor(node)
return IsControlFlowIfVisitor(static_analysis_visitor).transform(node)
def get_name_ids(nodes, not_name_set=None, node_black_list=None): def get_name_ids(nodes, not_name_set=None, node_black_list=None):
......
...@@ -21,6 +21,7 @@ import inspect ...@@ -21,6 +21,7 @@ import inspect
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func, is_control_flow_if from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func, is_control_flow_if
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from test_dygraph_to_static_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else from test_dygraph_to_static_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else
...@@ -98,35 +99,89 @@ class TestIsControlFlowIf(unittest.TestCase): ...@@ -98,35 +99,89 @@ class TestIsControlFlowIf(unittest.TestCase):
def test_expr(self): def test_expr(self):
# node is not ast.Compare # node is not ast.Compare
node = gast.parse("a + b") node = gast.parse("a + b")
self.assertFalse(is_control_flow_if(node)) self.assertFalse(is_control_flow_if(node.body[0].value))
def test_expr2(self): def test_expr2(self):
node = gast.parse("a + x.numpy()[1]") node = gast.parse("a + x.numpy()[1]")
self.assertFalse(is_control_flow_if(node)) self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None(self): def test_is_None(self):
node = gast.parse("x is None") node = gast.parse("x is None")
self.assertFalse(is_control_flow_if(node)) self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None2(self): def test_is_None2(self):
node = gast.parse("fluid.layers.sum(x) is None") node = gast.parse("fluid.layers.sum(x) is None")
self.assertFalse(is_control_flow_if(node)) self.assertFalse(is_control_flow_if(node.body[0].value))
def test_is_None3(self): def test_is_None3(self):
node = gast.parse("fluid.layers.sum(x).numpy() != None") node = gast.parse("fluid.layers.sum(x).numpy() != None")
self.assertFalse(is_control_flow_if(node)) self.assertFalse(is_control_flow_if(node.body[0].value))
def test_if(self): def test_if(self):
node = gast.parse("x.numpy()[1] > 1") node = gast.parse("x.numpy()[1] > 1")
self.assertTrue(is_control_flow_if(node)) self.assertTrue(is_control_flow_if(node.body[0].value))
def test_if_with_and(self): def test_if_with_and(self):
node = gast.parse("x is not None and 1 < x.numpy()[1]") node = gast.parse("x is not None and 1 < x.numpy()[1]")
self.assertTrue(is_control_flow_if(node)) self.assertTrue(is_control_flow_if(node.body[0].value))
def test_if_with_or(self): def test_if_with_or(self):
node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0") node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0")
self.assertTrue(is_control_flow_if(node)) self.assertTrue(is_control_flow_if(node.body[0].value))
def test_shape(self):
code = """
def foo(x):
batch_size = fluid.layers.shape(x)
if batch_size[0] > 16:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_shape_with_andOr(self):
code = """
def foo(x):
batch_size = fluid.layers.shape(x)
if x is not None and batch_size[0] > 16 or 2 > 1:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[1].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_paddle_api(self):
code = """
def foo(x):
if fluid.layers.shape(x)[0] > 16:
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_paddle_api_with_andOr(self):
code = """
def foo(x):
if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
self.assertTrue(is_control_flow_if(test_node, visitor))
def test_raise_error(self): def test_raise_error(self):
node = "a + b" node = "a + b"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册