未验证 提交 a8b6dd86 编写于 作者: L liym27 提交者: GitHub

[Cherry-Pick 2.0][Dy2Stat] 1. Fix bug of for-range stmts. 2. Support that step...

[Cherry-Pick 2.0][Dy2Stat] 1. Fix bug of for-range stmts. 2. Support that step value is negative in for-range stmts (#29519) (#29874)

1. Fix error in _build_cond_stmt of for-range stmts.

2. Support that step value is negative in for-range stmts

3. Fix code because of the diff between Py2 and Py3
上级 63939597
......@@ -1029,17 +1029,39 @@ class ForNodeVisitor(object):
return step_node
def _build_cond_stmt(self, step_node, compare_node):
if not isinstance(step_node, (gast.Constant, gast.UnaryOp)):
raise NotImplementedError(
"Dynamic-to-Static only supports the step value is a constant or negative constant in 'for-range' statements, "
"such as '2', '-3'. But received: '{}'. Please fix code to be compatible with Dynamic-to-Static."
.format(ast_to_source_code(step_node).strip()))
if isinstance(step_node, gast.UnaryOp) or step_node.value < 0:
# eg:
# range(max, min, -2)
# ->
# i > min
return gast.Compare(
left=gast.BinOp(
left=gast.Name(
id=self.iter_var_name
if self.is_for_range_iter() else self.iter_idx_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
op=gast.Add(),
right=step_node),
ops=[gast.LtE()],
ops=[gast.Gt()],
comparators=[compare_node])
else:
# eg:
# range(min, max, 2)
# ->
# i < max
return gast.Compare(
left=gast.Name(
id=self.iter_var_name
if self.is_for_range_iter() else self.iter_idx_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
ops=[gast.Lt()],
comparators=[compare_node])
def _build_index_increase_node(self, step_node):
......
......@@ -94,6 +94,28 @@ def for_loop_dyfunc2(max_len):
return ret
def for_loop_dyfunc3(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32')
for i in range(1, 10, 2):
fluid.layers.increment(ret, value=2.0, in_place=True)
return ret
def for_loop_dyfunc4(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32')
for i in range(10, 1, -2):
fluid.layers.increment(ret, value=2.0, in_place=True)
return ret
def for_loop_dyfunc_not_support(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32')
a = -2
for i in range(10, 1, a):
fluid.layers.increment(ret, value=2.0, in_place=True)
return ret
def while_loop_bool_op(x):
i = fluid.dygraph.to_variable(x)
......@@ -333,6 +355,16 @@ class TestTransformForLoop2(TestTransformForLoop):
self.dyfunc = for_loop_dyfunc2
class TestTransformForLoop3(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc3
class TestTransformForLoop4(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc4
class TestClassVarInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_class_var
......@@ -343,5 +375,17 @@ class TestVarCreateInForLoop(TestTransformForLoop):
self.dyfunc = var_create_in_for_loop
class TestErrorInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc_not_support
def test_ast_to_func(self):
with self.assertRaisesRegexp(
NotImplementedError,
"Dynamic-to-Static only supports the step value is a constant or negative constant "
):
self._run_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册