From be83e1ee15e7ba99dd5d64e7c53a8bde0b1457e6 Mon Sep 17 00:00:00 2001 From: TeFeng Chen Date: Tue, 7 Feb 2023 11:06:30 +0800 Subject: [PATCH] [Zero-Dim] support 0D Tensor for while_loop op (#49780) * support 0D Tensor for while_loop op * update * clean unit test * revert test_while_loop_op.py * test again * remove invalid check * fix error * change fluid to paddle.static * fix paddle.full * merge forward and backward test * simplify code * add precision check * fix condition var check * add dygraph test --- .../fluid/operators/controlflow/while_op.cc | 10 +-- .../tests/unittests/test_zero_dim_tensor.py | 63 +++++++++++++++++++ python/paddle/static/nn/control_flow.py | 4 +- 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index a5e3183774..248d899aa2 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -100,12 +100,12 @@ class WhileOp : public framework::OperatorBase { auto &cond = scope.FindVar(Input(kCondition))->Get(); PADDLE_ENFORCE_EQ( - cond.dims(), - phi::make_ddim({1}), + cond.numel(), + 1, platform::errors::InvalidArgument( - "The shape of Input(Condition) of WhileOp must be 1. But now " - "the Condition's shape is ", - cond.dims().to_str(), + "The numel of Input(Condition) of WhileOp must be 1. But now " + "the Condition's numel is ", + cond.numel(), ".\n")); #ifdef PADDLE_WITH_MKLDNN diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 1ddf9f897d..b00d305895 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -1590,6 +1590,32 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out2.grad.shape, []) self.assertEqual(x.grad.shape, []) + def test_while_loop(self): + def cond(i, x): + return paddle.less_than(i, eleven) + + def body(i, x): + x = x + i + i = i + 1 + return [i, x] + + i = paddle.full([], 1.0, dtype='float32') + i.stop_gradient = False + eleven = paddle.full([], 11, dtype='float32') + x = paddle.full([], 0.0, dtype='float32') + x.stop_gradient = False + out_i, out_x = paddle.static.nn.while_loop(cond, body, [i, x]) + out_x.backward() + + self.assertEqual(out_i.shape, []) + np.testing.assert_allclose(out_i, np.array(11)) + self.assertEqual(out_x.shape, []) + np.testing.assert_allclose(out_x, np.array(55)) + self.assertEqual(i.grad.shape, []) + np.testing.assert_allclose(i.grad, np.array(10)) + self.assertEqual(x.grad.shape, []) + np.testing.assert_allclose(x.grad, np.array(1.0)) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -2592,6 +2618,43 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[4].shape, ()) self.assertEqual(res[5].shape, ()) + @prog_scope() + def test_while_loop(self): + def cond(i, x): + return paddle.less_than(i, eleven) + + def body(i, x): + x = x + i + i = i + 1 + return [i, x] + + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, paddle.static.Program()): + i = paddle.static.data(name='i', shape=[], dtype='float32') + i.stop_gradient = False + eleven = paddle.full([], 11, 'float32') + x = paddle.static.data(name='x', shape=[], dtype='float32') + x.stop_gradient = False + out_i, out_x = paddle.static.nn.while_loop(cond, body, [i, x]) + paddle.static.append_backward(out_x) + + res = self.exe.run( + main_program, + feed={ + 'i': np.array(1.0, dtype='float32'), + 'x': np.array(0.0, dtype='float32'), + }, + fetch_list=[out_i.name, out_x.name, i.grad_name, x.grad_name], + ) + self.assertEqual(res[0].shape, ()) + np.testing.assert_allclose(res[0], np.array(11)) + self.assertEqual(res[1].shape, ()) + np.testing.assert_allclose(res[1], np.array(55)) + self.assertEqual(res[2].shape, ()) + np.testing.assert_allclose(res[2], np.array(10)) + self.assertEqual(res[3].shape, ()) + np.testing.assert_allclose(res[3], np.array(1.0)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index d46c0c7c18..0be5f23a03 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -467,7 +467,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ) if _non_static_mode(): - now_cond = pre_cond.numpy()[0] + now_cond = pre_cond.numpy().item() while now_cond: output_vars = body(*loop_vars) if not isinstance(output_vars, (list, tuple)): @@ -477,7 +477,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): "body in while_loop should return the same arity " "(length and structure) and types as loop_vars" ) - now_cond = cond(*output_vars).numpy()[0] + now_cond = cond(*output_vars).numpy().item() map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars) return loop_vars -- GitLab