未验证 提交 fc9d68a5 编写于 作者: A Alexandre Passos 提交者: GitHub

Merge pull request #41017 from VoVAllen/fix_dlpack_r2.3

Cherry-pick dlpack fix #40843 into r2.3
...@@ -221,8 +221,7 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype, ...@@ -221,8 +221,7 @@ Status TfDataTypeFormDlDataType(const DLDataType& dtype,
// Wraps the deleter function of DLManagedTensor to match the function signature // Wraps the deleter function of DLManagedTensor to match the function signature
// TFE_NewTensorHandleFromDeviceMemory. // TFE_NewTensorHandleFromDeviceMemory.
void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlmt_vptr); TFE_CallDLManagedTensorDeleter(dlmt_vptr);
dlmt->deleter(const_cast<DLManagedTensor*>(dlmt));
} }
// Checks whether the stride array matches the layout of compact, row-majored // Checks whether the stride array matches the layout of compact, row-majored
...@@ -324,7 +323,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, ...@@ -324,7 +323,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
ctx, device_name.value().c_str(), dtype, dims, num_dims, data, ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
total_bytes, &DeallocatorWrapperFunc, &dlmt, status); total_bytes, &DeallocatorWrapperFunc, dlmt, status);
return handle; return handle;
} }
......
...@@ -1169,7 +1169,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) { ...@@ -1169,7 +1169,9 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
return py::handle(EagerTensorFromHandle(thandle));
PyObject* pyhandle = EagerTensorFromHandle(thandle);
return tensorflow::PyoOrThrow(pyhandle);
}); });
m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context, m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册