未验证 提交 be83e1ee 编写于 作者: T TeFeng Chen 提交者: GitHub

[Zero-Dim] support 0D Tensor for while_loop op (#49780)

* support 0D Tensor for while_loop op

* update

* clean unit test

* revert test_while_loop_op.py

* test again

* remove invalid check

* fix error

* change fluid to paddle.static

* fix paddle.full

* merge forward and backward test

* simplify code

* add precision check

* fix condition var check

* add dygraph test
上级 7f6ccdf3
......@@ -100,12 +100,12 @@ class WhileOp : public framework::OperatorBase {
auto &cond = scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
cond.dims(),
phi::make_ddim({1}),
cond.numel(),
1,
platform::errors::InvalidArgument(
"The shape of Input(Condition) of WhileOp must be 1. But now "
"the Condition's shape is ",
cond.dims().to_str(),
"The numel of Input(Condition) of WhileOp must be 1. But now "
"the Condition's numel is ",
cond.numel(),
".\n"));
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -1590,6 +1590,32 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out2.grad.shape, [])
self.assertEqual(x.grad.shape, [])
def test_while_loop(self):
def cond(i, x):
return paddle.less_than(i, eleven)
def body(i, x):
x = x + i
i = i + 1
return [i, x]
i = paddle.full([], 1.0, dtype='float32')
i.stop_gradient = False
eleven = paddle.full([], 11, dtype='float32')
x = paddle.full([], 0.0, dtype='float32')
x.stop_gradient = False
out_i, out_x = paddle.static.nn.while_loop(cond, body, [i, x])
out_x.backward()
self.assertEqual(out_i.shape, [])
np.testing.assert_allclose(out_i, np.array(11))
self.assertEqual(out_x.shape, [])
np.testing.assert_allclose(out_x, np.array(55))
self.assertEqual(i.grad.shape, [])
np.testing.assert_allclose(i.grad, np.array(10))
self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, np.array(1.0))
class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
......@@ -2592,6 +2618,43 @@ class TestSundryAPIStatic(unittest.TestCase):
self.assertEqual(res[4].shape, ())
self.assertEqual(res[5].shape, ())
@prog_scope()
def test_while_loop(self):
def cond(i, x):
return paddle.less_than(i, eleven)
def body(i, x):
x = x + i
i = i + 1
return [i, x]
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, paddle.static.Program()):
i = paddle.static.data(name='i', shape=[], dtype='float32')
i.stop_gradient = False
eleven = paddle.full([], 11, 'float32')
x = paddle.static.data(name='x', shape=[], dtype='float32')
x.stop_gradient = False
out_i, out_x = paddle.static.nn.while_loop(cond, body, [i, x])
paddle.static.append_backward(out_x)
res = self.exe.run(
main_program,
feed={
'i': np.array(1.0, dtype='float32'),
'x': np.array(0.0, dtype='float32'),
},
fetch_list=[out_i.name, out_x.name, i.grad_name, x.grad_name],
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(res[0], np.array(11))
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[1], np.array(55))
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[2], np.array(10))
self.assertEqual(res[3].shape, ())
np.testing.assert_allclose(res[3], np.array(1.0))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
......
......@@ -467,7 +467,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
)
if _non_static_mode():
now_cond = pre_cond.numpy()[0]
now_cond = pre_cond.numpy().item()
while now_cond:
output_vars = body(*loop_vars)
if not isinstance(output_vars, (list, tuple)):
......@@ -477,7 +477,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
"body in while_loop should return the same arity "
"(length and structure) and types as loop_vars"
)
now_cond = cond(*output_vars).numpy()[0]
now_cond = cond(*output_vars).numpy().item()
map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
return loop_vars
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册