diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 7e6f3fc34ae38ab6a5a291dbb69d6061ea7812c8..a2e3841291b68ddc33344c87d57f58da5623d15f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -52,8 +52,9 @@ class IfElseTransformer(gast.NodeTransformer): def visit_If(self, node): assert isinstance(node, gast.If) + need_transform = is_control_flow_if(node.test) self.generic_visit(node) - if is_control_flow_if(node.test): + if need_transform: pred_node = node.test true_func_node, false_func_node, return_name_ids = transform_if_else( node, self.root) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py index d472f64e6f20d1c5094f9c4f4103b04819859520..94f891c5f6e6f25ea9219fd5cbd83bf4b7a89582 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_utils.py @@ -27,8 +27,53 @@ from collections import defaultdict from paddle.fluid import unique_name -TRUE_FUNC_PRFIX = 'true_fn' -FALSE_FUNC_PRFIX = 'false_fn' +TRUE_FUNC_PREFIX = 'true_fn' +FALSE_FUNC_PREFIX = 'false_fn' + + +class IsControlFlowIfVisitor(gast.NodeTransformer): + """ + Judge whether the node.test from Dygraph code dependent on paddle Tensor. + If does, it should satisfy: + 1. must involve at least one var whose type is Tensor. + 2. the Tensor var should call `.numpy()[]` interface or Tensor.shape is [1]. + 3. involve Tensor.shape[i] and the shape[i] is unknown in compile time. + The following examples should not be considered as control_flow_if: + 1. `if Tensor_var` or `if Tensor_var is None` + 2. if Tensor.shape[i] is determined with fixed value (not -1 or None) + + Note: pred in ConditionalBlock require variable, which means all vars should be Tensor + or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]). + + TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i], + because reshape_op may be called before this statement. + """ + + def __init__(self, node): + self.node = node + self.is_control_flow = False + + def ast_visit(self): + self.visit(self.node) + return self.is_control_flow + + def visit_Compare(self, node): + for child in gast.walk(node): + if isinstance(child, gast.Subscript): + self._visit_Subscript(child) + return node + + def _visit_Subscript(self, node): + self.generic_visit(node) + if isinstance(node.value, gast.Call): + self._visit_Call(node.value) + return node + + def _visit_Call(self, node): + assert isinstance(node, gast.Call) + if isinstance(node.func, gast.Attribute): + attr_node = node.func + self.is_control_flow = (attr_node.attr == 'numpy') def is_control_flow_if(node): @@ -36,7 +81,10 @@ def is_control_flow_if(node): Determine whether the node is a plain python `if statement` or control flow in Paddle. """ - return True + assert isinstance( + node, gast.AST + ), "Type of input node should be gast.AST, but received %s." % type(node) + return IsControlFlowIfVisitor(node).ast_visit() def get_name_ids(nodes, not_name_set=None, node_black_list=None): @@ -228,12 +276,12 @@ def transform_if_else(node, root): true_func_node = create_funcDef_node( node.body, - name=unique_name.generate(TRUE_FUNC_PRFIX), + name=unique_name.generate(TRUE_FUNC_PREFIX), input_args=parse_cond_args(if_name_ids, modified_name_ids), return_name_ids=return_name_ids) false_func_node = create_funcDef_node( node.orelse, - name=unique_name.generate(FALSE_FUNC_PRFIX), + name=unique_name.generate(FALSE_FUNC_PREFIX), input_args=parse_cond_args(else_name_ids, modified_name_ids), return_name_ids=return_name_ids) @@ -309,7 +357,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True): f = tempfile.NamedTemporaryFile( mode='w', suffix='.py', delete=False, encoding='utf-8') - # TODO(Aurelius84): more elegent way to transform ast into callable object + # TODO(Aurelius84): more elegant way to transform ast into callable object import_str = "import paddle\n" \ "import paddle.fluid as fluid\n" \ "import paddle.fluid.layers as layers\n" diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index 8fd3ef7266c91b24801b75fbe1309c521eb15a17..0ca36122fbccc3cf86e699fcf740739a1b310164 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -187,7 +187,7 @@ class AstVarEnv(object): def exit_scope(self): assert self.cur_scope.parent_scope is not None, "Call exit_scope in "\ - "AstVarEnv when current scope doens't have parent scope." + "AstVarEnv when current scope doesn't have parent scope." self.cur_scope = self.cur_scope.parent_scope return self.cur_scope diff --git a/python/paddle/fluid/tests/unittests/test_ast_util.py b/python/paddle/fluid/tests/unittests/test_ast_util.py index 9276d663b2f2382a6a23662e09bc978edc6836bc..27f4f2e5cef5d56deb32ae91f2d38fe37caac845 100644 --- a/python/paddle/fluid/tests/unittests/test_ast_util.py +++ b/python/paddle/fluid/tests/unittests/test_ast_util.py @@ -20,7 +20,9 @@ import gast import inspect import numpy as np import paddle.fluid as fluid -from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func +from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func, is_control_flow_if + +from test_dygraph_to_static_basic import dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else class TestGetNameIds(unittest.TestCase): @@ -92,22 +94,46 @@ class TestGetNameIds3(TestGetNameIds): } -def dyfunc_with_if_else(x_v): - if fluid.layers.mean(x_v).numpy()[0] > 5: - x_v = x_v - 1 - else: - x_v = x_v + 1 - return x_v +class TestIsControlFlowIf(unittest.TestCase): + def test_expr(self): + # node is not ast.Compare + node = gast.parse("a + b") + self.assertFalse(is_control_flow_if(node)) + + def test_expr2(self): + node = gast.parse("a + x.numpy()[1]") + self.assertFalse(is_control_flow_if(node)) + + def test_is_None(self): + node = gast.parse("x is None") + self.assertFalse(is_control_flow_if(node)) + + def test_is_None2(self): + node = gast.parse("fluid.layers.sum(x) is None") + self.assertFalse(is_control_flow_if(node)) + def test_is_None3(self): + node = gast.parse("fluid.layers.sum(x).numpy() != None") + self.assertFalse(is_control_flow_if(node)) -def dyfunc_with_if_else2(x): - i, j = 0, 0 - if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]: - y = fluid.layers.relu(x) - else: - x_pow = fluid.layers.pow(x, 2) - y = fluid.layers.tanh(x_pow) - return y + def test_if(self): + node = gast.parse("x.numpy()[1] > 1") + self.assertTrue(is_control_flow_if(node)) + + def test_if_with_and(self): + node = gast.parse("x is not None and 1 < x.numpy()[1]") + self.assertTrue(is_control_flow_if(node)) + + def test_if_with_or(self): + node = gast.parse("1 < fluid.layers.sum(x).numpy()[2] or x+y < 0") + self.assertTrue(is_control_flow_if(node)) + + def test_raise_error(self): + node = "a + b" + with self.assertRaises(Exception) as e: + self.assertRaises(TypeError, is_control_flow_if(node)) + self.assertTrue( + "Type of input node should be gast.AST" in str(e.exception)) class TestAST2Func(unittest.TestCase): @@ -130,13 +156,14 @@ class TestAST2Func(unittest.TestCase): self.assertEqual(func(x, y), self._ast2func(func)(x, y)) def test_ast2func_dygraph(self): - func = dyfunc_with_if_else + funcs = [dyfunc_with_if_else, dyfunc_with_if_else, nested_if_else] x_data = np.random.random([10, 16]).astype('float32') - with fluid.dygraph.guard(): - x_v = fluid.dygraph.to_variable(x_data) - true_ret = func(x_v).numpy() - test_ret = self._ast2func(func)(x_v).numpy() - self.assertTrue((true_ret == test_ret).all()) + for func in funcs: + with fluid.dygraph.guard(): + x_v = fluid.dygraph.to_variable(x_data) + true_ret = func(x_v).numpy() + test_ret = self._ast2func(func)(x_v).numpy() + self.assertTrue((true_ret == test_ret).all()) def test_ast2func_static(self): def func(x): diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py index 1c5e298552d054fb644aaab370d5f16dd5e861de..a74cd10be7d3da2e2952e203d0356066896d701f 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_to_static_basic.py @@ -22,18 +22,30 @@ from paddle.fluid.dygraph.jit import dygraph_to_static_output np.random.seed(1) +if fluid.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) +else: + place = fluid.CPUPlace() -def dyfunc_with_if_else(x_v): + +def dyfunc_with_if_else(x_v, label=None): if fluid.layers.mean(x_v).numpy()[0] > 5: x_v = x_v - 1 else: x_v = x_v + 1 + # plain if in python + if label is not None: + loss = fluid.layers.cross_entropy(x_v, label) + return loss return x_v -def dyfunc_with_if_else2(x): - i, j = 0, 0 - if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]: +def dyfunc_with_if_else2(x, col=100): + row = 0 + # plain if in python + if abs(col) > x.shape[-1]: + col = -1 + if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[row][col]: y = fluid.layers.relu(x) else: x_pow = fluid.layers.pow(x, 2) @@ -42,9 +54,12 @@ def dyfunc_with_if_else2(x): def nested_if_else(x_v): - batch_size = x_v.shape[0] + batch_size = 16 feat_size = x_v.shape[-1] bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1) + # plain if in python + if x_v.shape[0] != batch_size: + batch_size = x_v.shape[0] if fluid.layers.mean(x_v).numpy()[0] < 0: y = x_v + bias w = fluid.layers.fill_constant([feat_size], dtype='float32', value=10) @@ -78,12 +93,12 @@ class TestDygraphIfElse(unittest.TestCase): x_v = fluid.layers.assign(self.x) # Transform into static graph out = dygraph_to_static_output(self.dyfunc)(x_v) - exe = fluid.Executor(fluid.CPUPlace()) + exe = fluid.Executor(place) ret = exe.run(main_program, fetch_list=out) return ret def _run_dygraph(self): - with fluid.dygraph.guard(): + with fluid.dygraph.guard(place): x_v = fluid.dygraph.to_variable(self.x) ret = self.dyfunc(x_v) return ret.numpy()