提交 787a22a9 编写于 作者: M Megvii Engine Team

perf(tensor): implement __new__ in cpp

GitOrigin-RevId: 4defd249c3ca673ec67648c4f8aaa1858eb00447
上级 99df4a79
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import numpy as np import numpy as np
from ._imperative_rt import CompNode from ._imperative_rt import CompNode
from ._imperative_rt.core2 import set_py_device_type
class Device: class Device:
...@@ -53,3 +54,6 @@ def as_device(obj): ...@@ -53,3 +54,6 @@ def as_device(obj):
if isinstance(obj, Device): if isinstance(obj, Device):
return obj return obj
return Device(obj) return Device(obj)
set_py_device_type(Device)
...@@ -72,39 +72,6 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -72,39 +72,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
_short_name = None _short_name = None
_prefix = None _prefix = None
def __new__(
cls,
data: Union["Tensor", np.ndarray, list, int, float] = None,
dtype: np.dtype = None,
device: str = None,
is_const: bool = False,
no_cache: bool = False,
name: str = None,
):
if data is None:
data = []
if device is None:
cn = get_default_device()
elif isinstance(device, str):
if cls.dmap_callback is not None:
cn = CompNode(cls.dmap_callback(device))
else:
cn = CompNode(device)
else:
if isinstance(device, CompNode):
cn = device
else:
cn = device._cn
if isinstance(data, _Tensor):
obj = _Tensor.__new__(cls, data)
else:
if isinstance(data, np.ndarray):
if 0 in data.strides:
data = data.squeeze().reshape(data.shape)
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name)
return obj
def __init__( def __init__(
self, self,
data: Union["Tensor", np.ndarray, list, int, float], data: Union["Tensor", np.ndarray, list, int, float],
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "./common.h" #include "./common.h"
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <pybind11/pytypes.h>
#include "./helper.h" #include "./helper.h"
#include "./numpy_dtypes.h" #include "./numpy_dtypes.h"
...@@ -56,6 +57,8 @@ std::string get_default_device() { ...@@ -56,6 +57,8 @@ std::string get_default_device() {
return default_device; return default_device;
} }
py::handle py_comp_node_type;
void init_common(py::module m) { void init_common(py::module m) {
auto PyCompNode = auto PyCompNode =
py::class_<CompNode>(m, "CompNode") py::class_<CompNode>(m, "CompNode")
...@@ -117,6 +120,8 @@ void init_common(py::module m) { ...@@ -117,6 +120,8 @@ void init_common(py::module m) {
}, },
[](py::str cn) { return CompNode::load(cn); })); [](py::str cn) { return CompNode::load(cn); }));
py_comp_node_type = PyCompNode.inc_ref();
py::class_<CompNode::Event, std::shared_ptr<CompNode::Event>>(PyCompNode, "Event") py::class_<CompNode::Event, std::shared_ptr<CompNode::Event>>(PyCompNode, "Event")
.def("record", &CompNode::Event::record) .def("record", &CompNode::Event::record)
.def("wait", &CompNode::Event::host_wait); .def("wait", &CompNode::Event::host_wait);
......
...@@ -17,3 +17,5 @@ void init_common(pybind11::module m); ...@@ -17,3 +17,5 @@ void init_common(pybind11::module m);
void set_default_device(const std::string& device); void set_default_device(const std::string& device);
std::string get_default_device(); std::string get_default_device();
extern pybind11::handle py_comp_node_type;
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <pybind11/pytypes.h> #include <pybind11/pytypes.h>
#include <pyerrors.h> #include <pyerrors.h>
#include <iterator>
#include <range/v3/all.hpp> #include <range/v3/all.hpp>
#include <string> #include <string>
...@@ -108,6 +109,7 @@ struct SymbolVarContext { ...@@ -108,6 +109,7 @@ struct SymbolVarContext {
interpreter::Interpreter::Channel* interpreter_for_py = nullptr; interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
PyTypeObject* py_tensor_type = nullptr; PyTypeObject* py_tensor_type = nullptr;
pybind11::handle py_device_type = nullptr;
PyObject* cpp_use_symbolic_shape; PyObject* cpp_use_symbolic_shape;
#define REGISTE_APPLY_FUNC(mode) \ #define REGISTE_APPLY_FUNC(mode) \
...@@ -233,70 +235,410 @@ PyObject* py_apply( ...@@ -233,70 +235,410 @@ PyObject* py_apply(
PYEXT17_TRANSLATE_EXC_RET(nullptr) PYEXT17_TRANSLATE_EXC_RET(nullptr)
} }
TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { namespace {
if (kwargs && PyDict_Size(kwargs)) {
throw py::type_error("keyword argument not allowed"); template <typename T>
py::handle py_type() {
if constexpr (std::is_same_v<T, py::int_>) {
return (PyObject*)&PyLong_Type;
} else if constexpr (std::is_same_v<T, py::float_>) {
return (PyObject*)&PyFloat_Type;
} else if constexpr (std::is_same_v<T, py::tuple>) {
return (PyObject*)&PyTuple_Type;
} else if constexpr (std::is_same_v<T, py::list>) {
return (PyObject*)&PyList_Type;
} else {
static_assert(std::is_same_v<T, T>);
} }
auto nargs = PyTuple_Size(args); }
auto tup = py::reinterpret_borrow<py::tuple>(args);
if (nargs == 0) { template <typename T>
throw py::type_error("too few arguments"); auto scalar2storage(T val, CompNode cn, DType dtype) {
using max_ctype_t = DTypeScalar::max_ctype;
DTypeScalar scalar(dtype);
scalar.set_retain_dtype(val);
HostTensorStorage storage(cn);
auto* raw_ptr = reinterpret_cast<dt_byte*>(new max_ctype_t());
std::shared_ptr<dt_byte> raw_storage = {
raw_ptr, [](dt_byte* ptr) { delete reinterpret_cast<max_ctype_t*>(ptr); }};
storage.only_reset_raw_storage(cn, dtype.size(), raw_storage, 0);
std::memcpy(storage.ptr(), scalar.storage(), dtype.size());
return HostStorage::make(std::move(storage));
}
template <typename ctype>
auto vec2storage(Span<DTypeScalar> vec, CompNode cn, DType dtype) {
mgb_assert(vec.size() <= MEGDNN_MAX_NDIM);
// TODO: use storage cache and modify ConstTensorCache to return (Host, Device)
auto* raw_ptr = new ctype[MEGDNN_MAX_NDIM];
for (size_t i = 0; i < vec.size(); ++i) {
raw_ptr[i] = vec[i].get_cast<ctype>();
}
mgb_assert(sizeof(ctype) == dtype.size());
std::shared_ptr<dt_byte> raw_storage = {
reinterpret_cast<dt_byte*>(raw_ptr),
[](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
HostTensorStorage storage(cn);
storage.only_reset_raw_storage(cn, sizeof(ctype) * vec.size(), raw_storage, 0);
return HostStorage::make(std::move(storage));
}
struct HostTensorArgs {
ValueShape shape;
DType dtype;
HostStorage::ref_t storage;
HostTensorND as_tensor_nd() const {
HostTensorND ret(CompNode::default_cpu(), shape.as_tensor_shape(), dtype);
ret.only_reset_raw_storage(*storage);
return ret;
} }
if (auto* t = try_cast(tup[0].ptr())) { };
if (nargs > 1) {
throw py::type_error("expect 1 argument"); template <typename seq_type, typename ctype>
bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
auto size = obj.size();
if (size > MEGDNN_MAX_NDIM) {
return false;
} }
m_tensor = t->m_tensor->copy(); ctype items[size];
for (size_t i = 0; i < size; ++i) {
py::handle item = obj[i];
if (item.get_type().is(py_type<py::int_>())) {
items[i] = (ctype)(dt_int32)item.template cast<py::int_>();
} else if (item.get_type().is(py_type<py::float_>())) {
items[i] = (ctype)(dt_float32)item.template cast<py::float_>();
} else { } else {
if (nargs == 1) { return false;
auto arg0 = PyTuple_GetItem(args, 0); }
// for DeviceTensorND }
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { mgb_assert(sizeof(ctype) == dtype.size());
auto dv = py::handle(arg0).cast<DeviceTensorND>(); auto* raw_ptr = new ctype[size];
m_tensor = std::make_shared<Tensor>(imperative::apply( std::shared_ptr<dt_byte> raw_storage = {
CreateTensor(CreateTensor::Common, dv.comp_node(), dv.layout()), reinterpret_cast<dt_byte*>(raw_ptr),
DeviceStorage::make(dv.storage()))[0]); [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
HostTensorStorage storage(cn);
storage.only_reset_raw_storage(cn, sizeof(ctype) * size, raw_storage, 0);
std::memcpy(storage.ptr(), items, sizeof(ctype) * size);
ret.dtype = dtype;
ret.shape = {size};
ret.storage = HostStorage::make(std::move(storage));
return true;
}
template <typename seq_type>
bool pyseq2hval(seq_type obj, CompNode cn, HostTensorArgs& ret) {
auto size = obj.size();
if (size > MEGDNN_MAX_NDIM) {
return false;
}
DTypeScalar items[size];
DType dtype;
for (size_t i = 0; i < size; ++i) {
auto&& item = obj[i];
if (item.get_type().is(py_type<py::int_>())) {
items[i] = (dt_int32)item.template cast<py::int_>();
if (!dtype.valid()) {
dtype = dtype::Int32();
} else if (dtype != dtype::Int32() && dtype != dtype::Float32()) {
return false;
}
} else if (item.get_type().is(py_type<py::float_>())) {
items[i] = (dt_float32)item.template cast<py::float_>();
if (!dtype.valid()) {
dtype = dtype::Float32();
} else if (dtype == dtype::Int32()) {
dtype = dtype::Float32();
} else if (dtype != dtype::Float32()) {
return false;
}
} else {
return false;
}
}
if (!dtype.valid()) {
dtype = dtype::Float32();
}
ret.dtype = dtype;
ret.shape = {size};
if (dtype == dtype::Int32()) {
ret.storage = vec2storage<dt_int32>({items, size}, cn, dtype);
} else if (dtype == dtype::Float32()) {
ret.storage = vec2storage<dt_float32>({items, size}, cn, dtype);
} else {
mgb_assert(false);
}
return true;
}
template <typename seq_type>
bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
if (dtype == dtype::Int32()) {
return pyseq2hval<seq_type, dt_int32>(obj, cn, dtype, ret);
} else if (dtype == dtype::Float32()) {
return pyseq2hval<seq_type, dt_float32>(obj, cn, dtype, ret);
} else if (!dtype.valid()) {
return pyseq2hval<seq_type>(obj, cn, ret);
} else {
return false;
}
}
bool pyarr2hval(py::array obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
auto data = obj.cast<py::array>();
auto strides = data.strides();
bool need_squeeze = false;
for (size_t i = 0; i < data.ndim(); ++i) {
if (strides[i] == 0) {
need_squeeze = true;
break;
}
}
if (need_squeeze) {
std::vector<size_t> shape;
for (size_t i = 0; i < data.ndim(); ++i) {
shape.push_back(data.shape(i));
}
data = data.squeeze();
data.resize(shape);
}
HostTensorND retnd(cn);
retnd = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&retnd), dtype);
if (!dtype.valid()) {
dtype = retnd.dtype();
}
mgb_assert(
retnd.layout().is_empty() || retnd.layout().is_contiguous(),
"host value should be continuous");
for (size_t i = 0; i < data.ndim(); ++i) {
ret.shape[ret.shape.ndim++] = data.shape(i);
}
ret.dtype = dtype;
ret.storage = HostStorage::make(retnd.storage());
return true;
}
bool pyint2hval(py::int_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
if (!dtype.valid()) {
dtype = dtype::Int32();
}
ret.dtype = dtype;
ret.storage = scalar2storage((dt_int32)obj, cn, dtype);
return true;
}
bool pyfloat2hval(py::float_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
if (!dtype.valid()) {
dtype = dtype::Float32();
}
ret.dtype = dtype;
ret.storage = scalar2storage((dt_float32)obj, cn, dtype);
return true;
}
HostTensorArgs pyobj2hval(py::object obj, CompNode cn, DType dtype) {
HostTensorArgs ret;
bool success = false;
// check order: float -> int -> tuple(int -> float) -> list(int -> float)
// only handle `exact` pytype, isinstance also accepts subtype
// for example, isinstance(True, int) == True
if (obj.get_type().is(py_type<py::float_>())) {
success = pyfloat2hval(py::float_(obj), cn, dtype, ret);
} else if (obj.get_type().is(py_type<py::int_>())) { // py::bool_ is py::int_
success = pyint2hval(py::int_(obj), cn, dtype, ret);
} else if (obj.get_type().is(py_type<py::tuple>())) {
success = pyseq2hval<py::tuple>(py::tuple(obj), cn, dtype, ret);
} else if (obj.get_type().is(py_type<py::list>())) {
success = pyseq2hval<py::list>(py::list(obj), cn, dtype, ret);
} else if (obj.is_none()) {
obj = py::list(0);
}
if (!success) {
success = pyarr2hval(obj, cn, dtype, ret);
}
mgb_assert(success);
return ret;
}
struct PyArgDesc {
const char* name;
py::object (*default_value)();
};
struct PyArgDescs {
std::vector<PyArgDesc> items;
ssize_t (*name2idx)(const char* name);
};
py::tuple parse_args(py::tuple args, const PyArgDescs& descs) {
size_t nr_args = args.size();
size_t nr_items = descs.items.size();
mgb_assert(nr_args <= nr_items, "too many args");
if (nr_args == nr_items) {
return args;
}
py::tuple ret(nr_items);
for (size_t i = 0; i < nr_args; ++i) {
ret[i] = args[i];
}
for (size_t i = nr_args; i < nr_items; ++i) {
ret[i] = descs.items[i].default_value();
}
return ret;
}
py::tuple parse_args_and_kwargs(
py::tuple args, py::dict kwargs, const PyArgDescs& descs) {
size_t nr_args = args.size();
size_t nr_kwargs = kwargs.size();
size_t nr_items = descs.items.size();
mgb_assert(nr_args + nr_kwargs <= nr_items, "too many args");
if (nr_args == nr_items) {
return args;
}
py::tuple ret(nr_items);
for (size_t i = 0; i < nr_args; ++i) {
ret[i] = args[i];
}
bool has_value[nr_items - nr_args];
for (size_t i = nr_args; i < nr_items; ++i) {
has_value[i - nr_args] = false;
}
for (auto&& [k, v] : kwargs) {
auto key = py::str(k).cast<std::string>();
ssize_t index = descs.name2idx(key.c_str());
mgb_assert(index >= nr_args);
ret[index] = v;
has_value[index - nr_args] = true;
}
for (size_t i = nr_args; i < nr_items; ++i) {
if (!has_value[i - nr_args]) {
ret[i] = descs.items[i].default_value();
}
}
return ret;
}
CompNode as_comp_node(const std::string& name) {
thread_local struct {
std::string name;
CompNode cn;
} cached;
if (cached.name != name) {
cached.name = name;
cached.cn = CompNode::load(name);
}
return cached.cn;
}
CompNode as_comp_node(py::object py_device) {
std::optional<std::string> device_name;
if (py_device.is_none() || py::str::check_(py_device)) {
auto cls = py::handle(reinterpret_cast<PyObject*>(py_tensor_type));
auto dmap_callback = cls.attr("dmap_callback");
std::string name;
if (dmap_callback.is_none() && py_device.is_none()) {
name = get_default_device();
} else { } else {
throw py::type_error( if (py_device.is_none()) {
"single argument is not tensor, varnode or devicetensor"); py_device = py::str(get_default_device());
}
if (!dmap_callback.is_none()) {
py_device = dmap_callback(py_device);
} }
name = py::str(py_device).cast<std::string>();
}
return as_comp_node(name);
} else { } else {
if (py::isinstance(py_device, py_device_type)) {
py_device = py_device.attr("_cn");
}
mgb_assert(py::isinstance(py_device, py_comp_node_type));
return py_device.cast<CompNode>();
}
}
template <char... Chars>
bool compare_cstr(const char* cstr) {
return (((*cstr++) == Chars) && ...) && *cstr == '\0';
}
ssize_t name2idx(const char* name) {
const char* ch = name;
// TODO: trie
// clang-format off
switch (*ch++) {
case 'd':
switch (*ch++) {
// data
case 'a': return compare_cstr<'t', 'a'>(ch) ? 0 : -1;
// dtype
case 't': return compare_cstr<'y', 'p', 'e'>(ch) ? 1 : -1;
// device
case 'e': return compare_cstr<'v', 'i', 'c', 'e'>(ch) ? 2 : -1;
}
case 'i':
// is_const
return compare_cstr<'s', '_', 'c', 'o', 'n', 's', 't'>(ch) ? 3 : -1;
case 'n':
switch (*ch++) {
// no_cache
case 'o': return compare_cstr<'_', 'c', 'a', 'c', 'h', 'e'>(ch) ? 4 : -1;
// name
case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1;
}
}
// clang-format on
return -1;
}
} // namespace
TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
static PyArgDescs descs = {
{
{"data", []() -> py::object { return py::none(); }},
{"dtype", []() -> py::object { return py::none(); }},
{"device", []() -> py::object { return py::none(); }},
{"is_const", []() -> py::object { return py::bool_(false); }},
{"no_cache", []() -> py::object { return py::bool_(false); }},
{"name", []() -> py::object { return py::none(); }},
},
name2idx};
py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
if (nargs != 5 && nargs != 6) { auto tup = py::reinterpret_borrow<py::tuple>(args);
throw py::type_error("expect 5 or 6 arguments"); if (kwargs) {
tup = parse_args_and_kwargs(
tup, py::reinterpret_borrow<py::dict>(kwargs), descs);
} else {
tup = parse_args(tup, descs);
} }
auto data = tup[0].cast<py::array>(); mgb_assert(tup.size() == 6);
if (auto* t = try_cast(tup[0].ptr())) {
m_tensor = t->m_tensor->copy();
} else {
auto data = tup[0];
DType dtype = tup[1].cast<DType>(); DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();
bool is_const = tup[3].cast<bool>(); bool is_const = tup[3].cast<bool>();
bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false; bool no_cache = tup[4].cast<bool>();
std::string name; std::string name;
if (tup[nargs - 1].ptr() != Py_None) if (!tup[5].is_none()) {
name = tup[nargs - 1].cast<std::string>(); name = tup[5].cast<std::string>();
}
CompNode cn = as_comp_node(tup[2]);
// const op
{ {
CreateTensor::Kind kind = is_const ? CreateTensor::Const CreateTensor::Kind kind = is_const ? CreateTensor::Const
: no_cache ? CreateTensor::Unique : no_cache ? CreateTensor::Unique
: CreateTensor::Common; : CreateTensor::Common;
HostTensorND ret(cn); auto&& hval = pyobj2hval(data, cn, dtype);
ret = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype); auto val = imperative::apply(
mgb_assert( CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0];
ret.layout().is_empty() || ret.layout().is_contiguous(), m_tensor.emplace(val);
"host value should be continuous");
ValueShape shape;
for (size_t i = 0; i < data.ndim(); ++i) {
shape[shape.ndim++] = data.shape(i);
}
m_tensor = std::make_shared<Tensor>(imperative::apply(
CreateTensor(kind, cn, ret.dtype(), shape),
HostStorage::make(ret.storage()))[0]);
} }
if (!name.empty()) { if (!name.empty()) {
m_tensor->reset( m_tensor->reset(imperative::apply(RenameValue(name), m_tensor->data())[0]);
imperative::apply(RenameValue(name), m_tensor->data())[0]);
}
} }
} }
mgb_assert(m_tensor->data()); mgb_assert(m_tensor->data());
...@@ -402,17 +744,16 @@ PyObject* TensorWrapper::isscalar() { ...@@ -402,17 +744,16 @@ PyObject* TensorWrapper::isscalar() {
} }
struct TensorWeakRef { struct TensorWeakRef {
std::weak_ptr<Tensor> wptr; ValueWeakRef data;
TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {} TensorWeakRef(const TensorWrapper& tw) : data(tw.m_tensor->data()) {}
py::object operator()() { py::object operator()() {
if (auto p = wptr.lock()) { if (auto p = data.lock()) {
return TensorWrapper::make(py_tensor_type, p); return TensorWrapper::make(py_tensor_type, p);
} }
return py::none(); return py::none();
} }
int _use_cnt() { return wptr.use_count(); }
}; };
#ifdef METH_FASTCALL #ifdef METH_FASTCALL
...@@ -528,7 +869,6 @@ void init_tensor(py::module m) { ...@@ -528,7 +869,6 @@ void init_tensor(py::module m) {
// TODO: remove this // TODO: remove this
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
.def<&TensorWrapper::_drop>("_drop") .def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::_use_cnt>("_use_cnt")
.def<&TensorWrapper::_detail>("_detail") .def<&TensorWrapper::_detail>("_detail")
.def<&TensorWrapper::_set_name>("_set_name") .def<&TensorWrapper::_set_name>("_set_name")
.def<&TensorWrapper::_watch>("_watch") .def<&TensorWrapper::_watch>("_watch")
...@@ -542,8 +882,7 @@ void init_tensor(py::module m) { ...@@ -542,8 +882,7 @@ void init_tensor(py::module m) {
py::class_<TensorWeakRef>(m, "TensorWeakRef") py::class_<TensorWeakRef>(m, "TensorWeakRef")
.def(py::init<const TensorWrapper&>()) .def(py::init<const TensorWrapper&>())
.def("__call__", &TensorWeakRef::operator()) .def("__call__", &TensorWeakRef::operator());
.def("_use_cnt", &TensorWeakRef::_use_cnt);
py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar") py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
.def_property_readonly( .def_property_readonly(
...@@ -693,6 +1032,9 @@ void init_tensor(py::module m) { ...@@ -693,6 +1032,9 @@ void init_tensor(py::module m) {
py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr()); py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
}); });
m.def("set_py_device_type",
[](py::object type_obj) { py_device_type = type_obj.inc_ref(); });
/** /**
* \brief trace proxy * \brief trace proxy
* *
......
...@@ -38,13 +38,14 @@ namespace mgb::imperative::python { ...@@ -38,13 +38,14 @@ namespace mgb::imperative::python {
extern interpreter::Interpreter::Channel* interpreter_for_py; extern interpreter::Interpreter::Channel* interpreter_for_py;
extern PyTypeObject* py_tensor_type; extern PyTypeObject* py_tensor_type;
extern pybind11::handle py_device_type;
extern PyObject* cpp_use_symbolic_shape; extern PyObject* cpp_use_symbolic_shape;
extern PyObject* cpp_astensor1d; extern PyObject* cpp_astensor1d;
struct Tensor : NonCopyableObj { struct Tensor {
private: private:
std::string m_name;
ValueRef m_data; ValueRef m_data;
std::string m_name;
public: public:
using Handle = interpreter::Interpreter::Handle; using Handle = interpreter::Interpreter::Handle;
...@@ -53,11 +54,7 @@ public: ...@@ -53,11 +54,7 @@ public:
~Tensor() = default; ~Tensor() = default;
inline std::shared_ptr<Tensor> copy() { inline Tensor copy() { return *this; }
auto ret = std::make_shared<Tensor>(m_data);
ret->m_name = m_name;
return ret;
}
inline DType dtype() { return *data().dtype(); } inline DType dtype() { return *data().dtype(); }
inline CompNode comp_node() { return *data().device(); } inline CompNode comp_node() { return *data().device(); }
...@@ -75,7 +72,7 @@ public: ...@@ -75,7 +72,7 @@ public:
set_name(m_name); set_name(m_name);
} }
} }
inline ValueRef data() { return m_data.unwrap(); } inline ValueRef data() const { return m_data.unwrap(); }
bool is_scalar() { return data().is_scalar(); } bool is_scalar() { return data().is_scalar(); }
inline std::string name() { return m_name; } inline std::string name() { return m_name; }
inline void set_name(std::string name) { inline void set_name(std::string name) {
...@@ -89,14 +86,9 @@ public: ...@@ -89,14 +86,9 @@ public:
struct TensorWrapper { struct TensorWrapper {
public: public:
std::shared_ptr<Tensor> m_tensor; std::optional<Tensor> m_tensor;
inline TensorWrapper(std::shared_ptr<Tensor> tensor = {})
: m_tensor(std::move(tensor)) {
mgb_assert(tensor, "empty storage");
}
inline TensorWrapper(ValueRef value) : m_tensor(std::make_shared<Tensor>(value)) {} inline TensorWrapper(ValueRef value) { m_tensor.emplace(value); }
TensorWrapper(PyObject* args, PyObject* kwargs); TensorWrapper(PyObject* args, PyObject* kwargs);
~TensorWrapper() = default; ~TensorWrapper() = default;
...@@ -144,7 +136,6 @@ public: ...@@ -144,7 +136,6 @@ public:
PyObject* module_trace_info(); PyObject* module_trace_info();
void set_module_trace_info(PyObject*); void set_module_trace_info(PyObject*);
void _set_name(PyObject*); void _set_name(PyObject*);
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
PyObject* _detail(); PyObject* _detail();
void _watch(); void _watch();
}; };
......
...@@ -220,3 +220,10 @@ def test_tensor_type(): ...@@ -220,3 +220,10 @@ def test_tensor_type():
y1 = x1 + x2 y1 = x1 + x2
y2 = x2 + x1 y2 = x2 + x1
assert type(y1) == type(y2) assert type(y1) == type(y2)
def test_tensor_from_bool():
x = Tensor(True)
assert x.dtype == np.bool_
x = Tensor([True, False])
assert x.dtype == np.bool_
...@@ -46,7 +46,7 @@ namespace dtype = ::megdnn::dtype; ...@@ -46,7 +46,7 @@ namespace dtype = ::megdnn::dtype;
* \param nr_elem number of elements to write in *dest* * \param nr_elem number of elements to write in *dest*
*/ */
template <typename T> template <typename T>
void static_cast_dtype( MGE_WIN_DECLSPEC_FUC void static_cast_dtype(
T* dest, DType src_type, const void* storage, size_t nr_elem = 1); T* dest, DType src_type, const void* storage, size_t nr_elem = 1);
/*! /*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册