From 787a22a9d6802da204910893733bad688292e3a4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 12 Apr 2022 19:51:45 +0800 Subject: [PATCH] perf(tensor): implement __new__ in cpp GitOrigin-RevId: 4defd249c3ca673ec67648c4f8aaa1858eb00447 --- imperative/python/megengine/core/_wrap.py | 4 + imperative/python/megengine/tensor.py | 33 -- imperative/python/src/common.cpp | 5 + imperative/python/src/common.h | 4 +- imperative/python/src/tensor.cpp | 468 +++++++++++++++--- imperative/python/src/tensor.h | 23 +- .../test/unit/core/test_tensor_wrapper.py | 7 + src/core/include/megbrain/dtype.h | 2 +- 8 files changed, 432 insertions(+), 114 deletions(-) diff --git a/imperative/python/megengine/core/_wrap.py b/imperative/python/megengine/core/_wrap.py index f1650509..c5118b7e 100644 --- a/imperative/python/megengine/core/_wrap.py +++ b/imperative/python/megengine/core/_wrap.py @@ -9,6 +9,7 @@ import numpy as np from ._imperative_rt import CompNode +from ._imperative_rt.core2 import set_py_device_type class Device: @@ -53,3 +54,6 @@ def as_device(obj): if isinstance(obj, Device): return obj return Device(obj) + + +set_py_device_type(Device) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 080106df..1a2bf470 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -72,39 +72,6 @@ class Tensor(_Tensor, ArrayMethodMixin): _short_name = None _prefix = None - def __new__( - cls, - data: Union["Tensor", np.ndarray, list, int, float] = None, - dtype: np.dtype = None, - device: str = None, - is_const: bool = False, - no_cache: bool = False, - name: str = None, - ): - if data is None: - data = [] - if device is None: - cn = get_default_device() - elif isinstance(device, str): - if cls.dmap_callback is not None: - cn = CompNode(cls.dmap_callback(device)) - else: - cn = CompNode(device) - else: - if isinstance(device, CompNode): - cn = device - else: - cn = device._cn - - if isinstance(data, _Tensor): - obj = _Tensor.__new__(cls, data) - else: - if isinstance(data, np.ndarray): - if 0 in data.strides: - data = data.squeeze().reshape(data.shape) - obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name) - return obj - def __init__( self, data: Union["Tensor", np.ndarray, list, int, float], diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 90e09aea..cb71c0c8 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -12,6 +12,7 @@ #include "./common.h" #include +#include #include "./helper.h" #include "./numpy_dtypes.h" @@ -56,6 +57,8 @@ std::string get_default_device() { return default_device; } +py::handle py_comp_node_type; + void init_common(py::module m) { auto PyCompNode = py::class_(m, "CompNode") @@ -117,6 +120,8 @@ void init_common(py::module m) { }, [](py::str cn) { return CompNode::load(cn); })); + py_comp_node_type = PyCompNode.inc_ref(); + py::class_>(PyCompNode, "Event") .def("record", &CompNode::Event::record) .def("wait", &CompNode::Event::host_wait); diff --git a/imperative/python/src/common.h b/imperative/python/src/common.h index 9f250e1b..ca10f042 100644 --- a/imperative/python/src/common.h +++ b/imperative/python/src/common.h @@ -16,4 +16,6 @@ void init_common(pybind11::module m); void set_default_device(const std::string& device); -std::string get_default_device(); \ No newline at end of file +std::string get_default_device(); + +extern pybind11::handle py_comp_node_type; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 5d07b7cc..0f219745 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include @@ -108,6 +109,7 @@ struct SymbolVarContext { interpreter::Interpreter::Channel* interpreter_for_py = nullptr; PyTypeObject* py_tensor_type = nullptr; +pybind11::handle py_device_type = nullptr; PyObject* cpp_use_symbolic_shape; #define REGISTE_APPLY_FUNC(mode) \ @@ -233,70 +235,410 @@ PyObject* py_apply( PYEXT17_TRANSLATE_EXC_RET(nullptr) } -TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { - if (kwargs && PyDict_Size(kwargs)) { - throw py::type_error("keyword argument not allowed"); +namespace { + +template +py::handle py_type() { + if constexpr (std::is_same_v) { + return (PyObject*)&PyLong_Type; + } else if constexpr (std::is_same_v) { + return (PyObject*)&PyFloat_Type; + } else if constexpr (std::is_same_v) { + return (PyObject*)&PyTuple_Type; + } else if constexpr (std::is_same_v) { + return (PyObject*)&PyList_Type; + } else { + static_assert(std::is_same_v); } - auto nargs = PyTuple_Size(args); - auto tup = py::reinterpret_borrow(args); - if (nargs == 0) { - throw py::type_error("too few arguments"); +} + +template +auto scalar2storage(T val, CompNode cn, DType dtype) { + using max_ctype_t = DTypeScalar::max_ctype; + DTypeScalar scalar(dtype); + scalar.set_retain_dtype(val); + HostTensorStorage storage(cn); + auto* raw_ptr = reinterpret_cast(new max_ctype_t()); + std::shared_ptr raw_storage = { + raw_ptr, [](dt_byte* ptr) { delete reinterpret_cast(ptr); }}; + storage.only_reset_raw_storage(cn, dtype.size(), raw_storage, 0); + std::memcpy(storage.ptr(), scalar.storage(), dtype.size()); + return HostStorage::make(std::move(storage)); +} + +template +auto vec2storage(Span vec, CompNode cn, DType dtype) { + mgb_assert(vec.size() <= MEGDNN_MAX_NDIM); + // TODO: use storage cache and modify ConstTensorCache to return (Host, Device) + auto* raw_ptr = new ctype[MEGDNN_MAX_NDIM]; + for (size_t i = 0; i < vec.size(); ++i) { + raw_ptr[i] = vec[i].get_cast(); } - if (auto* t = try_cast(tup[0].ptr())) { - if (nargs > 1) { - throw py::type_error("expect 1 argument"); + mgb_assert(sizeof(ctype) == dtype.size()); + std::shared_ptr raw_storage = { + reinterpret_cast(raw_ptr), + [](dt_byte* ptr) { delete[] reinterpret_cast(ptr); }}; + HostTensorStorage storage(cn); + storage.only_reset_raw_storage(cn, sizeof(ctype) * vec.size(), raw_storage, 0); + return HostStorage::make(std::move(storage)); +} + +struct HostTensorArgs { + ValueShape shape; + DType dtype; + HostStorage::ref_t storage; + + HostTensorND as_tensor_nd() const { + HostTensorND ret(CompNode::default_cpu(), shape.as_tensor_shape(), dtype); + ret.only_reset_raw_storage(*storage); + return ret; + } +}; + +template +bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) { + auto size = obj.size(); + if (size > MEGDNN_MAX_NDIM) { + return false; + } + ctype items[size]; + for (size_t i = 0; i < size; ++i) { + py::handle item = obj[i]; + if (item.get_type().is(py_type())) { + items[i] = (ctype)(dt_int32)item.template cast(); + } else if (item.get_type().is(py_type())) { + items[i] = (ctype)(dt_float32)item.template cast(); + } else { + return false; } - m_tensor = t->m_tensor->copy(); - } else { - if (nargs == 1) { - auto arg0 = PyTuple_GetItem(args, 0); - // for DeviceTensorND - if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { - auto dv = py::handle(arg0).cast(); - m_tensor = std::make_shared(imperative::apply( - CreateTensor(CreateTensor::Common, dv.comp_node(), dv.layout()), - DeviceStorage::make(dv.storage()))[0]); - } else { - throw py::type_error( - "single argument is not tensor, varnode or devicetensor"); + } + mgb_assert(sizeof(ctype) == dtype.size()); + auto* raw_ptr = new ctype[size]; + std::shared_ptr raw_storage = { + reinterpret_cast(raw_ptr), + [](dt_byte* ptr) { delete[] reinterpret_cast(ptr); }}; + HostTensorStorage storage(cn); + storage.only_reset_raw_storage(cn, sizeof(ctype) * size, raw_storage, 0); + std::memcpy(storage.ptr(), items, sizeof(ctype) * size); + ret.dtype = dtype; + ret.shape = {size}; + ret.storage = HostStorage::make(std::move(storage)); + return true; +} + +template +bool pyseq2hval(seq_type obj, CompNode cn, HostTensorArgs& ret) { + auto size = obj.size(); + if (size > MEGDNN_MAX_NDIM) { + return false; + } + DTypeScalar items[size]; + DType dtype; + for (size_t i = 0; i < size; ++i) { + auto&& item = obj[i]; + if (item.get_type().is(py_type())) { + items[i] = (dt_int32)item.template cast(); + if (!dtype.valid()) { + dtype = dtype::Int32(); + } else if (dtype != dtype::Int32() && dtype != dtype::Float32()) { + return false; } + } else if (item.get_type().is(py_type())) { + items[i] = (dt_float32)item.template cast(); + if (!dtype.valid()) { + dtype = dtype::Float32(); + } else if (dtype == dtype::Int32()) { + dtype = dtype::Float32(); + } else if (dtype != dtype::Float32()) { + return false; + } + } else { + return false; + } + } + if (!dtype.valid()) { + dtype = dtype::Float32(); + } + ret.dtype = dtype; + ret.shape = {size}; + if (dtype == dtype::Int32()) { + ret.storage = vec2storage({items, size}, cn, dtype); + } else if (dtype == dtype::Float32()) { + ret.storage = vec2storage({items, size}, cn, dtype); + } else { + mgb_assert(false); + } + return true; +} + +template +bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) { + if (dtype == dtype::Int32()) { + return pyseq2hval(obj, cn, dtype, ret); + } else if (dtype == dtype::Float32()) { + return pyseq2hval(obj, cn, dtype, ret); + } else if (!dtype.valid()) { + return pyseq2hval(obj, cn, ret); + } else { + return false; + } +} + +bool pyarr2hval(py::array obj, CompNode cn, DType dtype, HostTensorArgs& ret) { + auto data = obj.cast(); + auto strides = data.strides(); + bool need_squeeze = false; + for (size_t i = 0; i < data.ndim(); ++i) { + if (strides[i] == 0) { + need_squeeze = true; + break; + } + } + if (need_squeeze) { + std::vector shape; + for (size_t i = 0; i < data.ndim(); ++i) { + shape.push_back(data.shape(i)); + } + data = data.squeeze(); + data.resize(shape); + } + HostTensorND retnd(cn); + retnd = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&retnd), dtype); + if (!dtype.valid()) { + dtype = retnd.dtype(); + } + mgb_assert( + retnd.layout().is_empty() || retnd.layout().is_contiguous(), + "host value should be continuous"); + for (size_t i = 0; i < data.ndim(); ++i) { + ret.shape[ret.shape.ndim++] = data.shape(i); + } + ret.dtype = dtype; + ret.storage = HostStorage::make(retnd.storage()); + return true; +} + +bool pyint2hval(py::int_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) { + if (!dtype.valid()) { + dtype = dtype::Int32(); + } + ret.dtype = dtype; + ret.storage = scalar2storage((dt_int32)obj, cn, dtype); + return true; +} + +bool pyfloat2hval(py::float_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) { + if (!dtype.valid()) { + dtype = dtype::Float32(); + } + ret.dtype = dtype; + ret.storage = scalar2storage((dt_float32)obj, cn, dtype); + return true; +} + +HostTensorArgs pyobj2hval(py::object obj, CompNode cn, DType dtype) { + HostTensorArgs ret; + bool success = false; + // check order: float -> int -> tuple(int -> float) -> list(int -> float) + // only handle `exact` pytype, isinstance also accepts subtype + // for example, isinstance(True, int) == True + if (obj.get_type().is(py_type())) { + success = pyfloat2hval(py::float_(obj), cn, dtype, ret); + } else if (obj.get_type().is(py_type())) { // py::bool_ is py::int_ + success = pyint2hval(py::int_(obj), cn, dtype, ret); + } else if (obj.get_type().is(py_type())) { + success = pyseq2hval(py::tuple(obj), cn, dtype, ret); + } else if (obj.get_type().is(py_type())) { + success = pyseq2hval(py::list(obj), cn, dtype, ret); + } else if (obj.is_none()) { + obj = py::list(0); + } + if (!success) { + success = pyarr2hval(obj, cn, dtype, ret); + } + mgb_assert(success); + return ret; +} + +struct PyArgDesc { + const char* name; + py::object (*default_value)(); +}; + +struct PyArgDescs { + std::vector items; + ssize_t (*name2idx)(const char* name); +}; + +py::tuple parse_args(py::tuple args, const PyArgDescs& descs) { + size_t nr_args = args.size(); + size_t nr_items = descs.items.size(); + mgb_assert(nr_args <= nr_items, "too many args"); + if (nr_args == nr_items) { + return args; + } + py::tuple ret(nr_items); + for (size_t i = 0; i < nr_args; ++i) { + ret[i] = args[i]; + } + for (size_t i = nr_args; i < nr_items; ++i) { + ret[i] = descs.items[i].default_value(); + } + return ret; +} + +py::tuple parse_args_and_kwargs( + py::tuple args, py::dict kwargs, const PyArgDescs& descs) { + size_t nr_args = args.size(); + size_t nr_kwargs = kwargs.size(); + size_t nr_items = descs.items.size(); + mgb_assert(nr_args + nr_kwargs <= nr_items, "too many args"); + if (nr_args == nr_items) { + return args; + } + py::tuple ret(nr_items); + for (size_t i = 0; i < nr_args; ++i) { + ret[i] = args[i]; + } + bool has_value[nr_items - nr_args]; + for (size_t i = nr_args; i < nr_items; ++i) { + has_value[i - nr_args] = false; + } + for (auto&& [k, v] : kwargs) { + auto key = py::str(k).cast(); + ssize_t index = descs.name2idx(key.c_str()); + mgb_assert(index >= nr_args); + ret[index] = v; + has_value[index - nr_args] = true; + } + for (size_t i = nr_args; i < nr_items; ++i) { + if (!has_value[i - nr_args]) { + ret[i] = descs.items[i].default_value(); + } + } + return ret; +} + +CompNode as_comp_node(const std::string& name) { + thread_local struct { + std::string name; + CompNode cn; + } cached; + if (cached.name != name) { + cached.name = name; + cached.cn = CompNode::load(name); + } + return cached.cn; +} + +CompNode as_comp_node(py::object py_device) { + std::optional device_name; + if (py_device.is_none() || py::str::check_(py_device)) { + auto cls = py::handle(reinterpret_cast(py_tensor_type)); + auto dmap_callback = cls.attr("dmap_callback"); + std::string name; + if (dmap_callback.is_none() && py_device.is_none()) { + name = get_default_device(); } else { - py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType - if (nargs != 5 && nargs != 6) { - throw py::type_error("expect 5 or 6 arguments"); + if (py_device.is_none()) { + py_device = py::str(get_default_device()); } - auto data = tup[0].cast(); - DType dtype = tup[1].cast(); - CompNode cn = tup[2].cast(); - bool is_const = tup[3].cast(); - bool no_cache = nargs == 6 ? tup[4].cast() : false; - std::string name; - if (tup[nargs - 1].ptr() != Py_None) - name = tup[nargs - 1].cast(); - - // const op - { - CreateTensor::Kind kind = is_const ? CreateTensor::Const - : no_cache ? CreateTensor::Unique - : CreateTensor::Common; - HostTensorND ret(cn); - ret = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype); - mgb_assert( - ret.layout().is_empty() || ret.layout().is_contiguous(), - "host value should be continuous"); - ValueShape shape; - for (size_t i = 0; i < data.ndim(); ++i) { - shape[shape.ndim++] = data.shape(i); - } - m_tensor = std::make_shared(imperative::apply( - CreateTensor(kind, cn, ret.dtype(), shape), - HostStorage::make(ret.storage()))[0]); + if (!dmap_callback.is_none()) { + py_device = dmap_callback(py_device); } + name = py::str(py_device).cast(); + } + return as_comp_node(name); + } else { + if (py::isinstance(py_device, py_device_type)) { + py_device = py_device.attr("_cn"); + } + mgb_assert(py::isinstance(py_device, py_comp_node_type)); + return py_device.cast(); + } +} - if (!name.empty()) { - m_tensor->reset( - imperative::apply(RenameValue(name), m_tensor->data())[0]); - } +template +bool compare_cstr(const char* cstr) { + return (((*cstr++) == Chars) && ...) && *cstr == '\0'; +} + +ssize_t name2idx(const char* name) { + const char* ch = name; + // TODO: trie + // clang-format off + switch (*ch++) { + case 'd': + switch (*ch++) { + // data + case 'a': return compare_cstr<'t', 'a'>(ch) ? 0 : -1; + // dtype + case 't': return compare_cstr<'y', 'p', 'e'>(ch) ? 1 : -1; + // device + case 'e': return compare_cstr<'v', 'i', 'c', 'e'>(ch) ? 2 : -1; + } + case 'i': + // is_const + return compare_cstr<'s', '_', 'c', 'o', 'n', 's', 't'>(ch) ? 3 : -1; + case 'n': + switch (*ch++) { + // no_cache + case 'o': return compare_cstr<'_', 'c', 'a', 'c', 'h', 'e'>(ch) ? 4 : -1; + // name + case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1; + } + } + // clang-format on + return -1; +} + +} // namespace + +TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { + static PyArgDescs descs = { + { + {"data", []() -> py::object { return py::none(); }}, + {"dtype", []() -> py::object { return py::none(); }}, + {"device", []() -> py::object { return py::none(); }}, + {"is_const", []() -> py::object { return py::bool_(false); }}, + {"no_cache", []() -> py::object { return py::bool_(false); }}, + {"name", []() -> py::object { return py::none(); }}, + }, + name2idx}; + py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType + auto tup = py::reinterpret_borrow(args); + if (kwargs) { + tup = parse_args_and_kwargs( + tup, py::reinterpret_borrow(kwargs), descs); + } else { + tup = parse_args(tup, descs); + } + mgb_assert(tup.size() == 6); + if (auto* t = try_cast(tup[0].ptr())) { + m_tensor = t->m_tensor->copy(); + } else { + auto data = tup[0]; + DType dtype = tup[1].cast(); + bool is_const = tup[3].cast(); + bool no_cache = tup[4].cast(); + std::string name; + if (!tup[5].is_none()) { + name = tup[5].cast(); + } + CompNode cn = as_comp_node(tup[2]); + + { + CreateTensor::Kind kind = is_const ? CreateTensor::Const + : no_cache ? CreateTensor::Unique + : CreateTensor::Common; + auto&& hval = pyobj2hval(data, cn, dtype); + auto val = imperative::apply( + CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0]; + m_tensor.emplace(val); + } + + if (!name.empty()) { + m_tensor->reset(imperative::apply(RenameValue(name), m_tensor->data())[0]); } } mgb_assert(m_tensor->data()); @@ -402,17 +744,16 @@ PyObject* TensorWrapper::isscalar() { } struct TensorWeakRef { - std::weak_ptr wptr; + ValueWeakRef data; - TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {} + TensorWeakRef(const TensorWrapper& tw) : data(tw.m_tensor->data()) {} py::object operator()() { - if (auto p = wptr.lock()) { + if (auto p = data.lock()) { return TensorWrapper::make(py_tensor_type, p); } return py::none(); } - int _use_cnt() { return wptr.use_count(); } }; #ifdef METH_FASTCALL @@ -528,7 +869,6 @@ void init_tensor(py::module m) { // TODO: remove this .def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_drop>("_drop") - .def<&TensorWrapper::_use_cnt>("_use_cnt") .def<&TensorWrapper::_detail>("_detail") .def<&TensorWrapper::_set_name>("_set_name") .def<&TensorWrapper::_watch>("_watch") @@ -542,8 +882,7 @@ void init_tensor(py::module m) { py::class_(m, "TensorWeakRef") .def(py::init()) - .def("__call__", &TensorWeakRef::operator()) - .def("_use_cnt", &TensorWeakRef::_use_cnt); + .def("__call__", &TensorWeakRef::operator()); py::class_>(m, "SymbolVar") .def_property_readonly( @@ -693,6 +1032,9 @@ void init_tensor(py::module m) { py_tensor_type = reinterpret_cast(type_obj.inc_ref().ptr()); }); + m.def("set_py_device_type", + [](py::object type_obj) { py_device_type = type_obj.inc_ref(); }); + /** * \brief trace proxy * diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 33b616ea..6b920c50 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -38,13 +38,14 @@ namespace mgb::imperative::python { extern interpreter::Interpreter::Channel* interpreter_for_py; extern PyTypeObject* py_tensor_type; +extern pybind11::handle py_device_type; extern PyObject* cpp_use_symbolic_shape; extern PyObject* cpp_astensor1d; -struct Tensor : NonCopyableObj { +struct Tensor { private: - std::string m_name; ValueRef m_data; + std::string m_name; public: using Handle = interpreter::Interpreter::Handle; @@ -53,11 +54,7 @@ public: ~Tensor() = default; - inline std::shared_ptr copy() { - auto ret = std::make_shared(m_data); - ret->m_name = m_name; - return ret; - } + inline Tensor copy() { return *this; } inline DType dtype() { return *data().dtype(); } inline CompNode comp_node() { return *data().device(); } @@ -75,7 +72,7 @@ public: set_name(m_name); } } - inline ValueRef data() { return m_data.unwrap(); } + inline ValueRef data() const { return m_data.unwrap(); } bool is_scalar() { return data().is_scalar(); } inline std::string name() { return m_name; } inline void set_name(std::string name) { @@ -89,14 +86,9 @@ public: struct TensorWrapper { public: - std::shared_ptr m_tensor; - - inline TensorWrapper(std::shared_ptr tensor = {}) - : m_tensor(std::move(tensor)) { - mgb_assert(tensor, "empty storage"); - } + std::optional m_tensor; - inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared(value)) {} + inline TensorWrapper(ValueRef value) { m_tensor.emplace(value); } TensorWrapper(PyObject* args, PyObject* kwargs); ~TensorWrapper() = default; @@ -144,7 +136,6 @@ public: PyObject* module_trace_info(); void set_module_trace_info(PyObject*); void _set_name(PyObject*); - PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; PyObject* _detail(); void _watch(); }; diff --git a/imperative/python/test/unit/core/test_tensor_wrapper.py b/imperative/python/test/unit/core/test_tensor_wrapper.py index ebdbf8ed..b984d536 100644 --- a/imperative/python/test/unit/core/test_tensor_wrapper.py +++ b/imperative/python/test/unit/core/test_tensor_wrapper.py @@ -220,3 +220,10 @@ def test_tensor_type(): y1 = x1 + x2 y2 = x2 + x1 assert type(y1) == type(y2) + + +def test_tensor_from_bool(): + x = Tensor(True) + assert x.dtype == np.bool_ + x = Tensor([True, False]) + assert x.dtype == np.bool_ diff --git a/src/core/include/megbrain/dtype.h b/src/core/include/megbrain/dtype.h index b4f24e30..0fdfc15f 100644 --- a/src/core/include/megbrain/dtype.h +++ b/src/core/include/megbrain/dtype.h @@ -46,7 +46,7 @@ namespace dtype = ::megdnn::dtype; * \param nr_elem number of elements to write in *dest* */ template -void static_cast_dtype( +MGE_WIN_DECLSPEC_FUC void static_cast_dtype( T* dest, DType src_type, const void* storage, size_t nr_elem = 1); /*! -- GitLab