未验证 提交 0cad1152 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] 1. Fix bug of for-range stmts. 2. Support that step value is...

[Dy2Stat] 1. Fix bug of for-range stmts. 2. Support that step value is negative in for-range stmts (#29519)

1. Fix error in _build_cond_stmt of for-range stmts. 

2. Support that step value is negative in for-range stmts

3. Fix code because of the diff between Py2 and Py3
上级 831e9135
...@@ -1028,17 +1028,39 @@ class ForNodeVisitor(object): ...@@ -1028,17 +1028,39 @@ class ForNodeVisitor(object):
return step_node return step_node
def _build_cond_stmt(self, step_node, compare_node): def _build_cond_stmt(self, step_node, compare_node):
if not isinstance(step_node, (gast.Constant, gast.UnaryOp)):
raise NotImplementedError(
"Dynamic-to-Static only supports the step value is a constant or negative constant in 'for-range' statements, "
"such as '2', '-3'. But received: '{}'. Please fix code to be compatible with Dynamic-to-Static."
.format(ast_to_source_code(step_node).strip()))
if isinstance(step_node, gast.UnaryOp) or step_node.value < 0:
# eg:
# range(max, min, -2)
# ->
# i > min
return gast.Compare( return gast.Compare(
left=gast.BinOp(
left=gast.Name( left=gast.Name(
id=self.iter_var_name id=self.iter_var_name
if self.is_for_range_iter() else self.iter_idx_name, if self.is_for_range_iter() else self.iter_idx_name,
ctx=gast.Load(), ctx=gast.Load(),
annotation=None, annotation=None,
type_comment=None), type_comment=None),
op=gast.Add(), ops=[gast.Gt()],
right=step_node), comparators=[compare_node])
ops=[gast.LtE()], else:
# eg:
# range(min, max, 2)
# ->
# i < max
return gast.Compare(
left=gast.Name(
id=self.iter_var_name
if self.is_for_range_iter() else self.iter_idx_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
ops=[gast.Lt()],
comparators=[compare_node]) comparators=[compare_node])
def _build_index_increase_node(self, step_node): def _build_index_increase_node(self, step_node):
......
...@@ -94,6 +94,28 @@ def for_loop_dyfunc2(max_len): ...@@ -94,6 +94,28 @@ def for_loop_dyfunc2(max_len):
return ret return ret
def for_loop_dyfunc3(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32')
for i in range(1, 10, 2):
fluid.layers.increment(ret, value=2.0, in_place=True)
return ret
def for_loop_dyfunc4(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32')
for i in range(10, 1, -2):
fluid.layers.increment(ret, value=2.0, in_place=True)
return ret
def for_loop_dyfunc_not_support(max_len):
ret = fluid.layers.zeros(shape=[1], dtype='float32')
a = -2
for i in range(10, 1, a):
fluid.layers.increment(ret, value=2.0, in_place=True)
return ret
def while_loop_bool_op(x): def while_loop_bool_op(x):
i = fluid.dygraph.to_variable(x) i = fluid.dygraph.to_variable(x)
...@@ -333,6 +355,16 @@ class TestTransformForLoop2(TestTransformForLoop): ...@@ -333,6 +355,16 @@ class TestTransformForLoop2(TestTransformForLoop):
self.dyfunc = for_loop_dyfunc2 self.dyfunc = for_loop_dyfunc2
class TestTransformForLoop3(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc3
class TestTransformForLoop4(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc4
class TestClassVarInForLoop(TestTransformForLoop): class TestClassVarInForLoop(TestTransformForLoop):
def _init_dyfunc(self): def _init_dyfunc(self):
self.dyfunc = for_loop_class_var self.dyfunc = for_loop_class_var
...@@ -343,5 +375,17 @@ class TestVarCreateInForLoop(TestTransformForLoop): ...@@ -343,5 +375,17 @@ class TestVarCreateInForLoop(TestTransformForLoop):
self.dyfunc = var_create_in_for_loop self.dyfunc = var_create_in_for_loop
class TestErrorInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc_not_support
def test_ast_to_func(self):
with self.assertRaisesRegexp(
NotImplementedError,
"Dynamic-to-Static only supports the step value is a constant or negative constant "
):
self._run_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册