未验证 提交 100914dd 编写于 作者: A Aurelius84 提交者: GitHub

Fix bug with `if Tensor` in is_control_flow (#24433)

* fix bug with `if Tensor` in is_control_flow test=develop

* remove continue test=develop
上级 1a0d26a4
......@@ -33,6 +33,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
......@@ -145,15 +146,24 @@ class IfElseTransformer(gast.NodeTransformer):
class NodeTestTransformer(gast.NodeTransformer):
def __init__(self, ast_node, compare_nodes_with_tensor=None):
def __init__(self,
ast_node,
compare_nodes_with_tensor=None,
node_to_wrapper_map=None):
if compare_nodes_with_tensor is None:
compare_nodes_with_tensor = set()
self.ast_root = ast_node
self._compare_nodes_with_tensor = compare_nodes_with_tensor
if node_to_wrapper_map is None:
node_to_wrapper_map = {}
self.node_to_wrapper_map = node_to_wrapper_map
self._new_assign_nodes = []
def transform(self):
return self.visit(self.ast_root)
node = self.ast_root
if not is_candidate_node(node):
return self._create_cast_node(node)
return self.visit(node)
def visit_Call(self, node):
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
......@@ -182,8 +192,11 @@ class NodeTestTransformer(gast.NodeTransformer):
def visit_BoolOp(self, node):
for i, child in enumerate(node.values):
if not is_candidate_node(child):
node.values[i] = self._create_bool_node(child)
continue
node_wrapper = self.node_to_wrapper_map.get(child, None)
if node_wrapper and node_wrapper.node_var_type & NodeVarType.TENSOR_TYPES:
node.values[i] = self._create_cast_node(child)
else:
node.values[i] = self._create_bool_node(child)
self.generic_visit(node)
new_node = self._create_logic_node(node)
return new_node
......@@ -195,10 +208,19 @@ class NodeTestTransformer(gast.NodeTransformer):
self.generic_visit(node)
return node
def _create_cast_node(self, node):
template = "fluid.layers.cast(x={}, dtype='bool')"
return self._create_node_with_api_template(node, template)
def _create_bool_node(self, node):
template = "fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool({}))"
return self._create_node_with_api_template(node, template)
def _create_node_with_api_template(self, node, template):
node_code = ast_to_source_code(node)
new_node_str = "fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool({}))".format(
node_code)
new_node_str = template.format(node_code)
# gast.parse return Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value
bool_tensor_name = unique_name.generate(PLAIN_TENSOR_PREFIX)
......@@ -258,7 +280,8 @@ class IfConditionVisitor(object):
self.static_analysis_visitor = static_analysis_visitor
self.visitor = IsControlFlowVisitor(node, static_analysis_visitor,
node_var_type_map)
self.transformer = NodeTestTransformer(node)
self.transformer = NodeTestTransformer(
node, node_to_wrapper_map=self.visitor.node_to_wrapper_map)
self.compare_nodes_with_tensor = set()
self._is_control_flow_if = False
......
......@@ -54,6 +54,9 @@ class NodeVarType(object):
# We use this enum value to denote the type return by a Paddle API
PADDLE_RETURN_TYPES = 304
# If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent.
TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES}
@staticmethod
def binary_op_output_type(in_type1, in_type2):
if in_type1 == in_type2:
......
......@@ -485,15 +485,14 @@ class IsControlFlowVisitor(gast.NodeVisitor):
def transform(self):
node = self.ast_root
if is_candidate_node(node):
if isinstance(node, gast.If):
self._visit_If(node)
if isinstance(node, gast.For):
self._visit_For(node)
elif isinstance(node, gast.While):
self._visit_While(node)
else:
self.visit(node)
if isinstance(node, gast.If):
self._visit_If(node)
elif isinstance(node, gast.For):
self._visit_For(node)
elif isinstance(node, gast.While):
self._visit_While(node)
else:
self.visit(node)
return self.is_control_flow_num > 0
def _visit_If(self, node):
......@@ -548,14 +547,10 @@ class IsControlFlowVisitor(gast.NodeVisitor):
def visit_BoolOp(self, node):
for i, child in enumerate(node.values):
if is_candidate_node(child):
self.visit(child)
self.visit(child)
return node
def visit_Compare(self, node):
# Ignores child node with `if x` or `if x is None`
# TODO(Aurelius84): `if tensor` will be supported in dygraph
# and should be considered as is_control_flow.
pre_control_flow_num = self.is_control_flow_num
if not compare_with_none(node):
self.generic_visit(node)
......@@ -598,19 +593,16 @@ class IsControlFlowVisitor(gast.NodeVisitor):
def _is_node_with_tensor(self, node, name_id):
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
# Look up the node_var_type_map by name_id.
if self.node_var_type_map:
if name_id and isinstance(name_id, six.string_types):
var_type = self.node_var_type_map.get(name_id, None)
if var_type and var_type & tensor_types:
if var_type and var_type & NodeVarType.TENSOR_TYPES:
return True
# if not found, look up the node_to_wrapper_map by node.
node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
wrapper_node = node_to_wrapper_map.get(node, None)
wrapper_node = self.node_to_wrapper_map.get(node, None)
if wrapper_node is not None:
if wrapper_node.node_var_type & tensor_types:
if wrapper_node.node_var_type & NodeVarType.TENSOR_TYPES:
return True
return False
......
......@@ -248,3 +248,31 @@ def if_with_class_var(x, y=None):
else:
x = x - foo.b
return x
def if_tensor_case(x):
x = fluid.dygraph.to_variable(x)
mean = fluid.layers.mean(x)
# It is equivalent to `if mean != 0`
if mean:
for i in range(0, 10):
if i > 5:
x += 1
break
x += 1
else:
for i in range(0, 37):
x += 1
break
x += i
# join `and`/`or`
if fluid.layers.mean(x) + 1 and mean > 1 and x is not None or 2 > 1:
x -= 1
# `not` statement
if not (x[0][0] and (mean * x)[0][0]):
x += 1
return x
......@@ -173,6 +173,12 @@ class TestDygraphIfElseWithClassVar(TestDygraphIfElse):
self.dyfunc = if_with_class_var
class TestDygraphIfTensor(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = if_tensor_case
class TestDygraphIfElseNet(unittest.TestCase):
"""
TestCase for the transformation from control flow `if/else`
......
......@@ -155,7 +155,18 @@ class TestIsControlFlowIf(unittest.TestCase):
self.check_false_case("fluid.layers.sum(x).numpy() != None")
def test_is_None4(self):
self.check_false_case("fluid.layers.sum(x) and 2>1")
node = gast.parse("fluid.layers.sum(x) and 2>1")
node_test = node.body[0].value
if_visitor = IfConditionVisitor(node_test)
self.assertTrue(if_visitor.is_control_flow())
# Transformation result:
# bool_tensor_0 = fluid.layers.cast(x=fluid.layers.sum(x), dtype='bool')
# bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1))
# logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=bool_tensor_1)
new_node, assign_nodes = if_visitor.transform()
self.assertTrue(len(assign_nodes) == 3)
def test_if(self):
node = gast.parse("x.numpy()[1] > 1")
......@@ -253,34 +264,38 @@ class TestIsControlFlowIf(unittest.TestCase):
self.assertTrue(len(assign_nodes) == 0)
def test_paddle_api_with_andOr(self):
code = """
code_or = """
def foo(x):
if 2 > 1 and fluid.layers.shape(x)[0] > 16 or x is not None :
x = x + 1
return x
"""
code = """
code_and = """
def foo(x):
if 2 > 1 and fluid.layers.shape(x)[0] > 16 and x is not None :
x = x + 1
return x
"""
code = textwrap.dedent(code)
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
self.assertTrue(if_visitor.is_control_flow())
new_node, assign_nodes = if_visitor.transform()
# Tranformation result:
# bool_tensor_0 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1))
# bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(x is not None))
# logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=fluid.layers.shape(x)[0] > 16)
# logic_and_1 = fluid.layers.logical_and(x=logic_and_0, y=bool_tensor_1)
self.assertTrue(isinstance(new_node, gast.Name))
self.assertTrue(len(assign_nodes) == 4)
for code in [code_or, code_and]:
code = textwrap.dedent(code)
node = gast.parse(code)
static_analysis_visitor = StaticAnalysisVisitor(node)
test_node = node.body[0].body[0].test
if_visitor = IfConditionVisitor(test_node, static_analysis_visitor)
self.assertTrue(if_visitor.is_control_flow())
new_node, assign_nodes = if_visitor.transform()
# Transformation result:
# bool_tensor_0 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(2 > 1))
# bool_tensor_1 = fluid.layers.fill_constant(shape=[1], dtype='bool', value=bool(x is not None))
# logic_and_0 = fluid.layers.logical_and(x=bool_tensor_0, y=fluid.layers.shape(x)[0] > 16)
# logic_and_1 = fluid.layers.logical_and(x=logic_and_0, y=bool_tensor_1) for code_and
# logic_or_0= fluid.layers.logical_or(x=logic_and_0, y=bool_tensor_1) for code_and
self.assertTrue(isinstance(new_node, gast.Name))
self.assertTrue(len(assign_nodes) == 4)
def test_with_node_var_type_map(self):
node = gast.parse("x > 1")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册