提交 2553e843 编写于 作者: M Mihai Maruseac

Fix multiple vulnerabilities in `tf.experimental.dlpack.to_dlpack`.

We have a use after free caused by memory coruption, a segmentation fault caused by memory corruption, several memory leaks and an undefined behavior when taking the reference of a nullptr.

PiperOrigin-RevId: 332568894
Change-Id: Ife0fc05e103b35325094ae5d822ee5fdea764572
上级 f3535c02
......@@ -250,21 +250,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
auto tf_dlm_context = GetDlContext(h, status);
if (!status->status.ok()) {
return nullptr;
}
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
if (!status->status.ok()) {
return nullptr;
}
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto tf_dlm_type = GetDlDataType(data_type, status);
if (!status->status.ok()) {
return nullptr;
}
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
dlm_tensor->dl_tensor.data = tf_dlm_data;
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
......@@ -277,13 +292,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
dlm_tensor->dl_tensor.shape = shape_arr->data();
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.strides = stride_arr->data();
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
......
......@@ -20,9 +20,11 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.dlpack import dlpack
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
......@@ -95,6 +97,12 @@ class DLPackTest(parameterized.TestCase, test.TestCase):
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
UnsupportedComplex64)
def testMustPassTensorArgumentToDLPack(self):
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"The argument to `to_dlpack` must be a TF tensor, not Python object"):
dlpack.to_dlpack([1])
if __name__ == "__main__":
ops.enable_eager_execution()
......
......@@ -1051,9 +1051,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// DLPack functions
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
PyObject* eager_tensor_pyobject_ptr = o.ptr();
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
status->status = tensorflow::errors::InvalidArgument(
"The argument to `to_dlpack` must be a TF tensor, not Python object");
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册