未验证 提交 8e66046b 编写于 作者: W WeiXin 提交者: GitHub

support backward return None, when corresponding input tensor without gradient (#32494)

* support backward return None.

* edit unittest.

* edit code according to CI

* Improve error information
上级 fd85a4af
......@@ -60,33 +60,51 @@ void RunPyObject(py::object *py_object,
outs->size(), result_tuple.size()));
}
for (size_t i = 0; i < result_tuple.size(); i++) {
if (Py_None != result_tuple[i].ptr()) {
if ((*outs)[i] != nullptr) {
if (Py_None != result_tuple[i].ptr()) {
try {
auto result_var =
result_tuple[i].cast<std::shared_ptr<imperative::VarBase>>();
*(*outs)[i] = result_var->Var();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The output of `PyLayer.backward` should be `Tensor`."));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %dth input tensor of forward needs gradient and the "
"corresponding gradient cannot be None.",
i));
}
} else {
if (Py_None != result_tuple[i].ptr()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %dth input tensor of forward do not need gradient and the "
"corresponding gradient should be `None`.",
i));
}
}
}
} else {
if ((*outs)[0] != nullptr) {
if (Py_None != py_result.ptr()) {
try {
auto result_var =
result_tuple[i].cast<std::shared_ptr<imperative::VarBase>>();
*(*outs)[i] = result_var->Var();
py_result.cast<std::shared_ptr<imperative::VarBase>>();
*((*outs)[0]) = result_var->Var();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::Unimplemented(
PADDLE_THROW(platform::errors::InvalidArgument(
"The output of `PyLayer.backward` should be `Tensor`."));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` can not be `None`."));
}
}
} else {
if (Py_None != py_result.ptr()) {
try {
auto result_var =
py_result.cast<std::shared_ptr<imperative::VarBase>>();
*((*outs)[0]) = result_var->Var();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` should be `Tensor`."));
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor of forward needs gradient, so the output of "
"`PyLayer.backward` can not be `None`."));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` can not be `None`."));
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor of forward do not need gradient, so the output of "
"`PyLayer.backward` should be `None`."));
}
}
}
......
......@@ -52,6 +52,40 @@ class TestPyLayer(unittest.TestCase):
self.assertTrue(np.max(np.abs((input1.grad - input2.grad))) < 1e-10)
def test_simple_pylayer_return_none_with_no_grad(self):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square):
ctx.func = func2
y1 = func1(x1)
y2 = func1(x2)
ctx.save_for_backward(y1, y2)
return y1, y2
@staticmethod
def backward(ctx, dy1, dy2):
y1, y2 = ctx.saved_tensor()
re1 = dy1 * (1 - ctx.func(y1))
re2 = dy2 * (1 - paddle.square(y2))
return re1, None
input1 = paddle.randn([2, 3]).astype("float64")
input2 = input1.detach().clone()
input3 = input1.detach().clone()
input4 = input1.detach().clone()
input1.stop_gradient = False
input2.stop_gradient = False
input3.stop_gradient = True
input4.stop_gradient = True
z = tanh.apply(input1, input3, paddle.tanh, paddle.square)
z = z[0] + z[1]
z.mean().backward()
z2 = paddle.tanh(input2) + paddle.tanh(input4)
z2.mean().backward()
self.assertTrue(np.max(np.abs((input1.grad - input2.grad))) < 1e-10)
def test_simple_pylayer_single_output(self):
class tanh(PyLayer):
@staticmethod
......@@ -196,7 +230,7 @@ class TestPyLayer(unittest.TestCase):
input2.stop_gradient = False
z = Layer_bk_none1.apply(input2)
with self.assertRaises(NotImplementedError):
with self.assertRaises(ValueError):
with paddle.fluid.dygraph.guard():
z.sum().backward()
......@@ -212,7 +246,7 @@ class TestPyLayer(unittest.TestCase):
input1 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
z = Layer_bk_none2.apply(input1, input1)
with self.assertRaises(NotImplementedError):
with self.assertRaises(ValueError):
with paddle.fluid.dygraph.guard():
z.mean().backward()
......@@ -228,14 +262,14 @@ class TestPyLayer(unittest.TestCase):
input1 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
z = Layer_bk_one1.apply(input1)
with self.assertRaises(NotImplementedError):
with self.assertRaises(ValueError):
with paddle.fluid.dygraph.guard():
z.mean().backward()
class Layer_bk_one2(PyLayer):
@staticmethod
def forward(ctx, x):
return x * 2, x * 5
def forward(ctx, x1, x2):
return x1 * 2, x2 * 5
@staticmethod
def backward(ctx, *args):
......@@ -243,8 +277,9 @@ class TestPyLayer(unittest.TestCase):
input1 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
z = Layer_bk_one1.apply(input1)
with self.assertRaises(NotImplementedError):
y = Layer_bk_one2.apply(input1, input1)
z = y[0] + y[1]
with self.assertRaises(ValueError):
with paddle.fluid.dygraph.guard():
z.mean().backward()
......@@ -279,6 +314,45 @@ class TestPyLayer(unittest.TestCase):
z = z[0] + z[1]
z.mean().backward()
def test_pylayer_bk_return_none(self):
class Layer_bk_none1(PyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 + x2
@staticmethod
def backward(ctx, dy):
return 1
input1 = paddle.randn([2, 3]).astype("float64")
input2 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = True
input2.stop_gradient = False
z = Layer_bk_none1.apply(input1, input2)
with self.assertRaises(ValueError):
with paddle.fluid.dygraph.guard():
z.mean().backward()
class Layer_bk_none2(PyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 * 2, x2 * 5
@staticmethod
def backward(ctx, *args):
return 1, 1
input1 = paddle.randn([2, 3]).astype("float64")
input2 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = True
input2.stop_gradient = False
z = Layer_bk_none2.apply(input1, input2)
z = z[0] + z[1]
with self.assertRaises(ValueError):
with paddle.fluid.dygraph.guard():
z.mean().backward()
def test_pylayer_inplace(self):
class cus_tanh(PyLayer):
@staticmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册