提交 87f4b46e 编写于 作者: M Megvii Engine Team

perf(mge/imperative): move convert_inputs from python to C++

GitOrigin-RevId: baef3d348c590d477432c2c45df54835557e7c8d
上级 b310f261
......@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from .._imperative_rt.core2 import Tensor
# from .._imperative_rt.core2 import Tensor
from ..tensor.core import OpBase, TensorBase, apply
......@@ -19,5 +19,10 @@ class Const:
self.device = device
def __call__(self, *reference):
Wrapper = type(reference[0])
return (Wrapper(self.value, self.dtype, self.device, True),)
from ...tensor import Tensor
device = self.device
if device is None:
device = reference[0].device
return (Tensor(self.value, self.dtype, self.device, True),)
......@@ -13,6 +13,12 @@ import numpy as np
# normal dtype related
from .._imperative_rt import bfloat16, intb1, intb2, intb4
from .._imperative_rt.common import (
get_scale,
get_zero_point,
is_dtype_equal,
is_quantize,
)
def is_lowbit(dtype):
......@@ -42,41 +48,6 @@ _metadata_dict = {
}
def is_quantize(dtype):
return (
hasattr(dtype, "metadata")
and dtype.metadata is not None
and "mgb_dtype" in dtype.metadata
)
def get_scale(dtype):
assert is_quantize(dtype)
return dtype.metadata["mgb_dtype"]["scale"]
def get_zero_point(dtype):
assert is_quantize(dtype)
metadata = dtype.metadata["mgb_dtype"]
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm")
return metadata["zero_point"]
def is_equal(dt0, dt1):
def _get_zero_point(dtype):
assert is_quantize(dtype)
metadata = dtype.metadata["mgb_dtype"]
return metadata.get("zero_point")
if is_quantize(dt0) and is_quantize(dt1):
return get_scale(dt0) == get_scale(dt1) and _get_zero_point(
dt0
) == _get_zero_point(dt1)
if not (is_quantize(dt0) or is_quantize(dt1)):
return dt0 == dt1
return False
def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str].qmax
......
......@@ -151,9 +151,9 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
def get_index(i):
if not isinstance(i, (Tensor)):
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_:
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
(i,) = Const(i, dtype=np.bool_, device=inp.device)()
else:
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp)
(i,) = Const(i, dtype=np.int32, device=inp.device)()
return i
assert isinstance(i, Tensor)
if i.dtype != np.bool_:
......@@ -197,7 +197,7 @@ def try_condtake(tensor, index):
):
return []
if isinstance(index, np.ndarray):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
(index,) = Const(index, dtype=np.bool_, device=tensor.device)()
assert isinstance(index, Tensor)
if not isinstance(tensor, Tensor):
raise TypeError("input must be a tensor")
......@@ -217,9 +217,7 @@ def getitem(tensor, index):
if isinstance(v.shape, v.__class__):
break
if len(v.shape) > 0 and v.shape[0] == 0:
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)(
tensor
)
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)()
return empty_tensor
if use_subtensor:
op = builtin.Subtensor(items=items)
......@@ -240,8 +238,7 @@ def setitem(tensor, index, value):
return tensor
tensor = tensor.reshape(-1)
if not isinstance(value, Tensor):
op = Const(value, dtype=tensor.dtype, device=tensor.device)
(value,) = op(tensor)
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)()
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
for v in tensors:
if len(v.shape) > 0 and v.shape[0] == 0:
......
......@@ -11,10 +11,10 @@ from typing import Iterable, Union
import numpy as np
from .._imperative_rt.core2 import Tensor, apply
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device
from ..ops import builtin
from ..ops.special import Const
from .dtype import is_equal, is_quantize
from .dtype import is_dtype_equal, is_quantize
from .megbrain_graph import VarNode
_enable_convert_inputs = True
......@@ -37,94 +37,12 @@ def set_convert_inputs(flag):
return backup
def dtype_promotion(inputs):
"""
Returns the dtype that would result from performing an arithmetic
operation on the provided input tensors and scalars.
"""
# map numpy.dtype.kind to priority
category_priority = {
"f": 3, # floating-point
"i": 2, # signed integer
"u": 2, # unsigned integer
"b": 1, # boolean
}
def scalar2dtype(x):
"""
For scalar `x`, returns its corresponding type. A floating point scalar
has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'.
A boolean scalar has dtype 'bool'.
"""
if isinstance(x, bool):
return np.bool_
if isinstance(x, int):
return np.int32
if isinstance(x, float):
return np.float32
def promote_types(types, cat):
"""
Returns the data type with sufficient size to hold all types of
category `cat` in the list `types`.
"""
used_types = [
i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat
]
assert len(used_types) > 0
res = used_types[0]
for i in used_types:
res = np.promote_types(res, i)
return res
def max_priority(types):
"""
Returns the maximum value of the priority of each type in the list
`types`.
"""
if not types:
return 0
else:
return max([category_priority.get(np.dtype(i).kind, 0) for i in types])
scalars = []
tensors = []
for data in inputs:
if hasattr(data, "dtype"):
tensors.append(data.dtype)
elif isinstance(data, (float, int, bool)):
scalars.append(scalar2dtype(data))
max_pri_scalars = max_priority(scalars)
max_pri_tensors = max_priority(tensors)
assert max_pri_scalars > 0 or max_pri_tensors > 0
if max_pri_scalars > max_pri_tensors:
return promote_types(scalars, max_pri_scalars)
else:
return promote_types(tensors, max_pri_tensors)
def get_device(inputs):
device = None
for i in inputs:
if isinstance(i, (Tensor, VarNode)):
if device is None:
device = i.device
elif device != i.device:
raise ValueError("ambiguous device: {} vs {}".format(device, i.device))
assert device is not None
return device
def concatenate(inputs, axis=0, *, device=None):
dtype = dtype_promotion(inputs)
device = get_device(inputs)
def convert(x):
return convert_single_value(x, inputs, dtype=dtype)
return convert_single_value(x, dtype=dtype, device=device)
inputs = tuple(map(convert, inputs))
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs)
......@@ -133,7 +51,7 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype):
dtype = np.dtype(dtype)
if not is_equal(x.dtype, dtype):
if not is_dtype_equal(x.dtype, dtype):
isscalar = x.isscalar()
(x,) = apply(builtin.TypeCvt(dtype=dtype), x)
if isscalar:
......@@ -141,13 +59,12 @@ def astype(x, dtype):
return x
def convert_single_value(v, inputs, *, dtype=None, device=None):
tensors = [i for i in inputs if isinstance(i, (Tensor, VarNode))]
assert len(tensors) > 0
def convert_single_value(v, *, dtype=None, device=None):
if isinstance(v, (Tensor, VarNode)):
v = astype(v, v.dtype if is_quantize(v.dtype) else dtype)
if not is_quantize(v.dtype):
v = astype(v, dtype)
else:
(v,) = Const(v, dtype=dtype, device=device)(*tensors)
(v,) = Const(v, dtype=dtype, device=device)()
return v
......@@ -161,7 +78,7 @@ def convert_inputs(*args: Tensor):
def convert(value):
if value is None:
return value
return convert_single_value(value, args, dtype=dtype, device=device)
return convert_single_value(value, dtype=dtype, device=device)
return tuple(map(convert, args))
......
......@@ -703,7 +703,7 @@ def topk(
op = builtin.TopK(mode=mode)
if not isinstance(k, Tensor):
(k,) = Const(k, dtype="int32", device=inp.device)(inp)
(k,) = Const(k, dtype="int32", device=inp.device)()
if len(inp.shape) == 1:
inp = inp.reshape(1, -1)
......
......@@ -658,7 +658,7 @@ def batch_norm(
def make_full_if_none(x, value):
if x is None:
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
(x,) = Const(value, dtype=inp.dtype, device=inp.device)()
shape = utils.astensor1d(
(1, C, 1, 1), inp, dtype="int32", device=inp.device
)
......@@ -1567,7 +1567,7 @@ def indexing_one_hot(
"""
assert isinstance(src, Tensor), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
index = utils.convert_single_value(index, dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
result = squeeze(result, axis)
......
......@@ -107,9 +107,7 @@ def full(shape, value, dtype="float32", device=None):
shape = (shape,)
if device is None:
device = get_default_device()
(x,) = Const(value, dtype=dtype, device=device)(
Tensor(value, dtype=dtype, device=device)
)
(x,) = Const(value, dtype=dtype, device=device)()
if len(shape) == 0: # scalar
return x
return broadcast_to(x, shape)
......@@ -265,7 +263,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
device = as_device(device)
def convert(x):
return convert_single_value(x, inps, dtype=dtype)
return convert_single_value(x, dtype=dtype, device=device)
inps = tuple(map(convert, inps))
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
......
......@@ -37,8 +37,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
else:
cn = CompNode(device)
else:
assert isinstance(device, CompNode)
cn = device
if isinstance(device, CompNode):
cn = device
else:
cn = device._cn
# import pdb; pdb.set_trace()
if isinstance(data, _Tensor):
......
......@@ -179,4 +179,5 @@ void init_common(py::module m) {
init_npy_num_bfloat16(m);
init_npy_num_intbx(m);
init_dtypes(m);
}
......@@ -158,7 +158,7 @@ void PyExceptionForward::throw_() {
/* ============== namespace npy ============== */
namespace {
namespace npy {
int to_mgb_supported_dtype_raw(int dtype) {
if (dtype == NPY_INT64)
......@@ -199,12 +199,6 @@ int dtype_mgb2np_raw(DType dtype) {
"can not convert dtype %s to numpy dtype", dtype.name()));
}
struct PyArrayDescrDeleter {
void operator()(PyArray_Descr* obj) {
Py_XDECREF(obj);
}
};
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
//! reference to the descriptor.
std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(
......@@ -585,9 +579,7 @@ void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) {
HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr));
}
} // anonymous namespace
PyObject* npy::ndarray_from_tensor(
PyObject* ndarray_from_tensor(
const HostTensorND &val, ShareType share_type) {
if (!val.layout().is_contiguous() && !val.shape().is_empty()) {
mgb_assert(share_type != ShareType::MUST_SHARE);
......@@ -634,7 +626,7 @@ PyObject* npy::ndarray_from_tensor(
return ret;
}
HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) {
HostTensorND np2tensor(PyObject* obj, const Meth& meth, DType dtype) {
auto ret_full = np2tensor_try_borrow(obj, meth, dtype);
if (meth.must_borrow_) {
mgb_assert(ret_full.second,
......@@ -645,7 +637,7 @@ HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) {
return ret_full.first;
}
PyObject* npy::dtype_mgb2np(mgb::DType dtype) {
PyObject* dtype_mgb2np(mgb::DType dtype) {
PYTHON_GIL;
// According to
// https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType
......@@ -668,7 +660,7 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) {
return typeobj;
}
mgb::DType npy::dtype_np2mgb(PyObject *obj) {
mgb::DType dtype_np2mgb(PyObject *obj) {
mgb_assert(obj && obj != Py_None,
"can not convert null PyObject to numpy dtype");
// see
......@@ -686,7 +678,7 @@ mgb::DType npy::dtype_np2mgb(PyObject *obj) {
return result;
}
PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) {
PyObject* to_mgb_supported_dtype(PyObject* dtype) {
PYTHON_GIL;
PyArray_Descr* descr;
......@@ -702,7 +694,7 @@ PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) {
return PyArray_TypeObjectFromType(type_num);
}
TensorShape npy::vec2shape(const std::vector<size_t> &vec) {
TensorShape vec2shape(const std::vector<size_t> &vec) {
TensorShape shape;
mgb_assert(vec.size() <= TensorShape::MAX_NDIM,
"dim too large: %zd (max %zd)",
......@@ -718,3 +710,5 @@ TensorShape npy::vec2shape(const std::vector<size_t> &vec) {
mgb_assert(shape.ndim, "shape should not be empty");
return shape;
}
} // namespace npy
......@@ -11,7 +11,7 @@
#pragma once
#include "megbrain/graph.h"
#include "megbrain/common.h"
#include "megbrain/utils/persistent_cache.h"
#include "megbrain/imperative/op_def.h"
......@@ -26,6 +26,8 @@
#include <pybind11/numpy.h>
#include <pybind11/functional.h>
#include "./numpy_dtypes.h"
pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr);
pybind11::module rel_import(pybind11::str name, pybind11::module m, int level);
......@@ -182,6 +184,18 @@ namespace npy {
//! convert raw vector to tensor shape
mgb::TensorShape vec2shape(const std::vector<size_t> &vec);
struct PyArrayDescrDeleter {
void operator()(PyArray_Descr* obj) {
Py_XDECREF(obj);
}
};
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new
//! reference to the descriptor.
std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(mgb::DType dtype);
mgb::DType dtype_np2mgb_descr(PyArray_Descr* descr);
//! convert megbrain dtype to numpy dtype object; return new reference
PyObject* dtype_mgb2np(mgb::DType dtype);
......
/**
* \file imperative/python/src/numpy_dtypes.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./numpy_dtypes.h"
#include "./helper.h"
#include "./pyext17.h"
#include "pybind11/pybind11.h"
#include <cstring>
namespace py = pybind11;
namespace mgb {
namespace {
inline bool _is_quantize(PyArray_Descr* dtype) {
static PyObject* PY_MGB_DTYPE_KEY = PyUnicode_FromString("mgb_dtype");
return dtype->metadata &&
PyDict_CheckExact(dtype->metadata) &&
PyDict_Contains(dtype->metadata, PY_MGB_DTYPE_KEY) == 1;
}
PyObject* _get_mgb_dtype(PyArray_Descr* dtype) {
// Return value: New reference.
if (!_is_quantize(dtype)) {
throw py::type_error("expact quantize dtype");
}
PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype");
if (!PyDict_CheckExact(ob)) {
throw py::type_error("mgb_dtype is not dict");
}
Py_INCREF(ob);
return ob;
}
double _get_scale(PyArray_Descr* dtype) {
PyObject* ob = _get_mgb_dtype(dtype);
PyObject* scale = PyDict_GetItemString(ob, "scale");
if (!scale) {
Py_DECREF(ob);
throw py::key_error("scale");
}
if (!PyFloat_Check(scale)) {
Py_DECREF(ob);
throw py::type_error("scale is not float");
}
double ret = PyFloat_AsDouble(scale);
Py_DECREF(ob);
return ret;
}
long _get_zero_point(PyArray_Descr* dtype) {
PyObject* ob = _get_mgb_dtype(dtype);
PyObject* name = PyDict_GetItemString(ob, "name");
if (!name) {
Py_DECREF(ob);
throw py::key_error("name");
}
const char* s = PyUnicode_AsUTF8(name);
if (strcmp(s, "Quantized8Asymm") != 0 && strcmp(s, "Quantized4Asymm") != 0) {
Py_DECREF(ob);
throw py::value_error(ssprintf("expect name to be \"Quantized8Asymm\" or \"Quantized4Asymm\", got %s", s));
}
PyObject* zp = PyDict_GetItemString(ob, "zero_point");
if (!zp) {
Py_DECREF(ob);
throw py::key_error("zero_point");
}
long ret = PyLong_AsLong(zp);
Py_DECREF(ob);
return ret;
}
bool _is_dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) {
bool q1 = _is_quantize(dt1),
q2 = _is_quantize(dt2);
if (q1 && q2) {
if (_get_scale(dt1) != _get_scale(dt2)) {
return false;
}
PyObject* zp1 = PyDict_GetItemString(
PyDict_GetItemString(dt1->metadata, "mgb_dtype"), "zero_point");
PyObject* zp2 = PyDict_GetItemString(
PyDict_GetItemString(dt2->metadata, "mgb_dtype"), "zero_point");
if (!zp1 || !zp2) {
throw py::key_error("zero_point");
}
return PyLong_AsLong(zp1) == PyLong_AsLong(zp2);
}
if (!q1 && !q2) {
return dt1->type_num == dt2->type_num;
}
return false;
}
template<auto f>
struct _wrap {
static constexpr size_t n_args = []() {
using F = decltype(f);
using T = PyArray_Descr*;
static_assert(std::is_pointer<F>::value);
if constexpr (std::is_invocable<F, T>::value) {
return 1;
} else if constexpr (std::is_invocable<F, T, T>::value) {
return 2;
} else {
static_assert(!std::is_same_v<F, F>, "unreachable");
}
}();
static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargs) {
if (nargs != n_args) {
PyErr_Format(PyExc_ValueError, "expected %lu arguments", n_args);
return nullptr;
}
for (size_t i=0; i<nargs; ++i) {
if (args[i] == Py_None) {
PyErr_SetString(PyExc_ValueError, "can not convert null PyObject to numpy dtype");
return nullptr;
}
}
try {
PyArray_Descr *dt1;
if(!PyArray_DescrConverter(args[0], &dt1)) {
throw ConversionError(ssprintf("can not convert to numpy.dtype from %s",
args[0]->ob_type->tp_name));
}
if constexpr (n_args == 1) {
auto res = (*f)(dt1);
Py_DECREF(dt1);
return py::cast(res).release().ptr();
} else {
PyArray_Descr *dt2;
if(!PyArray_DescrConverter(args[1], &dt2)) {
Py_DECREF(dt1);
throw ConversionError(ssprintf("can not convert to numpy.dtype from %s",
args[1]->ob_type->tp_name));
}
auto&& res = (*f)(dt1, dt2);
Py_DECREF(dt1);
Py_DECREF(dt2);
return py::cast(res).release().ptr();
}
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
}
};
} // anonymous namespace
void init_dtypes(py::module m) {
static PyMethodDef method_defs[] = {
{"is_quantize", (PyCFunction)_wrap<&_is_quantize>::impl, METH_FASTCALL, nullptr},
{"get_scale", (PyCFunction)_wrap<&_get_scale>::impl, METH_FASTCALL, nullptr},
{"get_zero_point", (PyCFunction)_wrap<&_get_zero_point>::impl, METH_FASTCALL, nullptr},
{"is_dtype_equal", (PyCFunction)_wrap<&_is_dtype_equal>::impl, METH_FASTCALL, nullptr},
{nullptr, nullptr, 0, nullptr}
};
for (auto&& def: method_defs) {
if (def.ml_meth != nullptr) {
auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
if (!func) throw py::error_already_set();
py::setattr(m, def.ml_name, func);
}
}
}
} // namespace mgb
......@@ -36,6 +36,7 @@ namespace mgb {
int npy_num_intb##n();
FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX)
#undef DEFINE_NPY_INTBX
void init_dtypes(pybind11::module m);
void init_npy_num_intbx(pybind11::module m);
//! numpy type num for bfloat16 type
......
......@@ -9,16 +9,22 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/dtype.h"
#include "megbrain/common.h"
#include "./tensor.h"
#include "./grad.h"
#include "./trace.h"
#include "./common.h"
#include "./numpy_dtypes.h"
#include "./graph_rt.h"
#include "./helper.h"
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include "./helper.h"
#include <unordered_map>
namespace py = pybind11;
namespace mgb::imperative::python {
......@@ -413,6 +419,198 @@ struct TensorWeakRef {
}
};
/* ============== convert inputs ============== */
// map numpy.dtype.kind to priority
inline uint8_t category_priority(char c) {
switch (c) {
case 'f': return 3; // floating-point
case 'i': return 2; // signed integer
case 'u': return 2; // unsigned integer
case 'b': return 1; // boolean
default: return 0;
}
}
// Returns the maximum value of the priority of each type in the list `types`.
uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
if (types.size() == 0) {
return 0;
} else {
uint8_t max_p = 0;
for (auto&& desc: types) {
max_p = std::max(max_p, category_priority(desc->kind));
}
return max_p;
}
}
// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
// Return value: New reference
SmallVector<PyArray_Descr*> used_types;
for (auto&& desc: types) {
auto&& v = category_priority(desc->kind);
if (v == cat) {
used_types.emplace_back(desc);
}
}
mgb_assert(used_types.size() > 0, "size of used_types is 0");
PyArray_Descr* res = used_types[0];
Py_INCREF(res);
for (size_t i = 1; i < used_types.size(); ++i) {
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
Py_DECREF(res);
res = tmp;
}
return res;
}
PyArray_Descr* scalar2dtype(PyObject* arg) {
// Return value: New reference
if (PyBool_Check(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_BOOL);
return descr;
}
if (PyLong_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_INT32);
return descr;
}
if (PyFloat_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
return descr;
}
return nullptr;
}
PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
// Return value: New reference
SmallVector<PyArray_Descr*> tensors;
SmallVector<PyArray_Descr*> scalars;
bool is_tuple = false;
PyObject* tuple;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
if (handle == Py_None) continue;
TensorWrapper* tw = TensorWrapper::cast_safe(handle);
if (tw) {
mgb::DType type = tw->m_tensor->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
}else{
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
auto&& descr = PyArray_DescrFromObject(handle, nullptr);
tensors.emplace_back(descr);
continue;
}
PyArray_Descr* descr = scalar2dtype(handle);
if (descr) {
scalars.emplace_back(descr);
continue;
}
}
}
auto max_pri_scalars = max_priority(scalars);
auto max_pri_tensors = max_priority(tensors);
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
throw py::value_error("invalid input, no dtype avaliable");
}
PyArray_Descr* res;
if (max_pri_scalars > max_pri_tensors) {
res = promote_types(scalars, max_pri_scalars);
}else{
res = promote_types(tensors, max_pri_tensors);
}
for (auto *p: tensors) { Py_DECREF(p); }
for (auto *p: scalars) { Py_DECREF(p); }
Py_DECREF(tuple);
return res;
}
CompNode _get_device(PyObject*const* args, size_t nargs) {
bool is_tuple = false;
PyObject* tuple;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}
bool valid = false;
CompNode cn;
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
TensorWrapper* tw = TensorWrapper::cast_safe(handle);
if (tw) {
if (!valid) {
cn = tw->m_tensor->comp_node();
valid = true;
} else {
CompNode cn1 = tw->m_tensor->comp_node();
if (cn1 != cn) {
throw py::value_error(ssprintf("ambiguous device: %s vs %s",
cn.to_string().c_str(), cn1.to_string().c_str()));
}
}
}
}
if (!valid) {
mgb_assert(0, "expact at least 1 device");
}
Py_DECREF(tuple);
return cn;
}
// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
PyArray_Descr* res = _dtype_promotion(args, nargs);
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
}
PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
CompNode cn = _get_device(args, nargs);
return py::cast(cn).release().ptr();
} catch (std::exception& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}
}
void init_tensor(py::module m) {
interpreter_for_py = interpreter::Interpreter::inst().create_channel();
......@@ -444,10 +642,19 @@ void init_tensor(py::module m) {
.def(py::init<const TensorWrapper&>())
.def("__call__", &TensorWeakRef::operator());
static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr};
auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr);
if (!apply_func) throw py::error_already_set();
py::setattr(m, "apply", apply_func);
static PyMethodDef method_defs[] = {
{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr},
{"dtype_promotion", (PyCFunction)dtype_promotion, METH_FASTCALL, nullptr},
{"get_device", (PyCFunction)get_device, METH_FASTCALL, nullptr},
{nullptr, nullptr, 0, nullptr}
};
for (auto&& def: method_defs) {
if (def.ml_meth != nullptr) {
auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
if (!func) throw py::error_already_set();
py::setattr(m, def.ml_name, func);
}
}
m.def("_set_swap_flag",
[](bool flag) { interpreter_for_py->set_swap_flag(flag); });
......
......@@ -113,7 +113,7 @@ def test_quint8_typecvt():
data = np.random.random(shape).astype(np.float32) * 5 - 1
def typecvt(x, dt=None):
(y,) = apply(ops.TypeCvt(dtype=dt), x)
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x)
return y
# convert to quint8
......@@ -194,7 +194,7 @@ def test_quint4_typecvt():
data = np.random.random(shape).astype(np.float32) * 5 - 1
def typecvt(x, dt=None):
(y,) = apply(ops.TypeCvt(dtype=dt), x)
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x)
return y
# convert to quint4
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册