提交 fa62f6c0 编写于 作者: M Megvii Engine Team

perf(mge/utils): move convert_input into C++

GitOrigin-RevId: 0d1cd362511d2d423faaeffd9d80710747cf05f2
上级 d98be080
......@@ -19,6 +19,7 @@ from .._imperative_rt.core2 import (
SymbolVar,
Tensor,
apply,
astype_cpp,
broadcast_cpp,
dtype_promotion,
)
......@@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp
from ..ops import builtin
from . import amp
from .indexing import getitem, setitem
from .utils import (
_normalize_axis,
astensor1d,
astype,
cast_tensors,
make_shape_tuple,
subgraph,
)
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph
_ElwMod = builtin.Elemwise.Mode
......@@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC):
r"""Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`.
"""
return astype(self, dtype)
return astype_cpp(self, dtype)
def reshape(self, *args):
r"""See :func:`~.reshape`."""
......
......@@ -20,6 +20,9 @@ from .._imperative_rt.core2 import (
_get_convert_inputs,
_set_convert_inputs,
apply,
astype_cpp,
convert_inputs_cpp,
convert_single_value_cpp,
dtype_promotion,
get_device,
make_shape_tuple,
......@@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None):
return result
def astype(x, dtype):
dtype = np.dtype(dtype)
if not is_dtype_equal(x.dtype, dtype):
(x,) = apply(builtin.TypeCvt(dtype=dtype), x)
return x
def convert_single_value(v, *, dtype=None, device=None):
if isinstance(v, (Tensor, SymbolVar)):
if not is_quantize(v.dtype):
v = astype(v, dtype)
else:
v = Const(v, dtype, device, None)
return v
return convert_single_value_cpp(v, dtype, device)
def convert_inputs(*args, device=None):
if not _get_convert_inputs():
return args
dtype = dtype_promotion(args)
if device is None:
device = get_device(args)
device = as_device(device)
graph = None
sym_type = None
for a in args:
if isinstance(a, SymbolVar):
if graph is None:
graph = a.var.graph
sym_type = type(a)
else:
assert graph == a.var.graph
args = list(args)
if graph is not None:
for i in range(len(args)):
if not isinstance(args[i], SymbolVar):
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype)
args[i] = sym_type(rst)
def convert(value):
if value is None:
return value
return convert_single_value(value, dtype=dtype, device=device.to_c())
return tuple(map(convert, args))
return convert_inputs_cpp(*args, device)
def cast_tensors(*args, promote=False):
......@@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
pass
except ValueError:
if dtype is not None and dtype != x.dtype:
x = astype(x, dtype)
x = astype_cpp(x, dtype)
if device is not None:
cn = as_device(device).to_c()
(x,) = apply(builtin.Copy(comp_node=cn), x)
......@@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if any(isinstance(i, (Tensor, SymbolVar)) for i in x):
x = concatenate(x, device=device) if len(x) > 1 else x[0]
if dtype is not None:
x = astype(x, dtype)
x = astype_cpp(x, dtype)
return x
x = Const(x, dtype, device, reference)
return x
......
......@@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import (
astensor1d,
astype,
cast_tensors,
convert_single_value,
make_shape_tuple,
......
......@@ -170,6 +170,12 @@ struct _wrap {
} // anonymous namespace
namespace imperative::python {
bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) {
return _is_dtype_equal(dt1, dt2);
}
} // namespace imperative::python
#ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUN) \
{ #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr }
......
......@@ -26,6 +26,11 @@
cb(BFloat16, npy_num_bfloat16())
namespace mgb {
namespace imperative::python {
bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2);
} // namespace imperative::python
//! numpy type num for intb1/2/4 type
#define DEFINE_NPY_INTBX(n) int npy_num_intb##n();
FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX)
......
......@@ -400,223 +400,6 @@ struct TensorWeakRef {
int _use_cnt() { return wptr.use_count(); }
};
/* ============== convert inputs ============== */
// map numpy.dtype.kind to priority
inline uint8_t category_priority(char c) {
switch (c) {
case 'f':
return 3; // floating-point
case 'i':
return 2; // signed integer
case 'u':
return 2; // unsigned integer
case 'b':
return 1; // boolean
default:
return 0;
}
}
// Returns the maximum value of the priority of each type in the list `types`.
uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
if (types.size() == 0) {
return 0;
} else {
uint8_t max_p = 0;
for (auto&& desc : types) {
max_p = std::max(max_p, category_priority(desc->kind));
}
return max_p;
}
}
// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
// Return value: New reference
SmallVector<PyArray_Descr*> used_types;
for (auto&& desc : types) {
auto&& v = category_priority(desc->kind);
if (v == cat) {
used_types.emplace_back(desc);
}
}
mgb_assert(used_types.size() > 0, "size of used_types is 0");
PyArray_Descr* res = used_types[0];
Py_INCREF(res);
for (size_t i = 1; i < used_types.size(); ++i) {
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
Py_DECREF(res);
res = tmp;
}
return res;
}
PyArray_Descr* scalar2dtype(PyObject* arg) {
// Return value: New reference
if (PyBool_Check(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_BOOL);
return descr;
}
if (PyLong_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_INT32);
return descr;
}
if (PyFloat_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
return descr;
}
return nullptr;
}
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
// Return value: New reference
SmallVector<PyArray_Descr*> tensors;
SmallVector<PyArray_Descr*> scalars;
bool is_tuple = false;
PyObject* tuple = nullptr;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
if (handle == Py_None)
continue;
TensorWrapper* tw = TensorWrapper::try_cast(handle);
if (tw) {
mgb::DType type = tw->m_tensor->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
} else {
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
auto&& descr = PyArray_DescrFromObject(handle, nullptr);
tensors.emplace_back(descr);
continue;
}
if (py::isinstance<PySymbolVar>(py::handle(handle))) {
auto var = py::handle(handle).cast<PySymbolVar*>();
mgb::DType type = var->m_node->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
continue;
}
PyArray_Descr* descr = scalar2dtype(handle);
if (descr) {
scalars.emplace_back(descr);
continue;
}
}
}
auto max_pri_scalars = max_priority(scalars);
auto max_pri_tensors = max_priority(tensors);
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
throw py::value_error("invalid input, no dtype avaliable");
}
PyArray_Descr* res;
if (max_pri_scalars > max_pri_tensors) {
res = promote_types(scalars, max_pri_scalars);
} else {
res = promote_types(tensors, max_pri_tensors);
}
for (auto* p : tensors) {
Py_DECREF(p);
}
for (auto* p : scalars) {
Py_DECREF(p);
}
Py_XDECREF(tuple);
return res;
}
CompNode _get_device(PyObject* const* args, size_t nargs) {
bool is_tuple = false;
PyObject* tuple = nullptr;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}
bool valid = false;
CompNode cn;
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle);
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
if (tw || is_symvar) {
if (!valid) {
cn = tw ? tw->m_tensor->comp_node()
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
valid = true;
} else {
CompNode cn1 = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
if (cn1 != cn) {
throw py::value_error(ssprintf(
"ambiguous device: %s (from %s) vs %s (from %s)",
cn.to_string().c_str(), cn.to_string_logical().c_str(),
cn1.to_string().c_str(), cn1.to_string_logical().c_str()));
}
}
}
}
if (!valid) {
return CompNode::load(get_default_device());
}
Py_XDECREF(tuple);
return cn;
}
// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
PyArray_Descr* res = _dtype_promotion(args, nargs);
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
CompNode cn = _get_device(args, nargs);
return py::cast(cn).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
#ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
......@@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp);
WRAP_FUNC_PY35(broadcast_cpp);
WRAP_FUNC_PY35(reshape_cpp);
WRAP_FUNC_PY35(Const);
WRAP_FUNC_PY35(astype_cpp);
WRAP_FUNC_PY35(convert_single_value_cpp);
WRAP_FUNC_PY35(convert_inputs_cpp);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
......@@ -779,6 +565,9 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
MGE_PY_INTERFACE(Const, Const),
MGE_PY_INTERFACE(astype_cpp, astype_cpp),
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {
......
......@@ -52,6 +52,223 @@ namespace views = ranges::views;
namespace mgb::imperative::python {
/* ============== convert inputs ============== */
// map numpy.dtype.kind to priority
inline uint8_t category_priority(char c) {
switch (c) {
case 'f':
return 3; // floating-point
case 'i':
return 2; // signed integer
case 'u':
return 2; // unsigned integer
case 'b':
return 1; // boolean
default:
return 0;
}
}
// Returns the maximum value of the priority of each type in the list `types`.
uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
if (types.size() == 0) {
return 0;
} else {
uint8_t max_p = 0;
for (auto&& desc : types) {
max_p = std::max(max_p, category_priority(desc->kind));
}
return max_p;
}
}
// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
// Return value: New reference
SmallVector<PyArray_Descr*> used_types;
for (auto&& desc : types) {
auto&& v = category_priority(desc->kind);
if (v == cat) {
used_types.emplace_back(desc);
}
}
mgb_assert(used_types.size() > 0, "size of used_types is 0");
PyArray_Descr* res = used_types[0];
Py_INCREF(res);
for (size_t i = 1; i < used_types.size(); ++i) {
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
Py_DECREF(res);
res = tmp;
}
return res;
}
PyArray_Descr* scalar2dtype(PyObject* arg) {
// Return value: New reference
if (PyBool_Check(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_BOOL);
return descr;
}
if (PyLong_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_INT32);
return descr;
}
if (PyFloat_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
return descr;
}
return nullptr;
}
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
// Return value: New reference
SmallVector<PyArray_Descr*> tensors;
SmallVector<PyArray_Descr*> scalars;
bool is_tuple = false;
PyObject* tuple = nullptr;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
if (handle == Py_None)
continue;
TensorWrapper* tw = TensorWrapper::try_cast(handle);
if (tw) {
mgb::DType type = tw->m_tensor->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
} else {
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
auto&& descr = PyArray_DescrFromObject(handle, nullptr);
tensors.emplace_back(descr);
continue;
}
if (py::isinstance<PySymbolVar>(py::handle(handle))) {
auto var = py::handle(handle).cast<PySymbolVar*>();
mgb::DType type = var->m_node->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
continue;
}
PyArray_Descr* descr = scalar2dtype(handle);
if (descr) {
scalars.emplace_back(descr);
continue;
}
}
}
auto max_pri_scalars = max_priority(scalars);
auto max_pri_tensors = max_priority(tensors);
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
throw py::value_error("invalid input, no dtype avaliable");
}
PyArray_Descr* res;
if (max_pri_scalars > max_pri_tensors) {
res = promote_types(scalars, max_pri_scalars);
} else {
res = promote_types(tensors, max_pri_tensors);
}
for (auto* p : tensors) {
Py_DECREF(p);
}
for (auto* p : scalars) {
Py_DECREF(p);
}
Py_XDECREF(tuple);
return res;
}
CompNode _get_device(PyObject* const* args, size_t nargs) {
bool is_tuple = false;
PyObject* tuple = nullptr;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}
bool valid = false;
CompNode cn;
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle);
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
if (tw || is_symvar) {
if (!valid) {
cn = tw ? tw->m_tensor->comp_node()
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
valid = true;
} else {
CompNode cn1 = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
if (cn1 != cn) {
throw py::value_error(ssprintf(
"ambiguous device: %s (from %s) vs %s (from %s)",
cn.to_string().c_str(), cn.to_string_logical().c_str(),
cn1.to_string().c_str(), cn1.to_string_logical().c_str()));
}
}
}
}
if (!valid) {
return CompNode::load(get_default_device());
}
Py_XDECREF(tuple);
return cn;
}
// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
PyArray_Descr* res = _dtype_promotion(args, nargs);
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
CompNode cn = _get_device(args, nargs);
return py::cast(cn).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
bool is_scalar(PyObject* tensor) {
if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
auto var = py::handle(tensor).cast<PySymbolVar*>();
......@@ -147,7 +364,6 @@ py::object _Const(
"dmap_callback");
if (dmap.ptr() != Py_None) {
device_obj = dmap(device);
py::print(device_obj);
} else {
device_obj = py::cast(CompNode::load(device.cast<std::string>()));
}
......@@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) {
return ret[0];
}
mgb::DType _get_dtype(py::handle tensor) {
if (auto tw = TensorWrapper::try_cast(tensor.ptr())) {
return tw->m_tensor->dtype();
} else {
auto var = tensor.cast<PySymbolVar*>();
return var->m_node->dtype();
}
}
py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
PyArray_Descr* descr;
if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) {
throw py::value_error(ssprintf(
"can not convert to numpy.dtype from %s",
dtype_hdl.ptr()->ob_type->tp_name));
}
PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get();
if (!dtype_equal(cur, descr)) {
std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr));
py::object Op = py::cast(op);
std::vector<PyObject*> p;
p.resize(2);
p[0] = Op.ptr();
p[1] = tensor.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
} else {
return py::reinterpret_borrow<py::object>(tensor);
}
}
py::object _convert_single_value_cpp(
py::handle value, py::handle dtype, py::handle device) {
if (is_tensor_or_symbolvar(value)) {
if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) {
return _astype_cpp(value, dtype);
}
} else {
return _Const(value, dtype, device, py::none());
}
return py::reinterpret_borrow<py::object>(value);
}
py::object _convert_inputs_cpp(
PyObject* const* args, size_t nargs, py::object dtype, py::object device) {
ComputingGraph* graph = nullptr;
py::handle typeobj;
py::list lis;
for (size_t i = 0; i < nargs; ++i) {
py::handle h = py::handle(args[i]);
lis.append(h);
if (py::isinstance<PySymbolVar>(h)) {
auto var = h.cast<PySymbolVar*>();
auto g = var->m_node->owner_graph();
if (!graph) {
graph = g;
typeobj = h.get_type();
} else {
mgb_assert(graph == g);
}
}
}
if (graph) {
CompNode cn = device.cast<CompNode>();
for (size_t i = 0; i < nargs; ++i) {
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
if (py::isinstance<PySymbolVar>(lis[i])) {
lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
}
}
auto convert = [&](py::object value) {
if (value.ptr() == Py_None) {
return value;
}
return _convert_single_value_cpp(value, dtype, device);
};
for (size_t i = 0; i < lis.size(); ++i) {
lis[i] = convert(lis[i]);
}
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr()));
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
......@@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _astype_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* convert_single_value_cpp(
PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _convert_single_value_cpp(
py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
py::object dtype = py::reinterpret_steal<py::object>(
dtype_promotion(self, args, nargs - 1));
py::object device;
if (args[nargs - 1] == Py_None) {
device = py::reinterpret_steal<py::object>(
get_device(self, args, nargs - 1));
} else {
device = py::reinterpret_borrow<py::object>(args[nargs - 1]);
}
return _convert_inputs_cpp(args, nargs - 1, dtype, device).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
......@@ -2,6 +2,10 @@
namespace mgb::imperative::python {
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
......@@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册