From db599258e536a5efa2766b4980cce0fc6396fcc3 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Wed, 22 Mar 2023 11:35:58 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Eager=E3=80=91Allow=20return=20none=20?= =?UTF-8?q?when=20stop=5Fgradient=3DFalse=20(#51740)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * allow return none when stop_gradient=True * remove useless code * refine code * refine code * fix test cast * change more test * add more tests --- paddle/fluid/eager/pylayer/py_layer_node.cc | 16 +++++++- paddle/fluid/pybind/eager_utils.cc | 18 ++++----- paddle/fluid/pybind/eager_utils.h | 6 +-- .../fluid/tests/unittests/test_pylayer_op.py | 38 +++++++++++++++++-- 4 files changed, 60 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/eager/pylayer/py_layer_node.cc b/paddle/fluid/eager/pylayer/py_layer_node.cc index 3be77e59843..df65ea730f2 100644 --- a/paddle/fluid/eager/pylayer/py_layer_node.cc +++ b/paddle/fluid/eager/pylayer/py_layer_node.cc @@ -155,9 +155,21 @@ GradNodePyLayer::operator()( grad_out.push_back({}); } else { if (ctx->forward_input_tensor_is_duplicable[i]) { - grad_out.push_back(paddle::pybind::GetTensorListFromPyObject(obj)); + grad_out.push_back( + paddle::pybind::GetTensorListFromPyObject(obj, true)); } else { - grad_out.push_back({paddle::pybind::GetTensorFromPyObject(obj)}); + if (paddle::pybind::PyCheckTensor(obj)) { + grad_out.push_back( + {paddle::pybind::UnSafeGetTensorFromPyObject(obj)}); + } else if (obj == Py_None) { + VLOG(4) << "Got None for Tensor with pos: " << i; + grad_out.push_back({}); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "argument must be " + "Tensor or None, but got %s", + reinterpret_cast(obj->ob_type)->tp_name)); + } } } } else { diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 7b7e46c5414..a0d0250949f 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1219,7 +1219,8 @@ std::vector GetTensorPtrListFromPyObject(PyObject* obj) { return result; } -std::vector GetTensorListFromPyObject(PyObject* obj) { +std::vector GetTensorListFromPyObject(PyObject* obj, + bool allow_none) { std::vector result; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); @@ -1229,6 +1230,9 @@ std::vector GetTensorListFromPyObject(PyObject* obj) { if (PyObject_IsInstance(item, reinterpret_cast(p_tensor_type))) { result.emplace_back(reinterpret_cast(item)->tensor); + } else if (allow_none && (item == Py_None)) { + VLOG(4) << "Got None in Tensor list: " << i; + result.emplace_back(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument must be " @@ -1245,6 +1249,9 @@ std::vector GetTensorListFromPyObject(PyObject* obj) { if (PyObject_IsInstance(item, reinterpret_cast(p_tensor_type))) { result.emplace_back(reinterpret_cast(item)->tensor); + } else if (allow_none && (item == Py_None)) { + VLOG(4) << "Got None in Tensor list: " << i; + result.emplace_back(); } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument must be " @@ -1262,16 +1269,9 @@ std::vector GetTensorListFromPyObject(PyObject* obj) { return result; } -paddle::Tensor& GetTensorFromPyObject(PyObject* obj) { - if (!PyCheckTensor(obj)) { - PADDLE_THROW(platform::errors::InvalidArgument( - "argument must be " - "Tensor, but got %s", - reinterpret_cast(obj->ob_type)->tp_name)); - } +paddle::Tensor& UnSafeGetTensorFromPyObject(PyObject* obj) { return reinterpret_cast(obj)->tensor; } - paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 9c1c0b8ba52..4f0b3fa6370 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -327,9 +327,9 @@ std::vector GetTensorPtrListFromArgs( std::vector GetTensorPtrListFromPyObject(PyObject* obj); -std::vector GetTensorListFromPyObject(PyObject* obj); - -paddle::Tensor& GetTensorFromPyObject(PyObject* obj); +std::vector GetTensorListFromPyObject(PyObject* obj, + bool allow_none = false); +paddle::Tensor& UnSafeGetTensorFromPyObject(PyObject* obj); // end of Slice related methods diff --git a/python/paddle/fluid/tests/unittests/test_pylayer_op.py b/python/paddle/fluid/tests/unittests/test_pylayer_op.py index 52c050579c9..b270a8b0d7f 100644 --- a/python/paddle/fluid/tests/unittests/test_pylayer_op.py +++ b/python/paddle/fluid/tests/unittests/test_pylayer_op.py @@ -122,6 +122,36 @@ class TestPyLayer(unittest.TestCase): np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10 ) + def test_simple_pylayer_multi_output(self): + class tanh(PyLayer): + @staticmethod + def forward(ctx, x1, func1, func2=paddle.split): + ctx.func = func2 + y1 = func1(x1) + ctx.save_for_backward(y1) + return y1 + + @staticmethod + def backward(ctx, dy1): + (y1,) = ctx.saved_tensor() + re1 = ctx.func(dy1, 3) + return re1 + + input1 = paddle.randn([2, 3]).astype("float64") + input2 = paddle.randn([2, 3]).astype("float64") + input3 = paddle.randn([2, 3]).astype("float64") + input1.stop_gradient = False + input2.stop_gradient = False + input3.stop_gradient = False + z = tanh.apply(x1=[input1, input2, input3], func1=paddle.concat) + z.mean().backward() + z2 = paddle.concat([input1, input2, input3]) + z2.mean().backward() + + self.assertTrue( + np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10 + ) + def test_pylayer_num_output_match(self): class tanh(PyLayer): @staticmethod @@ -269,8 +299,8 @@ class TestPyLayer(unittest.TestCase): input2.stop_gradient = False z = Layer_bk_none1.apply(input2) - with self.assertRaises(ValueError): - z.sum().backward() + z.sum().backward() + self.assertEqual(input2.grad, None) class Layer_bk_none2(PyLayer): @staticmethod @@ -285,8 +315,8 @@ class TestPyLayer(unittest.TestCase): input1.stop_gradient = False z = Layer_bk_none2.apply(input1, input1) - with self.assertRaises(ValueError): - z.mean().backward() + z.mean().backward() + self.assertIsNone(z.grad) class Layer_bk_one1(PyLayer): @staticmethod -- GitLab