diff --git a/imperative/python/megengine/core/ops/special.py b/imperative/python/megengine/core/ops/special.py index e378b8f075dc58249928d1698a1cfae019d38a60..4b2de494bdfaaa3c3871388534f550f024b45f04 100644 --- a/imperative/python/megengine/core/ops/special.py +++ b/imperative/python/megengine/core/ops/special.py @@ -8,7 +8,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np -from .._imperative_rt.core2 import Tensor +# from .._imperative_rt.core2 import Tensor from ..tensor.core import OpBase, TensorBase, apply @@ -19,5 +19,10 @@ class Const: self.device = device def __call__(self, *reference): - Wrapper = type(reference[0]) - return (Wrapper(self.value, self.dtype, self.device, True),) + from ...tensor import Tensor + + device = self.device + if device is None: + device = reference[0].device + + return (Tensor(self.value, self.dtype, self.device, True),) diff --git a/imperative/python/megengine/core/tensor/dtype.py b/imperative/python/megengine/core/tensor/dtype.py index 89a84a5a0954c6d5bcd5d27f72f9d8dfe717a5e4..0fbcab9ff2e042eb56ecb8a4aa794d605a76b6f0 100644 --- a/imperative/python/megengine/core/tensor/dtype.py +++ b/imperative/python/megengine/core/tensor/dtype.py @@ -13,6 +13,12 @@ import numpy as np # normal dtype related from .._imperative_rt import bfloat16, intb1, intb2, intb4 +from .._imperative_rt.common import ( + get_scale, + get_zero_point, + is_dtype_equal, + is_quantize, +) def is_lowbit(dtype): @@ -42,41 +48,6 @@ _metadata_dict = { } -def is_quantize(dtype): - return ( - hasattr(dtype, "metadata") - and dtype.metadata is not None - and "mgb_dtype" in dtype.metadata - ) - - -def get_scale(dtype): - assert is_quantize(dtype) - return dtype.metadata["mgb_dtype"]["scale"] - - -def get_zero_point(dtype): - assert is_quantize(dtype) - metadata = dtype.metadata["mgb_dtype"] - assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") - return metadata["zero_point"] - - -def is_equal(dt0, dt1): - def _get_zero_point(dtype): - assert is_quantize(dtype) - metadata = dtype.metadata["mgb_dtype"] - return metadata.get("zero_point") - - if is_quantize(dt0) and is_quantize(dt1): - return get_scale(dt0) == get_scale(dt1) and _get_zero_point( - dt0 - ) == _get_zero_point(dt1) - if not (is_quantize(dt0) or is_quantize(dt1)): - return dt0 == dt1 - return False - - def _check_zero_point(zp: int, dtype_str: str): qmin = _metadata_dict[dtype_str].qmin qmax = _metadata_dict[dtype_str].qmax diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index 8e952048ff9f7b09e9ddb40c56598ce286ce8005..8bc4a75f71ef9197661072c6f15052cb7bff6f2b 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -151,9 +151,9 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): def get_index(i): if not isinstance(i, (Tensor)): if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: - (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) + (i,) = Const(i, dtype=np.bool_, device=inp.device)() else: - (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) + (i,) = Const(i, dtype=np.int32, device=inp.device)() return i assert isinstance(i, Tensor) if i.dtype != np.bool_: @@ -197,7 +197,7 @@ def try_condtake(tensor, index): ): return [] if isinstance(index, np.ndarray): - (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) + (index,) = Const(index, dtype=np.bool_, device=tensor.device)() assert isinstance(index, Tensor) if not isinstance(tensor, Tensor): raise TypeError("input must be a tensor") @@ -217,9 +217,7 @@ def getitem(tensor, index): if isinstance(v.shape, v.__class__): break if len(v.shape) > 0 and v.shape[0] == 0: - (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( - tensor - ) + (empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() return empty_tensor if use_subtensor: op = builtin.Subtensor(items=items) @@ -240,8 +238,7 @@ def setitem(tensor, index, value): return tensor tensor = tensor.reshape(-1) if not isinstance(value, Tensor): - op = Const(value, dtype=tensor.dtype, device=tensor.device) - (value,) = op(tensor) + (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) for v in tensors: if len(v.shape) > 0 and v.shape[0] == 0: diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 52ef77318e559b7c5f75d3fceff3d42b827b1e6e..8ff86f00980a0bdeabd15f125c7f7b37d41352a1 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -11,10 +11,10 @@ from typing import Iterable, Union import numpy as np -from .._imperative_rt.core2 import Tensor, apply +from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device from ..ops import builtin from ..ops.special import Const -from .dtype import is_equal, is_quantize +from .dtype import is_dtype_equal, is_quantize from .megbrain_graph import VarNode _enable_convert_inputs = True @@ -37,94 +37,12 @@ def set_convert_inputs(flag): return backup -def dtype_promotion(inputs): - """ - Returns the dtype that would result from performing an arithmetic - operation on the provided input tensors and scalars. - """ - # map numpy.dtype.kind to priority - category_priority = { - "f": 3, # floating-point - "i": 2, # signed integer - "u": 2, # unsigned integer - "b": 1, # boolean - } - - def scalar2dtype(x): - """ - For scalar `x`, returns its corresponding type. A floating point scalar - has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'. - A boolean scalar has dtype 'bool'. - """ - if isinstance(x, bool): - return np.bool_ - if isinstance(x, int): - return np.int32 - if isinstance(x, float): - return np.float32 - - def promote_types(types, cat): - """ - Returns the data type with sufficient size to hold all types of - category `cat` in the list `types`. - """ - used_types = [ - i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat - ] - assert len(used_types) > 0 - res = used_types[0] - for i in used_types: - res = np.promote_types(res, i) - return res - - def max_priority(types): - """ - Returns the maximum value of the priority of each type in the list - `types`. - """ - if not types: - return 0 - else: - return max([category_priority.get(np.dtype(i).kind, 0) for i in types]) - - scalars = [] - tensors = [] - - for data in inputs: - if hasattr(data, "dtype"): - tensors.append(data.dtype) - elif isinstance(data, (float, int, bool)): - scalars.append(scalar2dtype(data)) - - max_pri_scalars = max_priority(scalars) - max_pri_tensors = max_priority(tensors) - - assert max_pri_scalars > 0 or max_pri_tensors > 0 - - if max_pri_scalars > max_pri_tensors: - return promote_types(scalars, max_pri_scalars) - else: - return promote_types(tensors, max_pri_tensors) - - -def get_device(inputs): - device = None - for i in inputs: - if isinstance(i, (Tensor, VarNode)): - if device is None: - device = i.device - elif device != i.device: - raise ValueError("ambiguous device: {} vs {}".format(device, i.device)) - assert device is not None - return device - - def concatenate(inputs, axis=0, *, device=None): dtype = dtype_promotion(inputs) device = get_device(inputs) def convert(x): - return convert_single_value(x, inputs, dtype=dtype) + return convert_single_value(x, dtype=dtype, device=device) inputs = tuple(map(convert, inputs)) (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) @@ -133,7 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): def astype(x, dtype): dtype = np.dtype(dtype) - if not is_equal(x.dtype, dtype): + if not is_dtype_equal(x.dtype, dtype): isscalar = x.isscalar() (x,) = apply(builtin.TypeCvt(dtype=dtype), x) if isscalar: @@ -141,13 +59,12 @@ def astype(x, dtype): return x -def convert_single_value(v, inputs, *, dtype=None, device=None): - tensors = [i for i in inputs if isinstance(i, (Tensor, VarNode))] - assert len(tensors) > 0 +def convert_single_value(v, *, dtype=None, device=None): if isinstance(v, (Tensor, VarNode)): - v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) + if not is_quantize(v.dtype): + v = astype(v, dtype) else: - (v,) = Const(v, dtype=dtype, device=device)(*tensors) + (v,) = Const(v, dtype=dtype, device=device)() return v @@ -161,7 +78,7 @@ def convert_inputs(*args: Tensor): def convert(value): if value is None: return value - return convert_single_value(value, args, dtype=dtype, device=device) + return convert_single_value(value, dtype=dtype, device=device) return tuple(map(convert, args)) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 0316a636eb6d280c976b18485adcfb607f47c80f..e40c18adc5785a6f71b5216716c0ba2e7a59592e 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -703,7 +703,7 @@ def topk( op = builtin.TopK(mode=mode) if not isinstance(k, Tensor): - (k,) = Const(k, dtype="int32", device=inp.device)(inp) + (k,) = Const(k, dtype="int32", device=inp.device)() if len(inp.shape) == 1: inp = inp.reshape(1, -1) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index b431a6c62ca308ab473dd8dd48580698a8e1e773..b9a5b4f75538bc758f167e191fc44bc89a80cf10 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -658,7 +658,7 @@ def batch_norm( def make_full_if_none(x, value): if x is None: - (x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) + (x,) = Const(value, dtype=inp.dtype, device=inp.device)() shape = utils.astensor1d( (1, C, 1, 1), inp, dtype="int32", device=inp.device ) @@ -1567,7 +1567,7 @@ def indexing_one_hot( """ assert isinstance(src, Tensor), "src must be of Tensor type" op = builtin.IndexingOneHot(axis=axis) - index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) + index = utils.convert_single_value(index, dtype="int32", device=src.device) (result,) = apply(op, src, index) if not keepdims: result = squeeze(result, axis) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 8cd60b8ea39ee8a4b9a928c6ec4fd67f9d7e748f..307d5f49145eadfde869646cb5c52fa4c2f25dcf 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -107,9 +107,7 @@ def full(shape, value, dtype="float32", device=None): shape = (shape,) if device is None: device = get_default_device() - (x,) = Const(value, dtype=dtype, device=device)( - Tensor(value, dtype=dtype, device=device) - ) + (x,) = Const(value, dtype=dtype, device=device)() if len(shape) == 0: # scalar return x return broadcast_to(x, shape) @@ -265,7 +263,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: device = as_device(device) def convert(x): - return convert_single_value(x, inps, dtype=dtype) + return convert_single_value(x, dtype=dtype, device=device) inps = tuple(map(convert, inps)) (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index fe3bc714db845c603fdc25e6c3ef5e05b23a7bdc..aadd7a9fa1e1d968c50b28eea27ea0994464af72 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -37,8 +37,10 @@ class Tensor(_Tensor, ArrayMethodMixin): else: cn = CompNode(device) else: - assert isinstance(device, CompNode) - cn = device + if isinstance(device, CompNode): + cn = device + else: + cn = device._cn # import pdb; pdb.set_trace() if isinstance(data, _Tensor): diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index fd5c660267fd489b6af70310e01a86164a95c6f1..35099e7f983aca764c5600504f22d57e4a3ff5bd 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -179,4 +179,5 @@ void init_common(py::module m) { init_npy_num_bfloat16(m); init_npy_num_intbx(m); + init_dtypes(m); } diff --git a/imperative/python/src/helper.cpp b/imperative/python/src/helper.cpp index 0d4bd70285cb9e62690f8b5d3abbbb2bb21010b5..b7f9215f605547c03b185cd8371a8af8c70fc42b 100644 --- a/imperative/python/src/helper.cpp +++ b/imperative/python/src/helper.cpp @@ -158,7 +158,7 @@ void PyExceptionForward::throw_() { /* ============== namespace npy ============== */ -namespace { +namespace npy { int to_mgb_supported_dtype_raw(int dtype) { if (dtype == NPY_INT64) @@ -199,12 +199,6 @@ int dtype_mgb2np_raw(DType dtype) { "can not convert dtype %s to numpy dtype", dtype.name())); } -struct PyArrayDescrDeleter { - void operator()(PyArray_Descr* obj) { - Py_XDECREF(obj); - } -}; - //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new //! reference to the descriptor. std::unique_ptr dtype_mgb2np_descr( @@ -585,9 +579,7 @@ void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) { HostTensorNDRefHolder::free(static_cast(ptr)); } -} // anonymous namespace - -PyObject* npy::ndarray_from_tensor( +PyObject* ndarray_from_tensor( const HostTensorND &val, ShareType share_type) { if (!val.layout().is_contiguous() && !val.shape().is_empty()) { mgb_assert(share_type != ShareType::MUST_SHARE); @@ -634,7 +626,7 @@ PyObject* npy::ndarray_from_tensor( return ret; } -HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { +HostTensorND np2tensor(PyObject* obj, const Meth& meth, DType dtype) { auto ret_full = np2tensor_try_borrow(obj, meth, dtype); if (meth.must_borrow_) { mgb_assert(ret_full.second, @@ -645,7 +637,7 @@ HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { return ret_full.first; } -PyObject* npy::dtype_mgb2np(mgb::DType dtype) { +PyObject* dtype_mgb2np(mgb::DType dtype) { PYTHON_GIL; // According to // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType @@ -668,7 +660,7 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) { return typeobj; } -mgb::DType npy::dtype_np2mgb(PyObject *obj) { +mgb::DType dtype_np2mgb(PyObject *obj) { mgb_assert(obj && obj != Py_None, "can not convert null PyObject to numpy dtype"); // see @@ -686,7 +678,7 @@ mgb::DType npy::dtype_np2mgb(PyObject *obj) { return result; } -PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { +PyObject* to_mgb_supported_dtype(PyObject* dtype) { PYTHON_GIL; PyArray_Descr* descr; @@ -702,7 +694,7 @@ PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { return PyArray_TypeObjectFromType(type_num); } -TensorShape npy::vec2shape(const std::vector &vec) { +TensorShape vec2shape(const std::vector &vec) { TensorShape shape; mgb_assert(vec.size() <= TensorShape::MAX_NDIM, "dim too large: %zd (max %zd)", @@ -718,3 +710,5 @@ TensorShape npy::vec2shape(const std::vector &vec) { mgb_assert(shape.ndim, "shape should not be empty"); return shape; } + +} // namespace npy diff --git a/imperative/python/src/helper.h b/imperative/python/src/helper.h index ab29206fdd87a434e8fa6ee3ba0f33d627d38bf2..6a2fb1c938c6d15b2064fa11a336860397410f9a 100644 --- a/imperative/python/src/helper.h +++ b/imperative/python/src/helper.h @@ -11,7 +11,7 @@ #pragma once -#include "megbrain/graph.h" +#include "megbrain/common.h" #include "megbrain/utils/persistent_cache.h" #include "megbrain/imperative/op_def.h" @@ -26,6 +26,8 @@ #include #include +#include "./numpy_dtypes.h" + pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr); pybind11::module rel_import(pybind11::str name, pybind11::module m, int level); @@ -182,6 +184,18 @@ namespace npy { //! convert raw vector to tensor shape mgb::TensorShape vec2shape(const std::vector &vec); + struct PyArrayDescrDeleter { + void operator()(PyArray_Descr* obj) { + Py_XDECREF(obj); + } + }; + + //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new + //! reference to the descriptor. + std::unique_ptr dtype_mgb2np_descr(mgb::DType dtype); + + mgb::DType dtype_np2mgb_descr(PyArray_Descr* descr); + //! convert megbrain dtype to numpy dtype object; return new reference PyObject* dtype_mgb2np(mgb::DType dtype); diff --git a/imperative/python/src/numpy_dtypes.cpp b/imperative/python/src/numpy_dtypes.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89310aea9ed4b3c65a7d156dc16a6374dad0fd14 --- /dev/null +++ b/imperative/python/src/numpy_dtypes.cpp @@ -0,0 +1,179 @@ +/** + * \file imperative/python/src/numpy_dtypes.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./numpy_dtypes.h" +#include "./helper.h" +#include "./pyext17.h" + +#include "pybind11/pybind11.h" + +#include + +namespace py = pybind11; + +namespace mgb { +namespace { + +inline bool _is_quantize(PyArray_Descr* dtype) { + static PyObject* PY_MGB_DTYPE_KEY = PyUnicode_FromString("mgb_dtype"); + return dtype->metadata && + PyDict_CheckExact(dtype->metadata) && + PyDict_Contains(dtype->metadata, PY_MGB_DTYPE_KEY) == 1; +} + +PyObject* _get_mgb_dtype(PyArray_Descr* dtype) { + // Return value: New reference. + if (!_is_quantize(dtype)) { + throw py::type_error("expact quantize dtype"); + } + PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype"); + if (!PyDict_CheckExact(ob)) { + throw py::type_error("mgb_dtype is not dict"); + } + Py_INCREF(ob); + return ob; +} + +double _get_scale(PyArray_Descr* dtype) { + PyObject* ob = _get_mgb_dtype(dtype); + PyObject* scale = PyDict_GetItemString(ob, "scale"); + if (!scale) { + Py_DECREF(ob); + throw py::key_error("scale"); + } + if (!PyFloat_Check(scale)) { + Py_DECREF(ob); + throw py::type_error("scale is not float"); + } + double ret = PyFloat_AsDouble(scale); + Py_DECREF(ob); + return ret; +} + +long _get_zero_point(PyArray_Descr* dtype) { + PyObject* ob = _get_mgb_dtype(dtype); + PyObject* name = PyDict_GetItemString(ob, "name"); + if (!name) { + Py_DECREF(ob); + throw py::key_error("name"); + } + const char* s = PyUnicode_AsUTF8(name); + if (strcmp(s, "Quantized8Asymm") != 0 && strcmp(s, "Quantized4Asymm") != 0) { + Py_DECREF(ob); + throw py::value_error(ssprintf("expect name to be \"Quantized8Asymm\" or \"Quantized4Asymm\", got %s", s)); + } + PyObject* zp = PyDict_GetItemString(ob, "zero_point"); + if (!zp) { + Py_DECREF(ob); + throw py::key_error("zero_point"); + } + long ret = PyLong_AsLong(zp); + Py_DECREF(ob); + return ret; +} + +bool _is_dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { + bool q1 = _is_quantize(dt1), + q2 = _is_quantize(dt2); + if (q1 && q2) { + if (_get_scale(dt1) != _get_scale(dt2)) { + return false; + } + PyObject* zp1 = PyDict_GetItemString( + PyDict_GetItemString(dt1->metadata, "mgb_dtype"), "zero_point"); + PyObject* zp2 = PyDict_GetItemString( + PyDict_GetItemString(dt2->metadata, "mgb_dtype"), "zero_point"); + if (!zp1 || !zp2) { + throw py::key_error("zero_point"); + } + return PyLong_AsLong(zp1) == PyLong_AsLong(zp2); + } + if (!q1 && !q2) { + return dt1->type_num == dt2->type_num; + } + return false; +} + +template +struct _wrap { + static constexpr size_t n_args = []() { + using F = decltype(f); + using T = PyArray_Descr*; + static_assert(std::is_pointer::value); + if constexpr (std::is_invocable::value) { + return 1; + } else if constexpr (std::is_invocable::value) { + return 2; + } else { + static_assert(!std::is_same_v, "unreachable"); + } + }(); + + static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargs) { + if (nargs != n_args) { + PyErr_Format(PyExc_ValueError, "expected %lu arguments", n_args); + return nullptr; + } + for (size_t i=0; iob_type->tp_name)); + } + if constexpr (n_args == 1) { + auto res = (*f)(dt1); + Py_DECREF(dt1); + return py::cast(res).release().ptr(); + } else { + PyArray_Descr *dt2; + if(!PyArray_DescrConverter(args[1], &dt2)) { + Py_DECREF(dt1); + throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", + args[1]->ob_type->tp_name)); + } + auto&& res = (*f)(dt1, dt2); + Py_DECREF(dt1); + Py_DECREF(dt2); + return py::cast(res).release().ptr(); + } + } catch (std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } + } +}; + +} // anonymous namespace + +void init_dtypes(py::module m) { + static PyMethodDef method_defs[] = { + {"is_quantize", (PyCFunction)_wrap<&_is_quantize>::impl, METH_FASTCALL, nullptr}, + {"get_scale", (PyCFunction)_wrap<&_get_scale>::impl, METH_FASTCALL, nullptr}, + {"get_zero_point", (PyCFunction)_wrap<&_get_zero_point>::impl, METH_FASTCALL, nullptr}, + {"is_dtype_equal", (PyCFunction)_wrap<&_is_dtype_equal>::impl, METH_FASTCALL, nullptr}, + {nullptr, nullptr, 0, nullptr} + }; + for (auto&& def: method_defs) { + if (def.ml_meth != nullptr) { + auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); + if (!func) throw py::error_already_set(); + py::setattr(m, def.ml_name, func); + } + } +} + +} // namespace mgb diff --git a/imperative/python/src/numpy_dtypes.h b/imperative/python/src/numpy_dtypes.h index 09a23a5abcc9fc2c4fe7b089fc7c0f6b4fa660fc..5c6fc86470354f06a6bcaf00d6c30d1219dc9c83 100644 --- a/imperative/python/src/numpy_dtypes.h +++ b/imperative/python/src/numpy_dtypes.h @@ -36,6 +36,7 @@ namespace mgb { int npy_num_intb##n(); FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) #undef DEFINE_NPY_INTBX + void init_dtypes(pybind11::module m); void init_npy_num_intbx(pybind11::module m); //! numpy type num for bfloat16 type diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 98e7b134da50c8f5ba0e1f7b9cb179be4343926f..f924ddd97d9ba13ed5179ea6adb16837c9052330 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -9,16 +9,22 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "megbrain/dtype.h" +#include "megbrain/common.h" + #include "./tensor.h" #include "./grad.h" #include "./trace.h" #include "./common.h" #include "./numpy_dtypes.h" #include "./graph_rt.h" +#include "./helper.h" #include #include -#include "./helper.h" + +#include + namespace py = pybind11; namespace mgb::imperative::python { @@ -413,6 +419,198 @@ struct TensorWeakRef { } }; +/* ============== convert inputs ============== */ + +// map numpy.dtype.kind to priority +inline uint8_t category_priority(char c) { + switch (c) { + case 'f': return 3; // floating-point + case 'i': return 2; // signed integer + case 'u': return 2; // unsigned integer + case 'b': return 1; // boolean + default: return 0; + } +} + +// Returns the maximum value of the priority of each type in the list `types`. +uint8_t max_priority(SmallVector types) { + if (types.size() == 0) { + return 0; + } else { + uint8_t max_p = 0; + for (auto&& desc: types) { + max_p = std::max(max_p, category_priority(desc->kind)); + } + return max_p; + } +} + +// Returns the data type with sufficient size to hold all types of +// category `cat` in the list `types`. +PyArray_Descr* promote_types(SmallVector types, uint8_t cat) { + // Return value: New reference + SmallVector used_types; + for (auto&& desc: types) { + auto&& v = category_priority(desc->kind); + if (v == cat) { + used_types.emplace_back(desc); + } + } + mgb_assert(used_types.size() > 0, "size of used_types is 0"); + PyArray_Descr* res = used_types[0]; + Py_INCREF(res); + + for (size_t i = 1; i < used_types.size(); ++i) { + PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); + Py_DECREF(res); + res = tmp; + } + return res; +} + +PyArray_Descr* scalar2dtype(PyObject* arg) { + // Return value: New reference + if (PyBool_Check(arg)) { + auto&& descr = PyArray_DescrFromType(NPY_BOOL); + return descr; + } + if (PyLong_CheckExact(arg)) { + auto&& descr = PyArray_DescrFromType(NPY_INT32); + return descr; + } + if (PyFloat_CheckExact(arg)) { + auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); + return descr; + } + return nullptr; +} + +PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { + // Return value: New reference + SmallVector tensors; + SmallVector scalars; + + bool is_tuple = false; + PyObject* tuple; + if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { + if (PyList_Check(args[0])) { + tuple = PyList_AsTuple(args[0]); + } else { + tuple = args[0]; + Py_INCREF(tuple); + } + nargs = PyTuple_Size(tuple); + is_tuple = true; + } + + for (size_t i = 0; i < nargs; ++i) { + PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; + if (handle == Py_None) continue; + TensorWrapper* tw = TensorWrapper::cast_safe(handle); + if (tw) { + mgb::DType type = tw->m_tensor->dtype(); + auto&& descr = npy::dtype_mgb2np_descr(type); + Py_INCREF(descr.get()); + tensors.emplace_back(descr.get()); + }else{ + if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { + auto&& descr = PyArray_DescrFromObject(handle, nullptr); + tensors.emplace_back(descr); + continue; + } + PyArray_Descr* descr = scalar2dtype(handle); + if (descr) { + scalars.emplace_back(descr); + continue; + } + } + } + + auto max_pri_scalars = max_priority(scalars); + auto max_pri_tensors = max_priority(tensors); + + if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { + throw py::value_error("invalid input, no dtype avaliable"); + } + PyArray_Descr* res; + if (max_pri_scalars > max_pri_tensors) { + res = promote_types(scalars, max_pri_scalars); + }else{ + res = promote_types(tensors, max_pri_tensors); + } + for (auto *p: tensors) { Py_DECREF(p); } + for (auto *p: scalars) { Py_DECREF(p); } + Py_DECREF(tuple); + return res; +} + +CompNode _get_device(PyObject*const* args, size_t nargs) { + bool is_tuple = false; + PyObject* tuple; + if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { + if (PyList_Check(args[0])) { + tuple = PyList_AsTuple(args[0]); + } else { + tuple = args[0]; + Py_INCREF(tuple); + } + nargs = PyTuple_Size(tuple); + is_tuple = true; + } + bool valid = false; + CompNode cn; + for (size_t i = 0; i < nargs; ++i) { + PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; + TensorWrapper* tw = TensorWrapper::cast_safe(handle); + if (tw) { + if (!valid) { + cn = tw->m_tensor->comp_node(); + valid = true; + } else { + CompNode cn1 = tw->m_tensor->comp_node(); + if (cn1 != cn) { + throw py::value_error(ssprintf("ambiguous device: %s vs %s", + cn.to_string().c_str(), cn1.to_string().c_str())); + } + } + } + } + if (!valid) { + mgb_assert(0, "expact at least 1 device"); + } + Py_DECREF(tuple); + return cn; +} + +// Returns the dtype that would result from performing an arithmetic +// operation on the provided input tensors and scalars. +PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) { + if (!nargs) { + PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); + return nullptr; + } + try { + PyArray_Descr* res = _dtype_promotion(args, nargs); + return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); + } catch (std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } +} + +PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) { + if (!nargs) { + PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); + return nullptr; + } + try { + CompNode cn = _get_device(args, nargs); + return py::cast(cn).release().ptr(); + } catch (std::exception& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + } +} void init_tensor(py::module m) { interpreter_for_py = interpreter::Interpreter::inst().create_channel(); @@ -444,10 +642,19 @@ void init_tensor(py::module m) { .def(py::init()) .def("__call__", &TensorWeakRef::operator()); - static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; - auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr); - if (!apply_func) throw py::error_already_set(); - py::setattr(m, "apply", apply_func); + static PyMethodDef method_defs[] = { + {"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}, + {"dtype_promotion", (PyCFunction)dtype_promotion, METH_FASTCALL, nullptr}, + {"get_device", (PyCFunction)get_device, METH_FASTCALL, nullptr}, + {nullptr, nullptr, 0, nullptr} + }; + for (auto&& def: method_defs) { + if (def.ml_meth != nullptr) { + auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); + if (!func) throw py::error_already_set(); + py::setattr(m, def.ml_name, func); + } + } m.def("_set_swap_flag", [](bool flag) { interpreter_for_py->set_swap_flag(flag); }); diff --git a/imperative/python/test/unit/core/test_dtype_quant.py b/imperative/python/test/unit/core/test_dtype_quant.py index 0ddf01e9d6aa8b997006610b616625b161d49675..b9a17972fcc8bc25ae53407d68cd2b59c10c8a10 100644 --- a/imperative/python/test/unit/core/test_dtype_quant.py +++ b/imperative/python/test/unit/core/test_dtype_quant.py @@ -113,7 +113,7 @@ def test_quint8_typecvt(): data = np.random.random(shape).astype(np.float32) * 5 - 1 def typecvt(x, dt=None): - (y,) = apply(ops.TypeCvt(dtype=dt), x) + (y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) return y # convert to quint8 @@ -194,7 +194,7 @@ def test_quint4_typecvt(): data = np.random.random(shape).astype(np.float32) * 5 - 1 def typecvt(x, dt=None): - (y,) = apply(ops.TypeCvt(dtype=dt), x) + (y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) return y # convert to quint4