未验证 提交 db599258 编写于 作者: J Jiabin Yang 提交者: GitHub

【Eager】Allow return none when stop_gradient=False (#51740)

* allow return none when stop_gradient=True

* remove useless code

* refine code

* refine code

* fix test cast

* change more test

* add more tests
上级 8c61a95a
......@@ -155,9 +155,21 @@ GradNodePyLayer::operator()(
grad_out.push_back({});
} else {
if (ctx->forward_input_tensor_is_duplicable[i]) {
grad_out.push_back(paddle::pybind::GetTensorListFromPyObject(obj));
grad_out.push_back(
paddle::pybind::GetTensorListFromPyObject(obj, true));
} else {
grad_out.push_back({paddle::pybind::GetTensorFromPyObject(obj)});
if (paddle::pybind::PyCheckTensor(obj)) {
grad_out.push_back(
{paddle::pybind::UnSafeGetTensorFromPyObject(obj)});
} else if (obj == Py_None) {
VLOG(4) << "Got None for Tensor with pos: " << i;
grad_out.push_back({});
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"argument must be "
"Tensor or None, but got %s",
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
}
}
} else {
......
......@@ -1219,7 +1219,8 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromPyObject(PyObject* obj) {
return result;
}
std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj) {
std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj,
bool allow_none) {
std::vector<paddle::Tensor> result;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
......@@ -1229,6 +1230,9 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj) {
if (PyObject_IsInstance(item,
reinterpret_cast<PyObject*>(p_tensor_type))) {
result.emplace_back(reinterpret_cast<TensorObject*>(item)->tensor);
} else if (allow_none && (item == Py_None)) {
VLOG(4) << "Got None in Tensor list: " << i;
result.emplace_back();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument must be "
......@@ -1245,6 +1249,9 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj) {
if (PyObject_IsInstance(item,
reinterpret_cast<PyObject*>(p_tensor_type))) {
result.emplace_back(reinterpret_cast<TensorObject*>(item)->tensor);
} else if (allow_none && (item == Py_None)) {
VLOG(4) << "Got None in Tensor list: " << i;
result.emplace_back();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument must be "
......@@ -1262,16 +1269,9 @@ std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj) {
return result;
}
paddle::Tensor& GetTensorFromPyObject(PyObject* obj) {
if (!PyCheckTensor(obj)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument must be "
"Tensor, but got %s",
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
paddle::Tensor& UnSafeGetTensorFromPyObject(PyObject* obj) {
return reinterpret_cast<TensorObject*>(obj)->tensor;
}
paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
......
......@@ -327,9 +327,9 @@ std::vector<paddle::Tensor*> GetTensorPtrListFromArgs(
std::vector<paddle::Tensor*> GetTensorPtrListFromPyObject(PyObject* obj);
std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj);
paddle::Tensor& GetTensorFromPyObject(PyObject* obj);
std::vector<paddle::Tensor> GetTensorListFromPyObject(PyObject* obj,
bool allow_none = false);
paddle::Tensor& UnSafeGetTensorFromPyObject(PyObject* obj);
// end of Slice related methods
......
......@@ -122,6 +122,36 @@ class TestPyLayer(unittest.TestCase):
np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
)
def test_simple_pylayer_multi_output(self):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, func1, func2=paddle.split):
ctx.func = func2
y1 = func1(x1)
ctx.save_for_backward(y1)
return y1
@staticmethod
def backward(ctx, dy1):
(y1,) = ctx.saved_tensor()
re1 = ctx.func(dy1, 3)
return re1
input1 = paddle.randn([2, 3]).astype("float64")
input2 = paddle.randn([2, 3]).astype("float64")
input3 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
input2.stop_gradient = False
input3.stop_gradient = False
z = tanh.apply(x1=[input1, input2, input3], func1=paddle.concat)
z.mean().backward()
z2 = paddle.concat([input1, input2, input3])
z2.mean().backward()
self.assertTrue(
np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10
)
def test_pylayer_num_output_match(self):
class tanh(PyLayer):
@staticmethod
......@@ -269,8 +299,8 @@ class TestPyLayer(unittest.TestCase):
input2.stop_gradient = False
z = Layer_bk_none1.apply(input2)
with self.assertRaises(ValueError):
z.sum().backward()
z.sum().backward()
self.assertEqual(input2.grad, None)
class Layer_bk_none2(PyLayer):
@staticmethod
......@@ -285,8 +315,8 @@ class TestPyLayer(unittest.TestCase):
input1.stop_gradient = False
z = Layer_bk_none2.apply(input1, input1)
with self.assertRaises(ValueError):
z.mean().backward()
z.mean().backward()
self.assertIsNone(z.grad)
class Layer_bk_one1(PyLayer):
@staticmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册