未验证 提交 2c02a580 编写于 作者: C chentianyu03 提交者: GitHub

add check whether tensor is inplace and leaf when calcute gradient (#37931)

上级 d1ab323f
......@@ -226,6 +226,15 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
}
}
if (if_inplace) {
// when pylayer forward is inplace strategy, check whether tensor is leaf
for (auto& t : input_vars) {
PADDLE_ENFORCE_EQ(t->IsLeaf() && !t->OverridedStopGradient(), false,
platform::errors::InvalidArgument(
"Leaf Var (%s) that doesn't stop gradient can't "
"use inplace strategy.",
t->Name()));
}
inplace_map["X"] = "Out";
}
......
......@@ -406,6 +406,32 @@ class TestPyLayer(unittest.TestCase):
z.backward()
self.assertTrue(data.grad is not None)
def test_pylayer_inplace_and_leaf_exception(self):
class cus_pylayer_op(PyLayer):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, dy):
return dy
class Layer(paddle.nn.Layer):
def __init__(self):
super(Layer, self).__init__()
def forward(self, data):
z = cus_pylayer_op.apply(data)
return z.mean()
for i in range(2):
data = paddle.ones([2, 3], dtype="float64") / (i + 1)
data.stop_gradient = False
layer = Layer()
with self.assertRaises(ValueError):
z = layer(data)
def test_backward_in_backward(self):
class cus_tanh(PyLayer):
@staticmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册