From 4955c97ee893ee3ff57cca48058bf61ef91d2726 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Sat, 4 Apr 2020 10:03:41 +0800 Subject: [PATCH] Add unitTest for `Tensor==constant` for ifElse in dygraph2static (#23407) * Add unitTest for `Tensor==constant` for ifElse in dygraph2static test=develop --- .../dygraph_to_static/ifelse_transformer.py | 3 ++- .../dygraph_to_static/test_ifelse.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index c7c4d48126..22c77dc854 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -85,6 +85,7 @@ class IfElseTransformer(gast.NodeTransformer): attribute = node.func if attribute.attr == 'numpy': node = attribute.value + self.generic_visit(node) return node def visit_IfExp(self, node): @@ -292,12 +293,12 @@ class NodeTestTransformer(gast.NodeTransformer): return self.visit(self.ast_root) def visit_Call(self, node): - # self.generic_visit(node) # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` if isinstance(node.func, gast.Attribute): attribute = node.func if attribute.attr == 'numpy': node = attribute.value + self.generic_visit(node) return node def visit_UnaryOp(self, node): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py index 7f7b72d0c9..1bd1fb2263 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py @@ -116,6 +116,30 @@ class TestDygraphIfElse6(TestDygraphIfElse): self.dyfunc = dyfunc_ifExp_with_while +def dyfunc_ifExp_with_while2(x): + y = [x] + + def add_fn(x): + x = x + 1 + return x + + def map_func(func, tensor_list): + return [func(x) for x in tensor_list] + + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + # It will be converted into `layers.cond` as followed. + # map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y) + # `i (Tensor) == 0` is supported in dygraph. + y = map_func(lambda x: x if i == 0 else add_fn(x), y) + return y[0] + + +class TestDygraphIfElse7(TestDygraphIfElse): + def setUp(self): + self.x = np.random.random([10, 16]).astype('float32') + self.dyfunc = dyfunc_ifExp_with_while2 + + class TestDygraphIfElseWithAndOr(TestDygraphIfElse): def setUp(self): self.x = np.random.random([10, 16]).astype('float32') -- GitLab