diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index c7c4d48126c12b838ac26d3da4ba2a7799407611..22c77dc854ed4ae5e137579968e9b5a341b1cdaf 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -85,6 +85,7 @@ class IfElseTransformer(gast.NodeTransformer): attribute = node.func if attribute.attr == 'numpy': node = attribute.value + self.generic_visit(node) return node def visit_IfExp(self, node): @@ -292,12 +293,12 @@ class NodeTestTransformer(gast.NodeTransformer): return self.visit(self.ast_root) def visit_Call(self, node): - # self.generic_visit(node) # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` if isinstance(node.func, gast.Attribute): attribute = node.func if attribute.attr == 'numpy': node = attribute.value + self.generic_visit(node) return node def visit_UnaryOp(self, node): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 7f7b72d0c927ff09b92d29bd36219a31f630002e..1bd1fb22631bfa6f2d89f705098effda72023566 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -116,6 +116,30 @@ class TestDygraphIfElse6(TestDygraphIfElse): self.dyfunc = dyfunc_ifExp_with_while +def dyfunc_ifExp_with_while2(x): + y = [x] + + def add_fn(x): + x = x + 1 + return x + + def map_func(func, tensor_list): + return [func(x) for x in tensor_list] + + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + # It will be converted into `layers.cond` as followed. + # map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y) + # `i (Tensor) == 0` is supported in dygraph. + y = map_func(lambda x: x if i == 0 else add_fn(x), y) + return y[0] + + +class TestDygraphIfElse7(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_ifExp_with_while2 + + class TestDygraphIfElseWithAndOr(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32')