diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index a5e3183774f990b37ffe216920fea1e325e4a482..248d899aa2d5d6876c5e31ccc75a342a44a6d73e 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -100,12 +100,12 @@ class WhileOp : public framework::OperatorBase { auto &cond = scope.FindVar(Input(kCondition))->Get(); PADDLE_ENFORCE_EQ( - cond.dims(), - phi::make_ddim({1}), + cond.numel(), + 1, platform::errors::InvalidArgument( - "The shape of Input(Condition) of WhileOp must be 1. But now " - "the Condition's shape is ", - cond.dims().to_str(), + "The numel of Input(Condition) of WhileOp must be 1. But now " + "the Condition's numel is ", + cond.numel(), ".\n")); #ifdef PADDLE_WITH_MKLDNN diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 1ddf9f897dbc3b99432ba32e27b9a45d218ed9ec..b00d305895a5c37c382f1479b16c7bf007555120 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -1590,6 +1590,32 @@ class TestSundryAPI(unittest.TestCase): self.assertEqual(out2.grad.shape, []) self.assertEqual(x.grad.shape, []) + def test_while_loop(self): + def cond(i, x): + return paddle.less_than(i, eleven) + + def body(i, x): + x = x + i + i = i + 1 + return [i, x] + + i = paddle.full([], 1.0, dtype='float32') + i.stop_gradient = False + eleven = paddle.full([], 11, dtype='float32') + x = paddle.full([], 0.0, dtype='float32') + x.stop_gradient = False + out_i, out_x = paddle.static.nn.while_loop(cond, body, [i, x]) + out_x.backward() + + self.assertEqual(out_i.shape, []) + np.testing.assert_allclose(out_i, np.array(11)) + self.assertEqual(out_x.shape, []) + np.testing.assert_allclose(out_x, np.array(55)) + self.assertEqual(i.grad.shape, []) + np.testing.assert_allclose(i.grad, np.array(10)) + self.assertEqual(x.grad.shape, []) + np.testing.assert_allclose(x.grad, np.array(1.0)) + class TestSundryAPIStatic(unittest.TestCase): def setUp(self): @@ -2592,6 +2618,43 @@ class TestSundryAPIStatic(unittest.TestCase): self.assertEqual(res[4].shape, ()) self.assertEqual(res[5].shape, ()) + @prog_scope() + def test_while_loop(self): + def cond(i, x): + return paddle.less_than(i, eleven) + + def body(i, x): + x = x + i + i = i + 1 + return [i, x] + + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program, paddle.static.Program()): + i = paddle.static.data(name='i', shape=[], dtype='float32') + i.stop_gradient = False + eleven = paddle.full([], 11, 'float32') + x = paddle.static.data(name='x', shape=[], dtype='float32') + x.stop_gradient = False + out_i, out_x = paddle.static.nn.while_loop(cond, body, [i, x]) + paddle.static.append_backward(out_x) + + res = self.exe.run( + main_program, + feed={ + 'i': np.array(1.0, dtype='float32'), + 'x': np.array(0.0, dtype='float32'), + }, + fetch_list=[out_i.name, out_x.name, i.grad_name, x.grad_name], + ) + self.assertEqual(res[0].shape, ()) + np.testing.assert_allclose(res[0], np.array(11)) + self.assertEqual(res[1].shape, ()) + np.testing.assert_allclose(res[1], np.array(55)) + self.assertEqual(res[2].shape, ()) + np.testing.assert_allclose(res[2], np.array(10)) + self.assertEqual(res[3].shape, ()) + np.testing.assert_allclose(res[3], np.array(1.0)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index d46c0c7c189b7508f2356e5ce24c56d74033e612..0be5f23a0368bf45b08a8162968c586d29aa8e88 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -467,7 +467,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): ) if _non_static_mode(): - now_cond = pre_cond.numpy()[0] + now_cond = pre_cond.numpy().item() while now_cond: output_vars = body(*loop_vars) if not isinstance(output_vars, (list, tuple)): @@ -477,7 +477,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): "body in while_loop should return the same arity " "(length and structure) and types as loop_vars" ) - now_cond = cond(*output_vars).numpy()[0] + now_cond = cond(*output_vars).numpy().item() map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars) return loop_vars