From 080d37a5016c063f60dabdb554a68e1e64eb7051 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 28 May 2020 15:44:41 +0800 Subject: [PATCH] fix bug in LogicalOpTransformer: Create logic node recursively (#24785) --- .../fluid/dygraph/dygraph_to_static/loop_transformer.py | 6 ++++-- .../fluid/tests/unittests/dygraph_to_static/test_loop.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index b9e6eff2f9b..cb1f2b3b1b0 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -109,8 +109,10 @@ class LogicalOpTransformer(gast.NodeTransformer): len(nodes)) if len(nodes) > 2: # Creates logic_and/logic_or node recursively. - pre_assign_node = self._create_bool_op_node(nodes[:2], api_type) - nodes = [pre_assign_node] + nodes[2:] + pre_logic_node = self._create_bool_op_node(nodes[:2], api_type) + post_logic_node = self._create_bool_op_node(nodes[2:], api_type) + nodes = [pre_logic_node] + [post_logic_node] + args = [ast_to_source_code(child) for child in nodes] new_node_str = "fluid.layers.logical_{}(x={}, y={})".format( api_type, args[0], args[1]) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 08b1336152c..00a7cc4f757 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -85,7 +85,7 @@ def while_loop_bool_op(x): # Use `to_variable` so that static analysis can analyze the type of X is Tensor x = fluid.dygraph.to_variable( x) # TODO(liym27): Delete it if the type of parameter x can be resolved - while (x >= 0 and x < 10) or x <= -1 or x < -3 or (x < -7 or x < -5): + while x <= -1 or x < -3 or (x < -7 or x < -5) or (x >= 0 and x < 10): i = i + x x = x + 1 return i -- GitLab