未验证 提交 60aaa715 编写于 作者: S songyouwei 提交者: GitHub

dygraph support while_loop op (#22714)

* dygraph support while_loop op
test=develop

* refine assign
test=develop
上级 f97f3f93
...@@ -18,7 +18,7 @@ from ..wrapped_decorator import signature_safe_contextmanager ...@@ -18,7 +18,7 @@ from ..wrapped_decorator import signature_safe_contextmanager
from .layer_function_generator import autodoc, templatedoc from .layer_function_generator import autodoc, templatedoc
from .tensor import assign, cast, fill_constant from .tensor import assign, cast, fill_constant
from .. import core from .. import core
from ..framework import Program, Variable, Operator from ..framework import Program, Variable, Operator, in_dygraph_mode
from ..layer_helper import LayerHelper, unique_name from ..layer_helper import LayerHelper, unique_name
from .nn import logical_and, logical_not, logical_or from .nn import logical_and, logical_not, logical_or
from .utils import assert_same_structure, map_structure from .utils import assert_same_structure, map_structure
...@@ -999,6 +999,20 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ...@@ -999,6 +999,20 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
"the shape of the variable returned by cond should be []," "the shape of the variable returned by cond should be [],"
"but given shape as {0}.".format(list(pre_cond.shape))) "but given shape as {0}.".format(list(pre_cond.shape)))
if in_dygraph_mode():
now_cond = pre_cond.numpy()[0]
while (now_cond):
output_vars = body(*loop_vars)
if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars]
if len(output_vars) != len(loop_vars):
raise ValueError(
"body in while_loop should return the same arity "
"(length and structure) and types as loop_vars")
now_cond = cond(*output_vars).numpy()[0]
loop_vars = output_vars
return loop_vars
while_loop_block = While(pre_cond, is_test, name) while_loop_block = While(pre_cond, is_test, name)
with while_loop_block.block(): with while_loop_block.block():
output_vars = body(*loop_vars) output_vars = body(*loop_vars)
......
...@@ -1393,6 +1393,41 @@ class TestLayer(LayerTest): ...@@ -1393,6 +1393,41 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_ret, dy_ret_rlt)) self.assertTrue(np.allclose(static_ret, dy_ret_rlt))
def test_while_loop(self):
with self.static_graph():
i = layers.fill_constant(shape=[1], dtype='int64', value=0)
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
def cond(i):
return layers.less_than(i, ten)
def body(i):
return i + 1
out = layers.while_loop(cond, body, [i])
static_ret = self.get_static_graph_result(feed={}, fetch_list=out)
with self.dynamic_graph():
i = layers.fill_constant(shape=[1], dtype='int64', value=0)
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
def cond(i):
return layers.less_than(i, ten)
def body(i):
return i + 1
dy_ret = layers.while_loop(cond, body, [i])
with self.assertRaises(ValueError):
j = layers.fill_constant(shape=[1], dtype='int64', value=0)
def body2(i):
return i + 1, i + 2
layers.while_loop(cond, body2, [j])
self.assertTrue(np.array_equal(static_ret[0], dy_ret[0].numpy()))
def test_compare(self): def test_compare(self):
value_a = np.arange(3) value_a = np.arange(3)
value_b = np.arange(3) value_b = np.arange(3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册