From a8b6dd86b3f134ed072acfe9c45731089002b16e Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Mon, 28 Dec 2020 19:04:06 +0800 Subject: [PATCH] [Cherry-Pick 2.0][Dy2Stat] 1. Fix bug of for-range stmts. 2. Support that step value is negative in for-range stmts (#29519) (#29874) 1. Fix error in _build_cond_stmt of for-range stmts. 2. Support that step value is negative in for-range stmts 3. Fix code because of the diff between Py2 and Py3 --- .../fluid/dygraph/dygraph_to_static/utils.py | 34 +++++++++++--- .../unittests/dygraph_to_static/test_loop.py | 44 +++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 2c2611ff4f..6e44a26e05 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -1029,18 +1029,40 @@ class ForNodeVisitor(object): return step_node def _build_cond_stmt(self, step_node, compare_node): - return gast.Compare( - left=gast.BinOp( + if not isinstance(step_node, (gast.Constant, gast.UnaryOp)): + raise NotImplementedError( + "Dynamic-to-Static only supports the step value is a constant or negative constant in 'for-range' statements, " + "such as '2', '-3'. But received: '{}'. Please fix code to be compatible with Dynamic-to-Static." + .format(ast_to_source_code(step_node).strip())) + + if isinstance(step_node, gast.UnaryOp) or step_node.value < 0: + # eg: + # range(max, min, -2) + # -> + # i > min + return gast.Compare( left=gast.Name( id=self.iter_var_name if self.is_for_range_iter() else self.iter_idx_name, ctx=gast.Load(), annotation=None, type_comment=None), - op=gast.Add(), - right=step_node), - ops=[gast.LtE()], - comparators=[compare_node]) + ops=[gast.Gt()], + comparators=[compare_node]) + else: + # eg: + # range(min, max, 2) + # -> + # i < max + return gast.Compare( + left=gast.Name( + id=self.iter_var_name + if self.is_for_range_iter() else self.iter_idx_name, + ctx=gast.Load(), + annotation=None, + type_comment=None), + ops=[gast.Lt()], + comparators=[compare_node]) def _build_index_increase_node(self, step_node): return gast.AugAssign( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py index 2f107e53ab..b6aa73d376 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_loop.py @@ -94,6 +94,28 @@ def for_loop_dyfunc2(max_len): return ret +def for_loop_dyfunc3(max_len): + ret = fluid.layers.zeros(shape=[1], dtype='float32') + for i in range(1, 10, 2): + fluid.layers.increment(ret, value=2.0, in_place=True) + return ret + + +def for_loop_dyfunc4(max_len): + ret = fluid.layers.zeros(shape=[1], dtype='float32') + for i in range(10, 1, -2): + fluid.layers.increment(ret, value=2.0, in_place=True) + return ret + + +def for_loop_dyfunc_not_support(max_len): + ret = fluid.layers.zeros(shape=[1], dtype='float32') + a = -2 + for i in range(10, 1, a): + fluid.layers.increment(ret, value=2.0, in_place=True) + return ret + + def while_loop_bool_op(x): i = fluid.dygraph.to_variable(x) @@ -333,6 +355,16 @@ class TestTransformForLoop2(TestTransformForLoop): self.dyfunc = for_loop_dyfunc2 +class TestTransformForLoop3(TestTransformForLoop): + def _init_dyfunc(self): + self.dyfunc = for_loop_dyfunc3 + + +class TestTransformForLoop4(TestTransformForLoop): + def _init_dyfunc(self): + self.dyfunc = for_loop_dyfunc4 + + class TestClassVarInForLoop(TestTransformForLoop): def _init_dyfunc(self): self.dyfunc = for_loop_class_var @@ -343,5 +375,17 @@ class TestVarCreateInForLoop(TestTransformForLoop): self.dyfunc = var_create_in_for_loop +class TestErrorInForLoop(TestTransformForLoop): + def _init_dyfunc(self): + self.dyfunc = for_loop_dyfunc_not_support + + def test_ast_to_func(self): + with self.assertRaisesRegexp( + NotImplementedError, + "Dynamic-to-Static only supports the step value is a constant or negative constant " + ): + self._run_static() + + if __name__ == '__main__': unittest.main() -- GitLab