提交 d7d490e4 编写于 作者: D DesmonDay

fix to_dlpack for loop

上级 86a22ad4
......@@ -475,20 +475,14 @@ void BindTensor(pybind11::module &m) { // NOLINT
[](phi::DenseTensor &self) {
DLPackTensor dlpack_tensor(self, 1);
DLManagedTensor *dmt = dlpack_tensor.ToDLManagedTensor();
auto capsule = py::capsule(
auto capsule = pybind11::capsule(
static_cast<void *>(dmt), "dltensor", [](PyObject *ptr) {
if (ptr) {
auto dltensor = new DLManagedTensor;
try {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "used_dltensor"));
return;
} catch (...) {
dltensor = reinterpret_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
}
dltensor->deleter(dltensor);
if (!PyCapsule_IsValid(ptr, "dltensor")) {
return;
}
DLManagedTensor *dmt = static_cast<DLManagedTensor *>(
PyCapsule_GetPointer(ptr, "dltensor"));
dmt->deleter(dmt);
});
return capsule;
})
......
......@@ -116,6 +116,12 @@ class TestDLPack(unittest.TestCase):
dlpack = paddle.utils.dlpack.to_dlpack(a)
b = paddle.utils.dlpack.from_dlpack(dlpack)
def test_to_dlpack_for_loop(self):
# See Paddle issue 50120
for i in range(10):
x = paddle.rand([3, 5])
dlpack = paddle.utils.dlpack.to_dlpack(x)
class TestRaiseError(unittest.TestCase):
def test_from_dlpack_raise_type_error(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册