From 6486428f05ae5cb87075ad930b7999d81a67c201 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 9 May 2023 15:13:46 +0800 Subject: [PATCH] feat(imperative): add dlpack support GitOrigin-RevId: e8f2bac6f8f415c20d9aa6ccc02762d69b5e2817 --- imperative/python/megengine/tensor.py | 21 +- imperative/python/megengine/utils/dlpack.py | 17 ++ imperative/python/src/common.cpp | 32 +++ imperative/python/src/dlpack.h | 231 +++++++++++++++++ imperative/python/src/dlpack_convertor.cpp | 270 ++++++++++++++++++++ imperative/python/src/dlpack_convertor.h | 24 ++ imperative/python/src/tensor.cpp | 39 +++ 7 files changed, 632 insertions(+), 2 deletions(-) create mode 100644 imperative/python/megengine/utils/dlpack.py create mode 100644 imperative/python/src/dlpack.h create mode 100644 imperative/python/src/dlpack_convertor.cpp create mode 100644 imperative/python/src/dlpack_convertor.h diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index d3fd96f16..8737aac9b 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- -from typing import Union +import enum +from typing import Tuple, Union import numpy as np from .core._imperative_rt import CompNode from .core._imperative_rt.core2 import FormatType from .core._imperative_rt.core2 import Tensor as _Tensor -from .core._imperative_rt.core2 import apply, set_py_tensor_type +from .core._imperative_rt.core2 import _to_dlpack, apply, set_py_tensor_type from .core._trace_option import use_symbolic_shape from .core._wrap import as_device from .core.ops.builtin import Borrow, Copy, GetVarShape @@ -249,6 +250,22 @@ class Tensor(_Tensor, ArrayMethodMixin): qparams = None self._qparams = qparams + def __dlpack__(self, stream=None): + if stream is not None and not isinstance(stream, int): + raise TypeError("stream must be ``int`` or ``none``") + elif stream is not None and stream != -1: + mdevice, mrank, mstream = self.device.physical_locator + if mdevice == "gpu": + if mstream != stream: + device = "gpu{}:{}".format(mrank, mstream) + self.to(device, _borrow=True) + elif mdevice == "cpu": + device = "cpu{}:{}".format(mrank, mstream) + self.to(device, _borrow=True) + else: + raise ValueError("dlpack not support this device: {}!".format(mdevice)) + return _to_dlpack(self) + set_py_tensor_type(Tensor) diff --git a/imperative/python/megengine/utils/dlpack.py b/imperative/python/megengine/utils/dlpack.py new file mode 100644 index 000000000..8ec8ca8b5 --- /dev/null +++ b/imperative/python/megengine/utils/dlpack.py @@ -0,0 +1,17 @@ +from typing import Any + +from ..core._imperative_rt.core2 import _from_dlpack + + +def to_dlpack(tensor, stream=None): + if stream is not None and stream != -1: + return tensor.__dlpack__(stream) + else: + return tensor.__dlpack__() + + +def from_dlpack(ext_tensor: Any, stream=None): + if isinstance(stream, int): + assert stream >= 0, "device stream should be a positive value" + stream = 0 if stream is None else stream + return _from_dlpack(ext_tensor, stream) diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 4d5dc90a4..f681ceae1 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -66,6 +66,28 @@ std::string get_default_device() { return default_device; } +std::string device_type2str(CompNode::DeviceType type) { + using DT = CompNode::DeviceType; + switch (type) { + case DT::UNSPEC: + return "xpu"; + case DT::CUDA: + return "gpu"; + case DT::CPU: + return "cpu"; + case DT::ATLAS: + return "atlas"; + case DT::ROCM: + return "rocm"; + case DT::CAMBRICON: + return "cambricon"; + case DT::MULTITHREAD: + return "multithread"; + default: + mgb_throw(MegBrainError, "bad device type"); + } +} + py::handle py_comp_node_type; void init_common(py::module m) { @@ -80,6 +102,16 @@ void init_common(py::module m) { .def_property_readonly( "physical_name", [](const CompNode& cn) { return cn.to_string_physical(); }) + .def_property_readonly( + "physical_locator", + [](const CompNode& cn) { + py::list res; + auto locator = cn.locator(); + res.append(device_type2str(locator.type)); + res.append(locator.device); + res.append(locator.stream); + return res; + }) .def_property_readonly( "get_mem_status_bytes", [](const CompNode& cn) { diff --git a/imperative/python/src/dlpack.h b/imperative/python/src/dlpack.h new file mode 100644 index 000000000..9b37b4375 --- /dev/null +++ b/imperative/python/src/dlpack.h @@ -0,0 +1,231 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dlpack.h + * \brief The common header of DLPack. + */ +#ifndef DLPACK_DLPACK_H_ +#define DLPACK_DLPACK_H_ + +/** + * \brief Compatibility with C++ + */ +#ifdef __cplusplus +#define DLPACK_EXTERN_C extern "C" +#else +#define DLPACK_EXTERN_C +#endif + +/*! \brief The current version of dlpack */ +#define DLPACK_VERSION 70 + +/*! \brief The current ABI version of dlpack */ +#define DLPACK_ABI_VERSION 1 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef DLPACK_EXPORTS +#define DLPACK_DLL __declspec(dllexport) +#else +#define DLPACK_DLL __declspec(dllimport) +#endif +#else +#define DLPACK_DLL +#endif + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif +/*! + * \brief The device type in DLDevice. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLCUDA = 2, + /*! + * \brief Pinned CUDA CPU memory by cudaMallocHost + */ + kDLCUDAHost = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Pinned ROCm CPU memory allocated by hipMallocHost + */ + kDLROCMHost = 11, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + /*! + * \brief CUDA managed/unified memory allocated by cudaMallocManaged + */ + kDLCUDAManaged = 13, + /*! + * \brief Unified shared memory allocated on a oneAPI non-partititioned + * device. Call to oneAPI runtime is required to determine the device + * type, the USM allocation type and the sycl context it is bound to. + * + */ + kDLOneAPI = 14, + /*! \brief GPU support for next generation WebGPU standard. */ + kDLWebGPU = 15, +} DLDeviceType; + +/*! + * \brief A Device for Tensor and operator. + */ +// NB: This is the only difference from +// https://github.com/dmlc/dlpack/blob/v0.7/include/dlpack/dlpack.h Required to +// allow forward declaration of DLDevice. +typedef struct DLDevice_ { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! + * \brief The device index. + * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. + */ + int32_t device_id; +} DLDevice; + +/*! + * \brief The type code options DLDataType. + */ +typedef enum { + /*! \brief signed integer */ + kDLInt = 0U, + /*! \brief unsigned integer */ + kDLUInt = 1U, + /*! \brief IEEE floating point */ + kDLFloat = 2U, + /*! + * \brief Opaque handle type, reserved for testing purposes. + * Frameworks need to agree on the handle data type for the exchange to be + * well-defined. + */ + kDLOpaqueHandle = 3U, + /*! \brief bfloat16 */ + kDLBfloat = 4U, + /*! + * \brief complex number + * (C/C++/Python layout: compact struct per complex number) + */ + kDLComplex = 5U, +} DLDataTypeCode; + +/*! + * \brief The data type the tensor can hold. The data type is assumed to follow + * the native endian-ness. An explicit error message should be raised when + * attempting to export an array with non-native endianness + * + * Examples + * - float: type_code = 2, bits = 32, lanes=1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 + * - int8: type_code = 0, bits = 8, lanes=1 + * - std::complex: type_code = 5, bits = 64, lanes = 1 + */ +typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; +} DLDataType; + +/*! + * \brief Plain C Tensor object, does not manage memory. + */ +typedef struct { + /*! + * \brief The data pointer points to the allocated data. This will be CUDA + * device pointer or cl_mem handle in OpenCL. It may be opaque on some device + * types. This pointer is always aligned to 256 bytes as in CUDA. The + * `byte_offset` field should be used to point to the beginning of the data. + * + * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte aligment requirement + * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed + * (after which this note will be updated); at the moment it is recommended + * to not rely on the data pointer being correctly aligned. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + */ + void* data; + /*! \brief The device of the tensor */ + DLDevice device; + /*! \brief Number of dimensions */ + int32_t ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! \brief The shape of the tensor */ + int64_t* shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes) + * can be NULL, indicating tensor is compact and row-majored. + */ + int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; +} DLTensor; + +/*! + * \brief C Tensor object, manage memory of DLTensor. This data structure is + * intended to facilitate the borrowing of DLTensor by another framework. It is + * not meant to transfer the tensor. When the borrowing framework doesn't need + * the tensor, it should call the deleter to notify the host that the resource + * is no longer needed. + */ +typedef struct DLManagedTensor { + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void* manager_ctx; + /*! \brief Destructor signature void (*)(void*) - this should be called + * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + * if there is no way for the caller to provide a reasonable destructor. + * The destructors deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor* self); +} DLManagedTensor; +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif // DLPACK_DLPACK_H_ diff --git a/imperative/python/src/dlpack_convertor.cpp b/imperative/python/src/dlpack_convertor.cpp new file mode 100644 index 000000000..a86c7cd48 --- /dev/null +++ b/imperative/python/src/dlpack_convertor.cpp @@ -0,0 +1,270 @@ +#include "./dlpack_convertor.h" +#include +#include "./helper.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/imperative/basic_operators.h" +#include "megbrain/imperative/dispatch.h" +#include "megbrain/tensor.h" + +using namespace mgb::imperative; +using namespace mgb; + +DLDataType mgb::imperative::get_dl_datatype(const DeviceTensorND& dv) { + DLDataType dtype; + dtype.lanes = 1; + dtype.bits = dv.dtype().size() * 8; + switch (dv.dtype().enumv()) { + case DTypeEnum::Byte: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case DTypeEnum::Uint8: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case DTypeEnum::Uint16: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case DTypeEnum::Int32: + dtype.code = DLDataTypeCode::kDLInt; + break; + case DTypeEnum::Int16: + dtype.code = DLDataTypeCode::kDLInt; + break; + case DTypeEnum::Int8: + dtype.code = DLDataTypeCode::kDLInt; + break; + case DTypeEnum::Float32: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case DTypeEnum::Float16: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case DTypeEnum::Bool: + mgb_throw(MegBrainError, "Bool type is not supported by dlpack"); + break; + case DTypeEnum::BFloat16: + dtype.code = DLDataTypeCode::kDLBfloat; + break; + case DTypeEnum::Complex64: + dtype.code = DLDataTypeCode::kDLComplex; + break; + default: + mgb_throw(MegBrainError, "type is not supported by dlpack"); + } + return dtype; +} + +DLDevice mgb::imperative::get_dl_device(const DeviceTensorND& dv) { + auto cn = dv.comp_node(); + DLDevice ctx; + switch (cn.device_type()) { + case CompNode::DeviceType::CPU: { + ctx.device_id = 0; + ctx.device_type = DLDeviceType::kDLCPU; + break; + } + case CompNode::DeviceType::CUDA: { +#if MGB_CUDA + auto&& env = CompNodeEnv::from_comp_node(cn).cuda_env(); + ctx.device_id = env.device; + ctx.device_type = DLDeviceType::kDLCUDA; +#else + mgb_throw(MegBrainError, "CUDA device is not available"); +#endif + break; + } + default: + mgb_throw( + MegBrainError, "Cannot pack tensors on %s", cn.to_string().c_str()); + } + return ctx; +} + +CompNode as_comp_node(const std::string& name) { + thread_local struct { + std::string name; + CompNode cn; + } dlpack_cncached; + if (dlpack_cncached.name != name) { + dlpack_cncached.name = name; + dlpack_cncached.cn = CompNode::load(name); + } + return dlpack_cncached.cn; +} + +CompNode mgb::imperative::get_tensor_device(const DLDevice& ctx, int stream) { + int id = ctx.device_id; + switch (ctx.device_type) { + case DLDeviceType::kDLCPU: { + auto device = "cpu" + std::to_string(id); + return as_comp_node(device); + } + case DLDeviceType::kDLCUDA: { + auto device = "gpu" + std::to_string(id) + ":" + std::to_string(stream); + return as_comp_node(device); + } + default: + mgb_throw(MegBrainError, "Unsupported device_type"); + } +} + +DType mgb::imperative::get_tensor_type(const DLDataType& dtype) { + DType tensortype; + switch (dtype.code) { + case DLDataTypeCode::kDLUInt: + switch (dtype.bits) { + case 8: + tensortype = DType::from_enum(DTypeEnum::Uint8); + break; + case 16: + tensortype = DType::from_enum(DTypeEnum::Uint16); + break; + default: + mgb_throw( + MegBrainError, "Unsupported kUInt bits: %s", + std::to_string(dtype.bits).c_str()); + } + break; + + case DLDataTypeCode::kDLInt: + switch (dtype.bits) { + case 8: + tensortype = DType::from_enum(DTypeEnum::Int8); + break; + case 16: + tensortype = DType::from_enum(DTypeEnum::Int16); + break; + case 32: + tensortype = DType::from_enum(DTypeEnum::Int32); + break; + default: + mgb_throw( + MegBrainError, "Unsupported kInt bits: %s", + std::to_string(dtype.bits).c_str()); + } + break; + case DLDataTypeCode::kDLFloat: + switch (dtype.bits) { + case 16: + tensortype = DType::from_enum(DTypeEnum::Float16); + break; + case 32: + tensortype = DType::from_enum(DTypeEnum::Float32); + break; + default: + mgb_throw( + MegBrainError, "Unsupported kFloat bits: %s", + std::to_string(dtype.bits).c_str()); + } + break; + case DLDataTypeCode::kDLBfloat: + switch (dtype.bits) { + case 16: + tensortype = DType::from_enum(DTypeEnum::BFloat16); + break; + default: + mgb_throw( + MegBrainError, "Unsupported kBFloat bits: %s", + std::to_string(dtype.bits).c_str()); + } + break; + case DLDataTypeCode::kDLComplex: + switch (dtype.bits) { + case 64: + tensortype = DType::from_enum(DTypeEnum::Complex64); + default: + mgb_throw( + MegBrainError, "Unsupported Complex bits: %s", + std::to_string(dtype.bits).c_str()); + } + break; + } + return tensortype; +} + +struct DLMTensor { + DeviceTensorND value; + DLManagedTensor tensor; + int64_t shape[MEGDNN_MAX_NDIM]; + int64_t stride[MEGDNN_MAX_NDIM]; +}; + +void deleter(DLManagedTensor* arg) { + delete static_cast(arg->manager_ctx); +} + +DLManagedTensor* mgb::imperative::to_dlpack(const ValueRef src) { + DeviceTensorND dv = src.dev_tensor()->as_nd(true); + DLMTensor* TensorHandler(new DLMTensor); + size_t ndim = dv.shape().ndim; + TensorHandler->value = dv; + TensorHandler->tensor.manager_ctx = TensorHandler; + TensorHandler->tensor.deleter = &deleter; + TensorHandler->tensor.dl_tensor.data = TensorHandler->value.raw_ptr(); + TensorHandler->tensor.dl_tensor.device = get_dl_device(dv); + TensorHandler->tensor.dl_tensor.ndim = ndim; + TensorHandler->tensor.dl_tensor.dtype = get_dl_datatype(dv); + + auto src_shape = TensorHandler->value.layout().shape; + auto src_stride = TensorHandler->value.layout().stride; + for (size_t i = 0; i < ndim; i++) { + if (src_shape[i] > std::numeric_limits::max()) { + mgb_throw( + MegBrainError, "unsupported input shape: %s", + TensorHandler->value.layout().to_string().c_str()); + } + TensorHandler->shape[i] = static_cast(src_shape[i]); + TensorHandler->stride[i] = static_cast(src_stride[i]); + } + TensorHandler->tensor.dl_tensor.shape = TensorHandler->shape; + TensorHandler->tensor.dl_tensor.strides = TensorHandler->stride; + TensorHandler->tensor.dl_tensor.byte_offset = 0; + return &(TensorHandler->tensor); +} + +TensorShape ptr2shape(const int64_t* ptr, size_t ndim) { + TensorShape shape; + mgb_assert( + ndim <= TensorShape::MAX_NDIM, "dim too large: %zd (max %zd)", ndim, + TensorShape::MAX_NDIM); + shape.ndim = ndim; + for (size_t i = 0; i < ndim; i++) { + if (ptr[i] < 0 || ptr[i] > std::numeric_limits::max()) { + std::string error_msg = ""; + for (size_t idx = 0; idx < ndim; idx++) { + auto shape_i = " " + std::to_string(ptr[i]); + error_msg += shape_i; + } + mgb_throw( + MegBrainError, "unsupported dlpack input shape: %s", + error_msg.c_str()); + } + shape[i] = ptr[i]; + } + return shape; +} + +ValueRef mgb::imperative::from_dlpack(DLManagedTensor* dlMTensor, int stream = 0) { + std::function deleter_dispatch = [dlMTensor](void*) { + if (dlMTensor->deleter) { + dlMTensor->deleter(dlMTensor); + } + }; + + DType tensor_type = get_tensor_type(dlMTensor->dl_tensor.dtype); + CompNode tensor_device = get_tensor_device(dlMTensor->dl_tensor.device, stream); + DeviceTensorStorage storage; + size_t dtype_size = tensor_type.size(); + size_t ndim = dlMTensor->dl_tensor.ndim; + TensorShape tensor_shape = ptr2shape(dlMTensor->dl_tensor.shape, ndim); + + storage.reset( + tensor_device, tensor_shape.total_nr_elems() * dtype_size, + {static_cast(dlMTensor->dl_tensor.data), deleter_dispatch}); + + ValueShape shapevalue = ValueShape::from(tensor_shape); + ValueRef val = imperative::apply( + CreateTensor( + CreateTensor::Common, tensor_device, tensor_type, shapevalue, {}), + DeviceStorage::make(storage))[0]; + return val; +}; \ No newline at end of file diff --git a/imperative/python/src/dlpack_convertor.h b/imperative/python/src/dlpack_convertor.h new file mode 100644 index 000000000..54077ea37 --- /dev/null +++ b/imperative/python/src/dlpack_convertor.h @@ -0,0 +1,24 @@ +#include "./dlpack.h" +#include "./tensor.h" + +#include "megbrain/imperative/value.h" +#include "megbrain/tensor.h" + +namespace mgb { +namespace imperative { + +DLManagedTensor* to_dlpack(const ValueRef src); + +DLDevice get_dl_device(const DeviceTensorND& dv); + +DLDataType get_dl_datatype(const DeviceTensorND& dv); + +ValueRef from_dlpack(DLManagedTensor* dlMTensor, int stream); + +CompNode get_tensor_device(const DLDevice& ctx, int stream); + +mgb::DType get_tensor_type(const DLDataType& dtype); + +} // namespace imperative + +} // namespace mgb \ No newline at end of file diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 63de260e3..fa71f3774 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -25,6 +25,8 @@ #include "megdnn/algorithm_cache.h" #include "./common.h" +#include "./dlpack.h" +#include "./dlpack_convertor.h" #include "./grad.h" #include "./graph_rt.h" #include "./helper.h" @@ -741,6 +743,35 @@ PyObject* TensorWrapper::_graph() { return py::cast(graph).release().ptr(); } +void dlpack_capsule_destructor(PyObject* data) { + if (!PyCapsule_IsValid(data, "dltensor")) { + // early out, see DLPack spec: if a consuming library sets the capsule + // name to something else, they own it and we don't need to do anything + return; + } + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + dlMTensor->deleter(const_cast(dlMTensor)); +} + +PyObject* tensor_to_dlpack(PyObject* tensor) { + TensorWrapper* wrapper = TensorWrapper::try_cast(tensor); + DLManagedTensor* dlMTensor = to_dlpack(wrapper->m_tensor->data()); + return PyCapsule_New(dlMTensor, "dltensor", dlpack_capsule_destructor); +} + +PyObject* tensor_from_dlpack(PyObject* data, PyObject* stream) { + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + if (!PyLong_Check(stream)) { + throw py::type_error("expect int"); + } + int sid = PyLong_AsLong(stream); + PyCapsule_SetName(data, "used_dltensor"); + auto tensor = from_dlpack(dlMTensor, sid); + return TensorWrapper::make(py_tensor_type, std::move(tensor)).release().ptr(); +} + struct TensorWeakRef { ValueWeakRef data; @@ -1465,6 +1496,14 @@ void init_tensor(py::module m) { m.def("get_auto_format_convert", [format_trans]() { return format_trans->get_auto_convert(); }); + m.def("_to_dlpack", [](py::object tensor) { + return py::reinterpret_steal(tensor_to_dlpack(tensor.ptr())); + }); + + m.def("_from_dlpack", [](py::object data, py::object stream) { + return py::reinterpret_steal( + tensor_from_dlpack(data.ptr(), stream.ptr())); + }); py::register_exception(m, "TraceError"); m.def("create_complex", [](py::object real, py::object imag) { -- GitLab