From 16e74f118572ac04fafac12c5f9977a12d412124 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sun, 29 Mar 2020 20:59:28 +0800 Subject: [PATCH] fix is_controw_flow bug with `if Tensor.numpy()` (#23251) --- .../dygraph_to_static/ifelse_transformer.py | 109 +++++++++++------- .../dygraph_to_static/ifelse_simple_func.py | 3 +- .../unittests/dygraph_to_static/test_dict.py | 8 +- .../dygraph_to_static/test_ifelse.py | 32 +++++ .../dygraph_to_static/test_ifelse_basic.py | 10 +- 5 files changed, 119 insertions(+), 43 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index cce8067805..c7c4d48126 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -62,7 +62,6 @@ class IfElseTransformer(gast.NodeTransformer): self.after_visit(self.root) def visit_If(self, node): - assert isinstance(node, gast.If) if_condition_visitor = IfConditionVisitor(node.test, self.static_analysis_visitor) need_transform = if_condition_visitor.is_control_flow() @@ -88,6 +87,22 @@ class IfElseTransformer(gast.NodeTransformer): node = attribute.value return node + def visit_IfExp(self, node): + """ + Transformation with `true_fn(x) if Tensor > 0 else false_fn(x)` + """ + if_condition_visitor = IfConditionVisitor(node.test, + self.static_analysis_visitor) + need_transform = if_condition_visitor.is_control_flow() + self.generic_visit(node) + if need_transform: + pred_node, new_assign_nodes = if_condition_visitor.transform() + new_node = create_cond_node(None, pred_node, node.body, node.orelse, + True) + return new_node + else: + return node + def after_visit(self, node): """ This function will add some postprocessing operations with node. @@ -130,7 +145,12 @@ def is_candidate_node(node): """ Nodes with specified type will be dependent on tensor. """ - return isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp)) + is_compare_node = isinstance(node, + (gast.Compare, gast.BoolOp, gast.UnaryOp)) + # TODO(Aurelius84): `.numpy()` may be an customized function, + # and should consider a more elegant way to solve this problem. + has_numpy_attr = ".numpy()" in ast_to_source_code(node) + return is_compare_node or has_numpy_attr def compare_with_none(node): @@ -223,6 +243,7 @@ class IsControlFlowVisitor(gast.NodeVisitor): self.is_control_flow_num += 1 def visit_Call(self, node): + self._visit_Call(node) if is_paddle_api(node): self.is_control_flow_num += 1 return node @@ -238,8 +259,7 @@ class IsControlFlowVisitor(gast.NodeVisitor): return node def _is_node_with_tensor(self, node, name_id): - tensor_types = set( - [NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES]) + tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES} # Look up the node_var_type_map by name_id. if self.node_var_type_map: if name_id and isinstance(name_id, six.string_types): @@ -261,7 +281,9 @@ class IsControlFlowVisitor(gast.NodeVisitor): class NodeTestTransformer(gast.NodeTransformer): - def __init__(self, ast_node, compare_nodes_with_tensor=set()): + def __init__(self, ast_node, compare_nodes_with_tensor=None): + if compare_nodes_with_tensor is None: + compare_nodes_with_tensor = set() self.ast_root = ast_node self._compare_nodes_with_tensor = compare_nodes_with_tensor self._new_assign_nodes = [] @@ -269,6 +291,15 @@ class NodeTestTransformer(gast.NodeTransformer): def transform(self): return self.visit(self.ast_root) + def visit_Call(self, node): + # self.generic_visit(node) + # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` + if isinstance(node.func, gast.Attribute): + attribute = node.func + if attribute.attr == 'numpy': + node = attribute.value + return node + def visit_UnaryOp(self, node): self.generic_visit(node) if isinstance(node.op, gast.Not): @@ -297,6 +328,7 @@ class NodeTestTransformer(gast.NodeTransformer): if compare_with_none( node) or node not in self._compare_nodes_with_tensor: return self._create_bool_node(node) + self.generic_visit(node) return node def _create_bool_node(self, node): @@ -656,46 +688,43 @@ def transform_if_else(node, root): return true_func_node, false_func_node, return_name_ids -def create_cond_node(return_name_ids, pred, true_func, false_func): +def create_cond_node(return_name_ids, + pred, + true_func, + false_func, + is_if_expr=False): """ Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace original `python if/else` statement. """ + + def create_lambda_node(func_or_expr_node, is_if_expr=False): + body = func_or_expr_node + if not is_if_expr: + body = gast.Call( + func=gast.Name( + id=func_or_expr_node.name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + args=[func_or_expr_node.args], + keywords=[]) + + lambda_node = gast.Lambda( + args=gast.arguments( + args=[], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]), + body=body) + return lambda_node + cond_api = gast.parse('fluid.layers.cond').body[0].value - true_func_lambda = gast.Lambda( - args=gast.arguments( - args=[], - posonlyargs=[], - vararg=None, - kwonlyargs=[], - kw_defaults=None, - kwarg=None, - defaults=[]), - body=gast.Call( - func=gast.Name( - id=true_func.name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - args=[true_func.args], - keywords=[])) - false_func_lambda = gast.Lambda( - args=gast.arguments( - args=[], - posonlyargs=[], - vararg=None, - kwonlyargs=[], - kw_defaults=None, - kwarg=None, - defaults=[]), - body=gast.Call( - func=gast.Name( - id=false_func.name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - args=[false_func.args], - keywords=[])) + true_func_lambda = create_lambda_node(true_func, is_if_expr) + false_func_lambda = create_lambda_node(false_func, is_if_expr) cond_layer = gast.Call( func=cond_api, args=[pred, true_func_lambda, false_func_lambda], diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index e8609d9080..f46c724491 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -58,7 +58,8 @@ def nested_if_else(x_v): bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1) if x_v.shape[0] != batch_size: batch_size = x_v.shape[0] - if fluid.layers.mean(x_v).numpy()[0] < 0: + # if tensor.shape is [1], now support to compare with numpy. + if fluid.layers.mean(x_v).numpy() < 0: y = x_v + bias w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10) if y.numpy()[0] < 10: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py index ce43052635..959fea7e1a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py @@ -89,7 +89,13 @@ class MainNetWithDict(fluid.dygraph.Layer): dtype='float32', value=0), } - max_len = input.shape[0] if input.shape[0] != max_len else max_len + # TODO(Aurelius84): The following code will be converted into: + # max_len = layers.cond(layers.shape(input)[0] != max_len, + # lambda: layers.shape(input)[0], lambda: max_len) + # But max_len should be wrapped into tensor, which is not supported. + + # Comment out this line of code for now. + # max_len = input.shape[0] if input.shape[0] != max_len else max_len out = input for i in range(max_len): out = self.sub_net(out, cache) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index ed0b2fd021..de89cb8ed2 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -84,6 +84,38 @@ class TestDygraphIfElse5(TestDygraphIfElse): self.dyfunc = nested_if_else_3 +def dyfunc_ifExp_with_while(x): + y = [x] + + def add_fn(x): + x = x + 1 + return x + + def cond(i, ten, y): + return i < ten + + def map_func(func, tensor_list): + return [func(x) for x in tensor_list] + + def body(i, ten, y): + # It will be converted into `layers.cond` as followed. + # map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y) + y = map_func(lambda x: x if i == 0 else add_fn(x), y) + i += 1 + return [i, ten, y] + + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + ten = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10) + i, ten, y = fluid.layers.while_loop(cond, body, [i, ten, y]) + return y[0] + + +class TestDygraphIfElse6(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_ifExp_with_while + + class TestDygraphIfElseWithAndOr(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py index 6cc0634996..e40fa00ddf 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse_basic.py @@ -135,7 +135,15 @@ class TestIsControlFlowIf(unittest.TestCase): self.check_false_case("a+b") def test_expr2(self): - self.check_false_case("a + x.numpy()[1]") + # x is a Tensor. + node = gast.parse("a + x.numpy()") + node_test = node.body[0].value + + if_visitor = IfConditionVisitor(node_test) + self.assertTrue(if_visitor.is_control_flow()) + # No transformation will be applied. + new_node, assign_nodes = if_visitor.transform() + self.assertTrue(len(assign_nodes) == 0) def test_is_None(self): self.check_false_case("x is None") -- GitLab