提交 2f3bc2db 编写于 作者: M Megvii Engine Team

perf(mge/utils): move astensor1d into C++

GitOrigin-RevId: e7c6659020d9db5a4b17f2a35a40ea2a99f7330d
上级 fa62f6c0
......@@ -9,6 +9,8 @@
import os
from ._imperative_rt.core2 import set_cpp_use_symbolic_shape
_use_symbolic_shape = False
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
_use_symbolic_shape = True
......@@ -25,3 +27,6 @@ def set_symbolic_shape(option: bool):
_org = _use_symbolic_shape
_use_symbolic_shape = option
return _org
set_cpp_use_symbolic_shape(use_symbolic_shape)
......@@ -22,12 +22,12 @@ from .._imperative_rt.core2 import (
astype_cpp,
broadcast_cpp,
dtype_promotion,
getitem_cpp,
)
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar
from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp
from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp
from ..ops import builtin
from . import amp
from .indexing import getitem, setitem
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph
_ElwMod = builtin.Elemwise.Mode
......@@ -544,11 +544,11 @@ class ArrayMethodMixin(abc.ABC):
yield self[i]
def __getitem__(self, index):
return getitem(self, index)
return getitem_cpp(self, index)
def __setitem__(self, index, value):
if index is not Ellipsis:
value = setitem(self, index, value)
value = setitem_cpp(self, index, value)
self._reset(value)
__contains__ = _todo
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .._imperative_rt.core2 import (
getitem_cpp,
set_cpp_astensor1d,
set_cpp_use_symbolic_shape,
setitem_cpp,
)
from .._trace_option import use_symbolic_shape
from .utils import astensor1d
def getitem(tensor, index):
return getitem_cpp(tensor, index)
def setitem(tensor, index, value):
return setitem_cpp(tensor, index, value)
set_cpp_use_symbolic_shape(use_symbolic_shape)
set_cpp_astensor1d(astensor1d)
......@@ -20,6 +20,7 @@ from .._imperative_rt.core2 import (
_get_convert_inputs,
_set_convert_inputs,
apply,
astensor1d_cpp,
astype_cpp,
convert_inputs_cpp,
convert_single_value_cpp,
......@@ -50,14 +51,6 @@ def set_convert_inputs(flag):
return _set_convert_inputs(flag)
def concatenate(inputs, axis=0, *, device=None):
inputs = convert_inputs(*inputs)
if device is None:
device = get_device(inputs)
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs)
return result
def convert_single_value(v, *, dtype=None, device=None):
return convert_single_value_cpp(v, dtype, device)
......@@ -104,34 +97,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
* numpy array
* tensor (returned as is, regardless of dtype and device)
"""
try:
ndim = x.ndim
except AttributeError:
pass
except ValueError:
if dtype is not None and dtype != x.dtype:
x = astype_cpp(x, dtype)
if device is not None:
cn = as_device(device).to_c()
(x,) = apply(builtin.Copy(comp_node=cn), x)
return x
else:
if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim)
if not isinstance(x, (Tensor, SymbolVar)):
x = Const(x, dtype, device, reference)
return x
if not isinstance(x, collections.abc.Sequence):
raise TypeError
if any(isinstance(i, (Tensor, SymbolVar)) for i in x):
x = concatenate(x, device=device) if len(x) > 1 else x[0]
if dtype is not None:
x = astype_cpp(x, dtype)
return x
x = Const(x, dtype, device, reference)
return x
return astensor1d_cpp(x, dtype, device, reference)
def _normalize_axis(
......
......@@ -104,13 +104,12 @@ struct SymbolVarContext {
interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
PyTypeObject* py_tensor_type = nullptr;
PyObject *cpp_use_symbolic_shape, *cpp_astensor1d;
PyObject* cpp_use_symbolic_shape;
#define REGISTE_APPLY_FUNC(mode) \
void set_##mode(py::object pyf) { mode = pyf.ptr(); }
REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)
REGISTE_APPLY_FUNC(cpp_astensor1d)
#undef REGISTE_APPLY_FUNC
......@@ -426,6 +425,7 @@ WRAP_FUNC_PY35(Const);
WRAP_FUNC_PY35(astype_cpp);
WRAP_FUNC_PY35(convert_single_value_cpp);
WRAP_FUNC_PY35(convert_inputs_cpp);
WRAP_FUNC_PY35(astensor1d_cpp);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
......@@ -568,6 +568,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(astype_cpp, astype_cpp),
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {
......@@ -957,8 +958,6 @@ void init_tensor(py::module m) {
m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);
m.def("set_cpp_astensor1d", &set_cpp_astensor1d);
m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
......
......@@ -310,6 +310,27 @@ bool is_bool_dtype(PyObject* args) {
return ret;
}
py::object device2obj(py::handle device, bool mapping = false) {
if (device.ptr() == Py_None) {
return py::cast(CompNode::load(get_default_device()));
} else if (py::isinstance<py::str>(device)) {
if (mapping) {
py::object dmap = getattr(
py::reinterpret_borrow<py::object>((PyObject*)py_tensor_type),
"dmap_callback");
if (dmap.ptr() != Py_None) {
return device2obj(dmap(device), false);
}
}
return py::cast(CompNode::load(device.cast<std::string>()));
} else if (py::isinstance<CompNode>(device)) {
return py::reinterpret_borrow<py::object>(device);
} else {
return getattr(device, "_cn");
}
}
py::object _Const(
py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) {
py::object val = py::reinterpret_borrow<py::object>(value);
......@@ -347,7 +368,7 @@ py::object _Const(
if (device.ptr() == Py_None) {
cn = ref_var->m_node->comp_node();
} else {
cn = device.cast<CompNode>();
cn = device2obj(device).cast<CompNode>();
}
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
......@@ -355,23 +376,7 @@ py::object _Const(
auto typeobj = ref.get_type();
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
py::object device_obj;
if (device.ptr() == Py_None) {
device_obj = py::cast(CompNode::load(get_default_device()));
} else if (py::isinstance<py::str>(device)) {
py::object dmap =
getattr(py::reinterpret_borrow<py::object>((PyObject*)py_tensor_type),
"dmap_callback");
if (dmap.ptr() != Py_None) {
device_obj = dmap(device);
} else {
device_obj = py::cast(CompNode::load(device.cast<std::string>()));
}
} else if (py::isinstance<CompNode>(device)) {
device_obj = py::reinterpret_borrow<py::object>(device);
} else {
device_obj = getattr(device, "_cn");
}
py::object device_obj = device2obj(device, true);
py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
}
......@@ -422,6 +427,197 @@ py::tuple _make_shape_tuple(py::handle shape) {
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr()));
}
bool is_tensor_or_symbolvar(py::handle arg) {
return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg);
}
bool is_py_sequence(py::handle arg) {
if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) ||
py::isinstance<PySymbolVar>(arg)) {
return false;
}
return PySequence_Check(arg.ptr());
}
mgb::DType _get_dtype(py::handle tensor) {
if (auto tw = TensorWrapper::try_cast(tensor.ptr())) {
return tw->m_tensor->dtype();
} else {
auto var = tensor.cast<PySymbolVar*>();
return var->m_node->dtype();
}
}
py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
PyArray_Descr* descr;
if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) {
throw py::value_error(ssprintf(
"can not convert to numpy.dtype from %s",
dtype_hdl.ptr()->ob_type->tp_name));
}
PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get();
if (!dtype_equal(cur, descr)) {
std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr));
py::object Op = py::cast(op);
std::vector<PyObject*> p;
p.resize(2);
p[0] = Op.ptr();
p[1] = tensor.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
} else {
return py::reinterpret_borrow<py::object>(tensor);
}
}
py::object _convert_single_value_cpp(
py::handle value, py::handle dtype, py::handle device) {
if (is_tensor_or_symbolvar(value)) {
if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) {
return _astype_cpp(value, dtype);
}
} else {
return _Const(value, dtype, device, py::none());
}
return py::reinterpret_borrow<py::object>(value);
}
py::object _convert_inputs_cpp(
PyObject* const* args, size_t nargs, py::object dtype, py::object device) {
ComputingGraph* graph = nullptr;
py::handle typeobj;
py::list lis;
for (size_t i = 0; i < nargs; ++i) {
py::handle h = py::handle(args[i]);
lis.append(h);
if (py::isinstance<PySymbolVar>(h)) {
auto var = h.cast<PySymbolVar*>();
auto g = var->m_node->owner_graph();
if (!graph) {
graph = g;
typeobj = h.get_type();
} else {
mgb_assert(graph == g);
}
}
}
if (graph) {
CompNode cn = device2obj(device).cast<CompNode>();
for (size_t i = 0; i < nargs; ++i) {
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
if (!py::isinstance<PySymbolVar>(lis[i])) {
lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
}
}
auto convert = [&](py::object value) {
if (value.ptr() == Py_None) {
return value;
}
return _convert_single_value_cpp(value, dtype, device);
};
for (size_t i = 0; i < lis.size(); ++i) {
lis[i] = convert(lis[i]);
}
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr()));
}
py::object _astensor1d_cpp(
py::handle value, py::handle dtype, py::handle device, py::handle ref) {
py::object ret;
py::object device_obj = py::none();
py::object ndim_obj = py::none();
if (device.ptr() != Py_None) {
device_obj = device2obj(device);
}
if (py::isinstance<PySymbolVar>(value)) {
try {
getattr(value, "ndim");
} catch (py::error_already_set& err) {
if (dtype.ptr() != Py_None) {
ret = _astype_cpp(value, dtype);
} else {
ret = py::reinterpret_borrow<py::object>(value);
}
if (device.ptr() != Py_None) {
std::shared_ptr<OpDef> op = Copy::make(device_obj.cast<CompNode>());
py::object Op = py::cast(op);
std::vector<PyObject*> p;
p.resize(2);
p[0] = Op.ptr();
p[1] = ret.ptr();
py::tuple copy_ret = py::reinterpret_steal<py::object>(
py_apply(NULL, p.data(), p.size()));
return copy_ret[0];
}
return ret;
}
}
size_t ndim = 999;
if (hasattr(value, "ndim")) {
ndim = getattr(value, "ndim").cast<size_t>();
if (ndim != 0 && ndim != 1) {
throw py::value_error("ndim != 1 or 0, get : " + std::to_string(ndim));
}
if (!is_tensor_or_symbolvar(value)) {
return _Const(value, dtype, device, ref);
} else {
return py::reinterpret_borrow<py::object>(value);
}
}
if (!is_py_sequence(value)) {
throw py::type_error();
}
py::list lis = py::reinterpret_steal<py::list>(PySequence_List(value.ptr()));
bool need_concat = false;
for (size_t i = 0; i < lis.size(); ++i) {
if (is_tensor_or_symbolvar(lis[i])) {
need_concat = true;
break;
}
}
if (!need_concat) {
return _Const(value, dtype, device, ref);
}
if (lis.size() > 1) {
std::vector<PyObject*> c_args(lis.size() + 1);
for (size_t i = 0; i < lis.size(); ++i) {
c_args[i] = lis[i].ptr();
}
c_args[lis.size()] = Py_None;
py::tuple inp_tup = py::reinterpret_steal<py::tuple>(
convert_inputs_cpp(NULL, c_args.data(), c_args.size()));
if (device_obj.ptr() == Py_None) {
std::vector<PyObject*> inp(inp_tup.size());
for (size_t i = 0; i < inp_tup.size(); ++i) {
inp[i] = inp_tup[i].ptr();
}
device_obj = py::cast(_get_device(inp.data(), inp.size()));
}
std::shared_ptr<OpDef> op = Concat::make(0, device_obj.cast<CompNode>());
py::object Op = py::cast(op);
std::vector<PyObject*> p;
p.resize(inp_tup.size() + 1);
p[0] = Op.ptr();
for (size_t i = 0; i < inp_tup.size(); ++i) {
p[i + 1] = inp_tup[i].ptr();
}
py::tuple concat_ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
ret = concat_ret[0];
} else {
ret = lis[0];
}
if (dtype.ptr() != Py_None) {
return _astype_cpp(ret, dtype);
} else {
return ret;
}
}
py::object _get_index(py::object tensor, py::object src) {
if (!TensorWrapper::try_cast(tensor.ptr()) &&
!py::isinstance<PySymbolVar>(tensor)) {
......@@ -501,7 +697,12 @@ py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
size_t ndim_incr = 1;
if (hasattr(handle, "dtype") && is_bool_dtype(handle.ptr()) &&
hasattr(handle, "ndim")) {
py::object ndim = getattr(handle, "ndim");
py::object ndim;
try {
ndim = getattr(handle, "ndim");
} catch (py::error_already_set& err) {
has_unknown_ndim_bool_index = true;
}
if (PyLong_Check(ndim.ptr())) {
ndim_incr = PyLong_AsLong(ndim.ptr());
} else {
......@@ -540,6 +741,8 @@ py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
}
}
py::object _reshape_cpp(py::handle inp_hdl, py::handle args);
py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
py::tuple cur_shape = _make_shape_tuple(py::handle(getattr(tensor, "shape")));
py::list new_tuple_val(0);
......@@ -556,7 +759,8 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
if (cur_shape[tdim + j - offset].cast<size_t>() !=
ishape[j].cast<size_t>()) {
std::string msg =
"boolean index did not match tensor along dimension " +
"boolean index did not match tensor along "
"dimension " +
std::to_string(tdim + j) + "; dimension is " +
std::to_string(
cur_shape[tdim + j - offset].cast<size_t>()) +
......@@ -580,19 +784,10 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
new_shape.append(cur_shape[j]);
}
py::tuple args = py::make_tuple(new_shape);
PyObject* shape_tensor =
PyObject_CallObject(cpp_astensor1d, args.ptr());
py::object reshape_func = getattr(tensor, "reshape");
Py_INCREF(shape_tensor);
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, shape_tensor);
PyObject* new_tensor =
PyObject_CallObject(reshape_func.ptr(), Args);
Py_XDECREF(Args);
tensor = py::reinterpret_steal<py::object>(new_tensor);
cur_shape = _make_shape_tuple(py::handle(shape_tensor));
Py_XDECREF(shape_tensor);
py::object shape_tensor = _astensor1d_cpp(
new_shape, py::none(), py::none(), py::none());
tensor = _reshape_cpp(tensor, shape_tensor);
cur_shape = _make_shape_tuple(shape_tensor);
} else {
for (size_t j = 0; j < i; ++j) {
new_shape.append(cur_shape[j]);
......@@ -602,7 +797,7 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
new_shape.append(cur_shape[j]);
}
cur_shape = new_shape;
tensor = getattr(tensor, "reshape")(cur_shape);
tensor = _reshape_cpp(tensor, cur_shape);
}
offset++;
tdim += ndim;
......@@ -616,6 +811,18 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
return py::make_tuple(tensor, py::reinterpret_borrow<py::tuple>(new_tuple_val));
}
std::pair<size_t, bool> get_ndim_safe(py::handle tensor) {
if (auto p = TensorWrapper::try_cast(tensor.ptr())) {
return {p->m_tensor->shape()->ndim, true};
}
try {
return {getattr(tensor, "ndim").cast<size_t>(), true};
} catch (py::error_already_set& err) {
return {0, false};
}
}
py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
py::tuple tuple_val;
......@@ -637,7 +844,7 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
need_remove_ellipsis = true;
} else {
if (is_bool_dtype(k.ptr()) && hasattr(k, "ndim")) {
size_t ndim = getattr(k, "ndim").cast<size_t>();
size_t ndim = get_ndim_safe(k).first;
idx_ndim += ndim;
if (ndim > 1) {
need_expand_bool_dim = true;
......@@ -712,87 +919,266 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
return py::make_tuple(inp, tensors, items, use_subtensor, need_expand_bool_dim);
}
py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
if (try_res.size() == 2) {
return try_res[0];
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
py::object _expand_args(py::handle args) {
if (!PyTuple_Check(args.ptr())) {
return py::reinterpret_borrow<py::object>(args);
}
static std::shared_ptr<OpDef> op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr());
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) ||
is_tensor_or_symbolvar(args_tup[0].ptr()))) {
return py::reinterpret_borrow<py::object>(args_tup[0]);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> p;
p.resize(tensors.size() + 2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
p[i + 2] = tensors[i].ptr();
return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr()));
}
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
py::object org_shape = getattr(inp_hdl, "shape");
py::object val = py::reinterpret_borrow<py::object>(val_hdl);
if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) {
val =
_Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
inp_hdl);
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
std::vector<int32_t> shp;
if (!PyTuple_Check(shape.ptr())) {
return {shp, false};
}
static std::shared_ptr<OpDef> op, set_op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
for (size_t i = 0; i < tup.size(); ++i) {
if (!PyLong_Check(tup[i].ptr())) {
return {shp, false};
} else {
op = IndexingMultiAxisVec::make(cpp_items);
shp.push_back(tup[i].cast<int32_t>());
}
std::vector<PyObject*> p;
p.resize(tensors.size() + 2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
p[i + 2] = tensors[i].ptr();
}
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
py::object tmp_result = ret[0];
return {shp, true};
}
try {
py::object value_tuple_shape = val.attr("_tuple_shape");
py::object tmp_result_tuple_shape = tmp_result.attr("_tuple_shape");
py::tuple value_shape = py::reinterpret_borrow<py::tuple>(value_tuple_shape);
py::tuple tmp_result_shape =
py::reinterpret_borrow<py::tuple>(tmp_result_tuple_shape);
for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) {
size_t vs = value_shape[value_shape.size() - i - 1].cast<size_t>();
bool enable_fastpath(py::handle inp) {
if (!TensorWrapper::try_cast(inp.ptr()) ||
TransformationManager::get_instance()
.segments[TransformationManager::Segment::Trace]
.size() > 0 ||
TransformationManager::get_instance()
.segments[TransformationManager::Segment::ModuleTrace]
.size() > 0) {
return false;
}
return true;
}
py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) {
py::object shape_hdl = _expand_args(args);
bool auto_infer = false;
py::list lis;
py::list new_shape;
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) {
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr()));
for (size_t i = 0; i < lis.size(); ++i) {
if (lis[i].ptr() == Py_None) {
auto_infer = true;
size_t right = lis.size() - i;
py::object tshp = getattr(inp_hdl, "_tuple_shape");
if (tshp.ptr() == Py_None) {
throw py::index_error("does not support `None` with unknown shape");
}
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp);
if (inp_shape.size() >= right) {
if (enable_fastpath(inp_hdl)) {
lis[i] = inp_shape[inp_shape.size() - right];
}
new_shape.append(inp_shape[inp_shape.size() - right]);
} else {
throw py::value_error("invalid broadcast shape");
}
} else {
new_shape.append(lis[i]);
if (PyLong_Check(lis[i].ptr())) {
int32_t s = lis[i].cast<int32_t>();
if (s < 0) {
throw py::value_error(
"expect shape[" + std::to_string(i) +
"] >= 0 or use `None` to auto infer, got " +
std::to_string(s));
}
}
}
}
}
if (auto_infer) {
if (enable_fastpath(inp_hdl)) {
shape_hdl = py::reinterpret_borrow<py::tuple>(lis);
} else {
shape_hdl = _astensor1d_cpp(
new_shape, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl);
}
}
py::object shape_tuple;
try {
shape_tuple = _make_shape_tuple(shape_hdl);
} catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
}
auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
py::object shape_tensor;
if (fastpath) {
op = Broadcast::make(shape);
p.resize(2);
} else {
op = Broadcast::make();
shape_tensor = _astensor1d_cpp(
shape_hdl, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl);
p.resize(3);
p[2] = shape_tensor.ptr();
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _reshape_cpp(py::handle inp_hdl, py::handle args) {
py::object shape_hdl = _expand_args(args);
py::object shape_tuple;
try {
shape_tuple = _make_shape_tuple(shape_hdl);
} catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
}
int32_t unspec_axis = -1;
if (PyTuple_Check(shape_tuple.ptr())) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape_tuple);
for (size_t i = 0; i < tup.size(); ++i) {
py::object obj = py::reinterpret_borrow<py::object>(tup[i]);
if (obj < py::int_(0)) {
if (obj.not_equal(py::int_(-1))) {
throw py::value_error(
"expect shape [" + std::to_string(i) + "] >= -1, got " +
repr(obj).cast<std::string>());
}
if (unspec_axis >= 0) {
throw py::value_error(
"multiple -1 in shape: " + std::to_string(unspec_axis) +
" & " + std::to_string(i));
}
unspec_axis = i;
}
}
}
auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
py::object shape_tensor;
if (fastpath) {
if (unspec_axis >= 0) {
op = Reshape::make(unspec_axis, shape);
} else {
op = Reshape::make(::megdnn::param::OptionalAxisV1::INVALID_AXIS, shape);
}
p.resize(2);
} else {
shape.clear();
if (unspec_axis >= 0) {
op = Reshape::make(unspec_axis, shape);
} else {
op = Reshape::make();
}
shape_tensor = _astensor1d_cpp(
shape_hdl, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl);
p.resize(3);
p[2] = shape_tensor.ptr();
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
if (try_res.size() == 2) {
return try_res[0];
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
}
static std::shared_ptr<OpDef> op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> p;
p.resize(tensors.size() + 2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
p[i + 2] = tensors[i].ptr();
}
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
py::object org_shape = getattr(inp_hdl, "shape");
py::object val = py::reinterpret_borrow<py::object>(val_hdl);
if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) {
val =
_Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
inp_hdl);
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
}
static std::shared_ptr<OpDef> op, set_op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> p;
p.resize(tensors.size() + 2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
p[i + 2] = tensors[i].ptr();
}
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
py::object tmp_result = ret[0];
try {
py::tuple value_shape =
py::reinterpret_borrow<py::tuple>(val.attr("_tuple_shape"));
py::tuple tmp_result_shape =
py::reinterpret_borrow<py::tuple>(tmp_result.attr("_tuple_shape"));
for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) {
size_t vs = value_shape[value_shape.size() - i - 1].cast<size_t>();
size_t ts =
tmp_result_shape[tmp_result_shape.size() - i - 1].cast<size_t>();
if (vs != 1 && vs != ts) {
......@@ -815,14 +1201,7 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
} catch (py::error_already_set& err) {
;
}
py::object broadcast_func = getattr(val, "_broadcast");
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, getattr(tmp_result, "shape").release().ptr());
PyObject* new_val = PyObject_CallObject(broadcast_func.ptr(), Args);
Py_XDECREF(Args);
val = py::reinterpret_steal<py::object>(new_val);
val = _broadcast_cpp(val, getattr(tmp_result, "shape"));
if (up[3].cast<bool>()) {
set_op = SetSubtensor::make(cpp_items);
} else {
......@@ -843,29 +1222,12 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
py::object res = result[0];
if (up[4].cast<bool>()) {
py::object reshape_func = getattr(res, "reshape");
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, org_shape.release().ptr());
PyObject* new_tensor = PyObject_CallObject(reshape_func.ptr(), Args);
Py_XDECREF(Args);
res = py::reinterpret_steal<py::object>(new_tensor);
res = _reshape_cpp(res, org_shape);
}
return res;
}
bool is_tensor_or_symbolvar(py::handle arg) {
return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance<PySymbolVar>(arg);
}
bool is_py_sequence(py::handle arg) {
if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) ||
py::isinstance<PySymbolVar>(arg)) {
return false;
}
return PySequence_Check(arg.ptr());
}
py::object _split_cpp(
py::handle inp_hdl, py::handle nsplits_or_sections_hdl, py::handle axis_hdl) {
py::object shape_obj = getattr(inp_hdl, "shape");
......@@ -936,7 +1298,7 @@ py::object _split_cpp(
std::vector<int32_t> list2vector(py::handle li) {
std::vector<int32_t> axis;
if (is_py_sequence(li.ptr())) {
if (is_py_sequence(li)) {
py::list tmp_list = py::reinterpret_steal<py::list>(PySequence_List(li.ptr()));
for (size_t i = 0; i < tmp_list.size(); ++i) {
axis.push_back(tmp_list[i].attr("__int__")().cast<int32_t>());
......@@ -958,13 +1320,9 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
ndim += shape->ndim;
}
} else {
auto&& var = inp_hdl.cast<PySymbolVar*>();
auto&& mgr = var->m_node->owner_graph()->static_infer_manager();
auto&& shape = mgr.infer_shape_fallible(var->m_node);
if (shape) {
unknown_ndim = false;
ndim += shape->ndim;
}
auto&& inp_ndim = get_ndim_safe(inp_hdl);
ndim += inp_ndim.first;
unknown_ndim &= ~inp_ndim.second;
}
for (size_t i = 0; i < axis.size(); ++i) {
if (axis[i] < 0) {
......@@ -1010,20 +1368,17 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
}
}
} else {
auto&& var = inp_hdl.cast<PySymbolVar*>();
auto&& mgr = var->m_node->owner_graph()->static_infer_manager();
auto&& shape = mgr.infer_shape_fallible(var->m_node);
if (shape) {
ndim = shape->ndim;
py::tuple shape =
py::reinterpret_borrow<py::tuple>(getattr(inp_hdl, "_tuple_shape"));
ndim = shape.size();
if (axis_hdl.ptr() == Py_None) {
for (size_t i = 0; i < shape->ndim; ++i) {
if (shape->shape[i] == 1) {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i].cast<size_t>() == 1) {
axis.push_back(i);
}
}
}
}
}
for (size_t i = 0; i < axis.size(); ++i) {
if (axis[i] < 0) {
axis[i] += static_cast<int32_t>(ndim);
......@@ -1043,27 +1398,6 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
size_t fast_ndim(py::handle tensor) {
if (auto p = TensorWrapper::try_cast(tensor.ptr())) {
return p->m_tensor->shape()->ndim;
}
return getattr(tensor, "ndim").cast<size_t>();
}
py::object _expand_args(py::handle args) {
if (!PyTuple_Check(args.ptr())) {
return py::reinterpret_borrow<py::object>(args);
}
py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr());
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) ||
is_tensor_or_symbolvar(args_tup[0].ptr()))) {
return py::reinterpret_borrow<py::object>(args_tup[0]);
} else {
return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr()));
}
}
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py::object obj = _expand_args(args);
py::list lis;
......@@ -1077,7 +1411,7 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
lis = py::reinterpret_steal<py::list>(maybe_list);
}
}
if (fast_ndim(inp_hdl) == 0) {
if (get_ndim_safe(inp_hdl).first == 0) {
if (lis.size() != 0) {
throw py::index_error(
"transpose for scalar does not accept additional args");
......@@ -1112,351 +1446,79 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
return ret[0];
}
std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
std::vector<int32_t> shp;
if (!PyTuple_Check(shape.ptr())) {
return {shp, false};
}
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
for (size_t i = 0; i < tup.size(); ++i) {
if (!PyLong_Check(tup[i].ptr())) {
return {shp, false};
} else {
shp.push_back(tup[i].cast<int32_t>());
}
}
return {shp, true};
}
bool enable_fastpath(py::handle inp) {
if (!TensorWrapper::try_cast(inp.ptr()) ||
TransformationManager::get_instance()
.segments[TransformationManager::Segment::Trace]
.size() > 0 ||
TransformationManager::get_instance()
.segments[TransformationManager::Segment::ModuleTrace]
.size() > 0) {
return false;
}
return true;
}
py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) {
py::object shape_hdl = _expand_args(args);
bool auto_infer = false;
py::list lis;
py::list new_shape;
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) {
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr()));
for (size_t i = 0; i < lis.size(); ++i) {
if (lis[i].ptr() == Py_None) {
auto_infer = true;
size_t right = lis.size() - i;
py::object tshp = getattr(inp_hdl, "_tuple_shape");
if (tshp.ptr() == Py_None) {
throw py::index_error("does not support `None` with unknown shape");
}
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp);
if (inp_shape.size() >= right) {
if (enable_fastpath(inp_hdl)) {
lis[i] = inp_shape[inp_shape.size() - right];
}
new_shape.append(inp_shape[inp_shape.size() - right]);
} else {
throw py::value_error("invalid broadcast shape");
}
} else {
new_shape.append(lis[i]);
if (PyLong_Check(lis[i].ptr())) {
int32_t s = lis[i].cast<int32_t>();
if (s < 0) {
throw py::value_error(
"expect shape[" + std::to_string(i) +
"] >= 0 or use `None` to auto infer, got " +
std::to_string(s));
}
}
}
}
}
if (auto_infer) {
if (enable_fastpath(inp_hdl)) {
shape_hdl = py::reinterpret_borrow<py::tuple>(lis);
} else {
py::tuple args = py::make_tuple(new_shape, inp_hdl);
py::dict kwargs;
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32());
kwargs["device"] = getattr(inp_hdl, "device");
shape_hdl = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr()));
}
}
py::object shape_tuple;
try {
shape_tuple = _make_shape_tuple(shape_hdl);
} catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
}
auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
py::object shape_tensor;
if (fastpath) {
op = Broadcast::make(shape);
p.resize(2);
} else {
op = Broadcast::make();
py::tuple args = py::make_tuple(shape_hdl, inp_hdl);
py::dict kwargs;
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32());
kwargs["device"] = getattr(inp_hdl, "device");
shape_tensor = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr()));
p.resize(3);
p[2] = shape_tensor.ptr();
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _reshape_cpp(py::handle inp_hdl, py::handle args) {
py::object shape_hdl = _expand_args(args);
py::object shape_tuple;
try {
shape_tuple = _make_shape_tuple(shape_hdl);
} catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
}
int32_t unspec_axis = -1;
if (PyTuple_Check(shape_tuple.ptr())) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape_tuple);
for (size_t i = 0; i < tup.size(); ++i) {
py::object obj = py::reinterpret_borrow<py::object>(tup[i]);
if (obj < py::int_(0)) {
if (obj.not_equal(py::int_(-1))) {
throw py::value_error(
"expect shape [" + std::to_string(i) + "] >= -1, got " +
repr(obj).cast<std::string>());
}
if (unspec_axis >= 0) {
throw py::value_error(
"multiple -1 in shape: " + std::to_string(unspec_axis) +
" & " + std::to_string(i));
}
unspec_axis = i;
}
}
}
auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
py::object shape_tensor;
if (fastpath) {
if (unspec_axis >= 0) {
op = Reshape::make(unspec_axis, shape);
} else {
op = Reshape::make(::megdnn::param::OptionalAxisV1::INVALID_AXIS, shape);
}
p.resize(2);
} else {
shape.clear();
if (unspec_axis >= 0) {
op = Reshape::make(unspec_axis, shape);
} else {
op = Reshape::make();
}
py::tuple args = py::make_tuple(shape_hdl, inp_hdl);
py::dict kwargs;
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32());
kwargs["device"] = getattr(inp_hdl, "device");
shape_tensor = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr()));
p.resize(3);
p[2] = shape_tensor.ptr();
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
mgb::DType _get_dtype(py::handle tensor) {
if (auto tw = TensorWrapper::try_cast(tensor.ptr())) {
return tw->m_tensor->dtype();
} else {
auto var = tensor.cast<PySymbolVar*>();
return var->m_node->dtype();
}
}
py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
PyArray_Descr* descr;
if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) {
throw py::value_error(ssprintf(
"can not convert to numpy.dtype from %s",
dtype_hdl.ptr()->ob_type->tp_name));
}
PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get();
if (!dtype_equal(cur, descr)) {
std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr));
py::object Op = py::cast(op);
std::vector<PyObject*> p;
p.resize(2);
p[0] = Op.ptr();
p[1] = tensor.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
} else {
return py::reinterpret_borrow<py::object>(tensor);
}
}
py::object _convert_single_value_cpp(
py::handle value, py::handle dtype, py::handle device) {
if (is_tensor_or_symbolvar(value)) {
if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) {
return _astype_cpp(value, dtype);
}
} else {
return _Const(value, dtype, device, py::none());
}
return py::reinterpret_borrow<py::object>(value);
}
py::object _convert_inputs_cpp(
PyObject* const* args, size_t nargs, py::object dtype, py::object device) {
ComputingGraph* graph = nullptr;
py::handle typeobj;
py::list lis;
for (size_t i = 0; i < nargs; ++i) {
py::handle h = py::handle(args[i]);
lis.append(h);
if (py::isinstance<PySymbolVar>(h)) {
auto var = h.cast<PySymbolVar*>();
auto g = var->m_node->owner_graph();
if (!graph) {
graph = g;
typeobj = h.get_type();
} else {
mgb_assert(graph == g);
}
}
}
if (graph) {
CompNode cn = device.cast<CompNode>();
for (size_t i = 0; i < nargs; ++i) {
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
if (py::isinstance<PySymbolVar>(lis[i])) {
lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
}
}
auto convert = [&](py::object value) {
if (value.ptr() == Py_None) {
return value;
}
return _convert_single_value_cpp(value, dtype, device);
};
for (size_t i = 0; i < lis.size(); ++i) {
lis[i] = convert(lis[i]);
}
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr()));
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
return _make_shape_tuple(args[0]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _getitem_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
return _getitem_cpp(args[0], args[1]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _setitem_cpp(
py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
.release()
.ptr();
return _setitem_cpp(args[0], args[1], args[2]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _split_cpp(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
.release()
.ptr();
return _split_cpp(args[0], args[1], args[2]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* expand_dims_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _expand_dims_cpp(py::handle(args[0]), py::handle(args[1]))
.release()
.ptr();
return _expand_dims_cpp(args[0], args[1]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _squeeze_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
return _squeeze_cpp(args[0], args[1]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _transpose_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
return _transpose_cpp(args[0], args[1]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _broadcast_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
return _broadcast_cpp(args[0], args[1]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _reshape_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
return _reshape_cpp(args[0], args[1]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _Const(py::handle(args[0]), py::handle(args[1]), py::handle(args[2]),
py::handle(args[3]))
.release()
.ptr();
return _Const(args[0], args[1], args[2], args[3]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _astype_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
return _astype_cpp(args[0], args[1]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
......@@ -1464,10 +1526,7 @@ PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PyObject* convert_single_value_cpp(
PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _convert_single_value_cpp(
py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
.release()
.ptr();
return _convert_single_value_cpp(args[0], args[1], args[2]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
......@@ -1488,4 +1547,11 @@ PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _astensor1d_cpp(args[0], args[1], args[2], args[3]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
......@@ -32,4 +32,6 @@ PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
......@@ -511,6 +511,20 @@ def test_advance_indexing_with_bool(test_varnode):
network = Network()
else:
network = None
a = np.array([[True, False], [False, True]])
b = np.array([1])
aa = make_tensor(a, network)
bb = make_tensor(b, network)
np.testing.assert_equal(a[b], get_value(aa[bb]))
b = np.array([[True, True], [False, True]])
bb = make_tensor(b, network)
np.testing.assert_equal(a[b], get_value(aa[bb]))
if not test_varnode:
a[b] = False
aa[bb] = False
np.testing.assert_equal(a, get_value(aa))
a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([1, 2, 3])
c = np.array([1, 2, 3])
......@@ -525,67 +539,68 @@ def test_advance_indexing_with_bool(test_varnode):
a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([False, True, True])
c = np.array([2, 0]).astype(np.int32)
aa = Tensor(a)
bb = Tensor(b)
cc = Tensor(c)
np.testing.assert_equal(a[b, c], aa[bb, cc].numpy())
aa = make_tensor(a, network)
bb = make_tensor(b, network)
cc = make_tensor(c, network)
np.testing.assert_equal(a[b, c], get_value(aa[bb, cc]))
a[b, c] = -1.0
aa[bb, cc] = -1.0
np.testing.assert_equal(a, aa.numpy())
np.testing.assert_equal(a, get_value(aa))
d = np.array([-1, -2], dtype=np.float32)
dd = Tensor(d)
dd = make_tensor(d, network)
a[b, c] = d
aa[bb, cc] = dd
np.testing.assert_equal(a, aa.numpy())
np.testing.assert_equal(a, get_value(aa))
a = np.ones((2, 2))
b = np.array([[True, False], [False, True]])
aa = Tensor(a)
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
aa = make_tensor(a, network)
bb = make_tensor(b, network)
np.testing.assert_equal(a[b], get_value(aa[bb]))
b[:] = True
bb[:] = True
np.testing.assert_equal(a[b], aa[bb].numpy())
np.testing.assert_equal(a[:, [True, False]], aa[:, [True, False]].numpy())
np.testing.assert_equal(a[b], get_value(aa[bb]))
np.testing.assert_equal(a[:, [True, False]], get_value(aa[:, [True, False]]))
a = np.array([[True, False], [False, True]])
b = np.array([1])
aa = Tensor(a)
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
aa = make_tensor(a, network)
bb = make_tensor(b, network)
np.testing.assert_equal(a[b], get_value(aa[bb]))
b = np.array([[True, True], [False, True]])
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy())
bb = make_tensor(b, network)
np.testing.assert_equal(a[b], get_value(aa[bb]))
if not test_varnode:
a[b] = False
aa[bb] = False
np.testing.assert_equal(a, aa.numpy())
np.testing.assert_equal(a, get_value(aa))
a = np.ones((2, 2), dtype=np.int32)
b = np.array([[False, False], [False, False]])
aa = Tensor(a)
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[b].numpy())
np.testing.assert_equal(a[b], aa[bb].numpy())
aa = make_tensor(a, network)
bb = make_tensor(b, network)
np.testing.assert_equal(a[b], get_value(aa[b]))
np.testing.assert_equal(a[b], get_value(aa[bb]))
b = np.array([False, False])
bb = Tensor(b)
np.testing.assert_equal(a[b], aa[bb].numpy().reshape(a[b].shape))
bb = make_tensor(b, network)
np.testing.assert_equal(a[b], get_value(aa[bb]).reshape(a[b].shape))
a = np.arange(576).reshape(2, 3, 4, 3, 4, 2).astype("int32")
aa = Tensor(a)
aa = make_tensor(a, network)
b = (np.random.sample((2, 3, 4)) > 0.5).astype("bool")
bb = Tensor(b)
np.testing.assert_equal(a[b, :, 0:4:2], aa[bb, :, 0:4:2].numpy())
bb = make_tensor(b, network)
np.testing.assert_equal(a[b, :, 0:4:2], get_value(aa[bb, :, 0:4:2]))
b = (np.random.sample((4, 3, 4)) > 0.5).astype("bool")
bb = Tensor(b)
np.testing.assert_equal(a[..., b, 0:2], aa[..., bb, 0:2].numpy())
bb = make_tensor(b, network)
np.testing.assert_equal(a[..., b, 0:2], get_value(aa[..., bb, 0:2]))
b = (np.random.sample((3, 4, 3)) > 0.5).astype("bool")
bb = Tensor(b)
bb = make_tensor(b, network)
np.testing.assert_equal(
a[:, b, 0:2, [True, False]], aa[:, bb, 0:2, [True, False]].numpy()
a[:, b, 0:2, [True, False]], get_value(aa[:, bb, 0:2, [True, False]])
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册