diff --git a/imperative/python/megengine/core/tensor/indexing.py b/imperative/python/megengine/core/tensor/indexing.py index 4dc1da3bdb3e6eeced6424b11ad2919b5195922b..9906cf84592106ba39953189ee7cd1c151e90633 100644 --- a/imperative/python/megengine/core/tensor/indexing.py +++ b/imperative/python/megengine/core/tensor/indexing.py @@ -6,287 +6,23 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import Iterable - -import numpy as np - -from .._imperative_rt.core2 import SymbolVar, Tensor, apply +from .._imperative_rt.core2 import ( + getitem_cpp, + set_cpp_astensor1d, + set_cpp_use_symbolic_shape, + setitem_cpp, +) from .._trace_option import use_symbolic_shape -from ..ops import builtin -from ..ops.special import Const -from .utils import astensor1d, isscalar, make_shape_tuple - - -def remove_ellipsis(tensor, tuple_val): - cur_sum = 0 - pos = -1 - has_unkown_ndim_bool_index = False - for i_idx, i in enumerate(tuple_val): - if i is Ellipsis: - for j in tuple_val[:i_idx:-1]: - if j is Ellipsis: - raise IndexError("only one ellipsis is allowed") - pos = i_idx - else: - try: - cur_sum += ( - i.ndim - if hasattr(i, "dtype") - and i.dtype == np.bool_ - and hasattr(i, "ndim") - else 1 - ) - except ValueError: - has_unkown_ndim_bool_index = True - - if pos == -1: - return tuple_val - else: - if has_unkown_ndim_bool_index: - raise IndexError( - "Does not support bool index with unknown shape when using Ellipsis" - ) - try: - ndim_sum = tensor.ndim - except ValueError: - raise IndexError("Does not support Ellipsis when tensor's ndim is unknown.") - return ( - tuple_val[:pos] - + (slice(None, None, None),) * (ndim_sum - cur_sum) - + tuple_val[pos + 1 :] - ) - - -# XXX: assume same results during trace -def check_bool_index(tensor, tuple_val): - try: - cur_shape = make_shape_tuple(tensor.shape) - except ValueError: - return tensor, tuple_val - - new_tuple_val = [] - offset = 0 - tdim = 0 - for idx, i in enumerate(tuple_val): - if hasattr(i, "dtype") and i.dtype == np.bool_: - if i.ndim > 1: - tot = i.ndim - ishape = make_shape_tuple(i.shape) - for j in range(i.ndim): - if cur_shape[tdim + j - offset] != ishape[j]: - raise IndexError( - "boolean index did not match tensor along dimension {}; dimension is {} but corresponding boolean dimension is {}".format( - tdim + j, cur_shape[tdim + j - offset], ishape[j] - ) - ) - i = i.reshape(-1) - if not use_symbolic_shape(): - cur_shape = ( - cur_shape[:idx] - + (i.shape[0],) - + cur_shape[tdim + tot - offset :] - ) - else: - # XXX: use only for trace - new_shape = [] - for ii in range(idx): - new_shape.append(tensor.shape[ii]) - new_shape.append(i.shape[0]) - for ii in range(tdim + tot - offset, len(cur_shape)): - new_shape.append(cur_shape[ii]) - cur_shape = astensor1d(new_shape) - offset += 1 - tensor = tensor.reshape(cur_shape) - tdim += tot - if use_symbolic_shape(): - cur_shape = make_shape_tuple(cur_shape) - new_tuple_val.append(i) - else: - new_tuple_val.append(i) - tdim += 1 - return tensor, new_tuple_val - - -def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): - if not isinstance(tuple_val, tuple): - tuple_val = (tuple_val,) - ndim_indexed = 0 - for i in tuple_val: - if not i is Ellipsis: - ndim_indexed += ( - i.ndim - if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim") - else 1 - ) - else: - try: - if ndim_indexed > inp.ndim: - raise IndexError( - "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( - inp.ndim, len(tuple_val) - ) - ) - except ValueError: - # ignore - pass - - tuple_val = remove_ellipsis(inp, tuple_val) - use_subtensor = True - if inp.shape is not None: - inp, tuple_val = check_bool_index(inp, tuple_val) - - new_axes = [] - tensors = [] - items = [] - cur_axis = -1 - for i_idx, i in enumerate(tuple_val): - cur_axis += 1 - if i is np.newaxis: - if cur_axis >= 0: - new_axes.append(cur_axis) - continue - - if i is Ellipsis: - cur_axis = -1 - for j in tuple_val[:i_idx:-1]: - if j is Ellipsis: - raise IndexError("only one ellipsis is allowed") - if j is np.newaxis: - new_axes.append(cur_axis) - cur_axis -= 1 - continue - - if ( - not isscalar(i) - and not i is np.newaxis - and not i is Ellipsis - and not isinstance(i, slice) - ): - use_subtensor = False - - item = [ - cur_axis, - ] - - def is_bool_list(x): - if not isinstance(x, list): - return False - if len(x) == 0: - return False - for i in x: - if not isinstance(i, bool): - return False - return True - - def get_index(i): - if not isinstance(i, (Tensor, SymbolVar)): - if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: - (i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) - else: - (i,) = Const(i, dtype=np.int32, device=inp.device)(inp) - return i - assert isinstance(i, (Tensor, SymbolVar)) - if i.dtype != np.bool_: - return i - _, ind = apply(builtin.CondTake(), i, i) - return ind - - def push(v, item, tensors): - if v is None: - item.append(False) - else: - item.append(True) - v = get_index(v) - assert np.issubdtype(v.dtype, np.integer) or np.issubdtype( - v.dtype, np.bool_ - ), "var type in the subscript must be int or bool" - tensors.append(v) - - if isinstance(i, slice): - if i.start is None and i.stop is None and i.step is None: - continue - push(i.start, item, tensors) - push(i.stop, item, tensors) - push(i.step, item, tensors) - item.append(False) # idx - else: - item += [False,] * 3 # begin, end, stop - push(i, item, tensors) - assert len(item) == 5 - items.append(item) - if new_axes: - raise IndexError("newaxis is not allowed here") - return inp, tensors, items, use_subtensor - - -def try_condtake(tensor, index): - if not hasattr(index, "dtype") or not hasattr(index, "shape"): - return [] - if index.dtype != np.bool_ or make_shape_tuple(index.shape) != make_shape_tuple( - tensor.shape - ): - return [] - if isinstance(index, np.ndarray): - (index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) - assert isinstance(index, (Tensor, SymbolVar)) - if not isinstance(tensor, (Tensor, SymbolVar)): - raise TypeError("input must be a tensor") - if tensor.device != index.device: - raise ValueError( - "ambiguous device: {} vs {}".format(tensor.device, index.device) - ) - return apply(builtin.CondTake(), tensor, index) +from .utils import astensor1d def getitem(tensor, index): - try_result = try_condtake(tensor, index) - if len(try_result) == 2: - return try_result[0] - tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) - if use_subtensor: - op = builtin.Subtensor(items=items) - else: - op = builtin.IndexingMultiAxisVec(items=items) - (result,) = apply(op, tensor, *tensors) - return result + return getitem_cpp(tensor, index) def setitem(tensor, index, value): - org_shape = tensor.shape - try_result = try_condtake(tensor, index) - if len(try_result) == 2: - index = try_result[1] - tensor = tensor.reshape(-1) - if not isinstance(value, (Tensor, SymbolVar)): - (value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) - tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index) - if use_subtensor: - op = builtin.Subtensor(items=items) - else: - op = builtin.IndexingMultiAxisVec(items=items) + return setitem_cpp(tensor, index, value) - (tmp_result,) = apply(op, tensor, *tensors) - try: - value_shape = value._tuple_shape - tmp_result_shape = tmp_result._tuple_shape - except ValueError: - pass - else: - for i in range(min(len(value_shape), len(tmp_result_shape))): - if (value_shape[-i - 1] != 1) & ( - value_shape[-i - 1] != tmp_result_shape[-i - 1] - ): - raise ValueError( - "cannot copy tensor with shape {} to subtensor with shape {}".format( - value_shape, tmp_result_shape - ) - ) - value = value._broadcast(tmp_result.shape) - if use_subtensor: - op = builtin.SetSubtensor(items=items) - else: - op = builtin.IndexingSetMultiAxisVec(items=items) - (result,) = apply(op, tensor, value, *tensors) - result = result.reshape(org_shape) - return result +set_cpp_use_symbolic_shape(use_symbolic_shape) +set_cpp_astensor1d(astensor1d) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index c9431dc91b542a8abeb2e2d3f44725a87469ef6c..45d934767390a57ed5d7e366a015bcf1409de7d2 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -12,7 +12,14 @@ from typing import Iterable, Union import numpy as np from .._imperative_rt import make_const -from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device +from .._imperative_rt.core2 import ( + SymbolVar, + Tensor, + apply, + dtype_promotion, + get_device, + make_shape_tuple, +) from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from .._wrap import as_device from ..ops import builtin @@ -163,30 +170,6 @@ def astensor1d(x, *reference, dtype=None, device=None): return x -def _expand_int(s, i): - if isinstance(i, (Tensor, SymbolVar)): - i_np = i.numpy() - if i_np.ndim == 0: - s.append(int(i_np)) - else: - s += list(i_np) - return - if isinstance(i, Iterable): - for ii in i: - _expand_int(s, ii) - return - if np.issubdtype(type(i), np.integer): - s.append(i) - return - raise - - -def make_shape_tuple(shape): - s = [] - _expand_int(s, shape) - return tuple(s) - - def _normalize_axis( ndim: int, axis: Union[int, Iterable], reverse=False ) -> Union[int, list]: diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index b99a77a99cc26d0abcfc6f647852ce5469dcd5d4..1f74c27ffc3aee8c6c52eb98c21ec12015f8f77e 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -56,6 +56,15 @@ WeakKeyMap module_trace_info_map; interpreter::Interpreter::Channel* interpreter_for_py = nullptr; PyTypeObject* py_tensor_type = nullptr; +PyObject *cpp_use_symbolic_shape, *cpp_astensor1d; + +#define REGISTE_APPLY_FUNC(mode) \ + void set_##mode(py::object pyf) { mode = pyf.ptr(); } + +REGISTE_APPLY_FUNC(cpp_use_symbolic_shape) +REGISTE_APPLY_FUNC(cpp_astensor1d) + +#undef REGISTE_APPLY_FUNC PyObject* py_apply( PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) { @@ -520,6 +529,557 @@ CompNode _get_device(PyObject* const* args, size_t nargs) { return cn; } +bool is_scalar(PyObject* tensor) { + if (py::isinstance(py::handle(tensor))) { + auto var = py::handle(tensor).cast(); + return var->is_scalar; + } + auto* tw = TensorWrapper::try_cast(tensor); + if (tw) { + return tw->m_tensor->is_scalar(); + } + return PyArray_CheckAnyScalar(tensor); +} + +bool is_bool_list(PyObject* arg) { + if (!PyList_Check(arg)) { + return false; + } + size_t sz = PyList_Size(arg); + if (!sz) { + return false; + } + for (size_t i = 0; i < sz; ++i) { + PyObject* handle = PyList_GetItem(arg, i); + if (!PyBool_Check(handle)) { + return false; + } + } + return true; +} + +bool is_bool_dtype(PyObject* args) { + if (!PyObject_HasAttrString(args, "dtype")) + return false; + PyObject* dobj = PyObject_GetAttrString(args, "dtype"); + PyArray_Descr* dtype; + PyArray_DescrConverter(dobj, &dtype); + bool ret = (dtype->kind == 'b'); + Py_XDECREF(dtype); + Py_XDECREF(dobj); + return ret; +} + +py::object _Const( + py::handle value, py::handle dtype, py::handle device, py::handle ref) { + py::object val = py::reinterpret_borrow(value); + if (PyArray_Check(value.ptr())) { + py::tuple strides = + py::reinterpret_borrow(getattr(value, "strides")); + bool need_squeeze = false; + for (size_t i = 0; i < strides.size(); ++i) { + if (strides[i].cast() == 0) { + need_squeeze = true; + } + } + if (need_squeeze) { + val = py::reinterpret_borrow(value); + val = val.attr("squeeze")(); + val = val.attr("reshape")(val.attr("shape")); + } + } + if (py::isinstance(ref)) { + auto ref_var = ref.cast(); + auto* graph = ref_var->m_node->owner_graph(); + auto cn = device.cast(); + OperatorNodeConfig config(cn); + auto hv = npy::np2tensor( + val.ptr(), npy::Meth::borrow(cn), dtype.cast()); + auto typeobj = ref.get_type(); + return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); + } + py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none()); + return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); +} + +py::tuple _make_shape_tuple(py::handle shape) { + py::list orig; + py::list ret(0); + auto solve_one = [&](py::handle val) { + if (TensorWrapper::try_cast(val.ptr()) || py::isinstance(val)) { + py::object np = getattr(val, "numpy")(); + PyArrayObject* arr = (PyArrayObject*)np.ptr(); + PyObject* maybe_list = PyArray_ToList(arr); + if (PyList_Check(maybe_list)) { + py::list may = py::reinterpret_steal(maybe_list); + for (size_t i = 0; i < may.size(); ++i) { + ret.append(may[i]); + } + } else { + mgb_assert(PyLong_Check(maybe_list)); + ret.append(PyLong_AsLong(maybe_list)); + Py_XDECREF(maybe_list); + } + } else if (PyArray_Check(val.ptr())) { + ret.append(PyArray_PyIntAsInt(val.ptr())); + } else { + ret.append(PyLong_AsLong(val.ptr())); + } + }; + if (PyArray_Check(shape.ptr()) && !PyArray_CheckAnyScalar(shape.ptr())) { + orig = py::reinterpret_steal( + PyArray_ToList((PyArrayObject*)shape.ptr())); + for (size_t i = 0; i < orig.size(); ++i) { + solve_one(orig[i]); + } + } else if (PyList_Check(shape.ptr())) { + orig = py::reinterpret_borrow(shape); + for (size_t i = 0; i < orig.size(); ++i) { + solve_one(orig[i]); + } + } else if (PyTuple_Check(shape.ptr())) { + py::tuple tup = py::reinterpret_borrow(shape); + for (size_t i = 0; i < tup.size(); ++i) { + solve_one(tup[i]); + } + } else { + solve_one(shape); + } + return py::reinterpret_steal(PyList_AsTuple(ret.ptr())); +} + +py::object _get_index(py::object tensor, py::object src) { + if (!TensorWrapper::try_cast(tensor.ptr()) && + !py::isinstance(tensor)) { + auto get_const = [&](mgb::DType dtype) -> py::object { + return _Const(tensor, py::cast(dtype), src.attr("device"), src); + }; + if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) { + tensor = get_const(dtype::Bool()); + } else { + tensor = get_const(dtype::Int32()); + } + if (!is_bool_dtype(tensor.ptr())) { + return tensor; + } + } else { + if (!is_bool_dtype(tensor.ptr())) { + return tensor; + } + } + static std::shared_ptr op = CondTake::make(); + std::vector p; + p.resize(3); + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = tensor.ptr(); + p[2] = tensor.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[1]; +} + +py::tuple _try_cond_take(py::handle tensor, py::handle index) { + if (!hasattr(index, "dtype") || !hasattr(index, "shape")) { + return py::tuple(); + } + if (!is_bool_dtype(index.ptr()) || + _make_shape_tuple(getattr(index, "shape")) + .not_equal(_make_shape_tuple(getattr(tensor, "shape")))) { + return py::tuple(); + } + py::object iobj; + if (PyArray_Check(index.ptr())) { + iobj = + _Const(index, py::cast((mgb::DType)dtype::Bool()), + getattr(tensor, "device"), tensor); + } else { + iobj = py::reinterpret_borrow(index); + } + static std::shared_ptr op = CondTake::make(); + std::vector p; + p.resize(3); + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = tensor.ptr(); + p[2] = iobj.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret; +} + +py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) { + size_t tuple_size = tuple_val.size(); + size_t ndim_sum = 0, cur_sum = 0; + int pos = -1; + bool has_unknown_ndim_bool_index = false; + for (size_t i = 0; i < tuple_size; ++i) { + py::object handle = tuple_val[i]; + if (handle.ptr() == Py_Ellipsis) { + pos = static_cast(i); + for (size_t j = 0; j < i; ++j) { + py::object t = tuple_val[j]; + if (t.ptr() == Py_Ellipsis) { + throw py::index_error("only one ellipsis is allowed."); + } + } + } else { + size_t ndim_incr = 1; + if (hasattr(handle, "dtype") && is_bool_dtype(handle.ptr()) && + hasattr(handle, "ndim")) { + py::object ndim = getattr(handle, "ndim"); + if (PyLong_Check(ndim.ptr())) { + ndim_incr = PyLong_AsLong(ndim.ptr()); + } else { + has_unknown_ndim_bool_index = true; + } + } + cur_sum += ndim_incr; + } + } + if (pos == -1) { + return tuple_val; + } else { + if (has_unknown_ndim_bool_index) { + throw py::index_error( + "does not support bool index with unknown shape when using " + "Ellipsis."); + } + try { + ndim_sum = getattr(tensor, "ndim").cast(); + } catch (py::error_already_set& err) { + throw py::index_error( + "does not support Ellipsis when tensor's ndim is unknown."); + } + py::tuple ret(ndim_sum - cur_sum + tuple_size - 1); + size_t idx = 0; + for (size_t i = 0; i < tuple_size; ++i) { + if (i == pos) { + for (size_t j = cur_sum; j < ndim_sum; ++j) { + ret[idx++] = PySlice_New(NULL, NULL, NULL); + } + } else { + ret[idx++] = tuple_val[i]; + } + } + return ret; + } +} + +py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) { + py::tuple cur_shape = _make_shape_tuple(py::handle(getattr(tensor, "shape"))); + py::list new_tuple_val(0); + + size_t offset = 0; + size_t tdim = 0; + for (size_t i = 0; i < tuple_val.size(); ++i) { + py::handle k = tuple_val[i]; + if (is_bool_dtype(k.ptr())) { + size_t ndim = getattr(k, "ndim").cast(); + if (ndim > 1) { + py::tuple ishape = _make_shape_tuple(py::handle(getattr(k, "shape"))); + for (size_t j = 0; j < ndim; ++j) { + if (cur_shape[tdim + j - offset].cast() != + ishape[j].cast()) { + std::string msg = + "boolean index did not match tensor along dimension " + + std::to_string(tdim + j) + "; dimension is " + + std::to_string( + cur_shape[tdim + j - offset].cast()) + + " but corresponding boolean dimension is " + + std::to_string(ishape[j].cast()); + throw py::index_error(msg.c_str()); + } + } + py::object new_k = getattr(k, "reshape")(-1); + py::object kshape = getattr(new_k, "shape"); + py::list new_shape(0); + PyObject* sym = PyObject_CallObject(cpp_use_symbolic_shape, nullptr); + bool is_sym = (sym == Py_True); + Py_XDECREF(sym); + if (is_sym) { + py::object tshape = getattr(tensor, "shape"); + for (size_t j = 0; j < i; ++j) { + new_shape.append(tshape[py::int_(j)]); + } + new_shape.append(kshape[py::int_(0)]); + for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) { + new_shape.append(cur_shape[j]); + } + py::tuple args = py::make_tuple(new_shape); + PyObject* shape_tensor = + PyObject_CallObject(cpp_astensor1d, args.ptr()); + py::object reshape_func = getattr(tensor, "reshape"); + Py_INCREF(shape_tensor); + PyObject* Args = PyTuple_New(1); + PyTuple_SetItem(Args, 0, shape_tensor); + PyObject* new_tensor = + PyObject_CallObject(reshape_func.ptr(), Args); + Py_XDECREF(Args); + tensor = py::reinterpret_steal(new_tensor); + cur_shape = _make_shape_tuple(py::handle(shape_tensor)); + Py_XDECREF(shape_tensor); + } else { + for (size_t j = 0; j < i; ++j) { + new_shape.append(cur_shape[j]); + } + new_shape.append(py::reinterpret_borrow(kshape)[0]); + for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) { + new_shape.append(cur_shape[j]); + } + cur_shape = new_shape; + tensor = getattr(tensor, "reshape")(cur_shape); + } + offset++; + tdim += ndim; + } + new_tuple_val.append(k); + } else { + new_tuple_val.append(k); + tdim++; + } + } + return py::make_tuple(tensor, py::reinterpret_borrow(new_tuple_val)); +} + +py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) { + py::object inp = py::reinterpret_borrow(inp_hdl); + py::tuple tuple_val; + if (py::isinstance(idx_hdl)) { + tuple_val = py::reinterpret_borrow(idx_hdl); + } else { + tuple_val = py::make_tuple(idx_hdl); + } + + bool use_subtensor = true; + bool need_remove_ellipsis = false; + bool need_expand_bool_dim = false; + size_t idx_ndim = 0; + for (size_t i = 0; i < tuple_val.size(); ++i) { + py::object k = tuple_val[i]; + if (k.ptr() == Py_None) { + throw py::index_error("newaxis is not allowed here"); + } else if (k.ptr() == Py_Ellipsis) { + need_remove_ellipsis = true; + } else { + if (is_bool_dtype(k.ptr()) && hasattr(k, "ndim")) { + size_t ndim = getattr(k, "ndim").cast(); + idx_ndim += ndim; + if (ndim > 1) { + need_expand_bool_dim = true; + } + } else { + idx_ndim++; + } + } + } + try { + size_t inp_ndim = getattr(inp, "ndim").cast(); + if (idx_ndim > inp_ndim) { + std::string msg = "too many indices for tensor: tensor is " + + std::to_string(inp_ndim) + "-dimensional, but " + + std::to_string(idx_ndim) + " were indexed"; + throw py::index_error(msg.c_str()); + } + } catch (py::error_already_set& err) { + ; // ignore + } + if (need_remove_ellipsis) { + tuple_val = _remove_ellipsis(inp, tuple_val); + } + + if (need_expand_bool_dim) { + py::object shape = getattr(inp, "shape"); + if (shape.ptr() != Py_None) { + py::tuple ret = _expand_bool_dim(inp, tuple_val); + inp = ret[0]; + tuple_val = ret[1]; + } + } + + py::list items; + py::list tensors; + int cur_axis = -1; + + for (size_t i = 0; i < tuple_val.size(); ++i) { + py::object handle = tuple_val[i]; + cur_axis++; + if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) { + use_subtensor = false; + } + py::list item; + item.append(cur_axis); + auto push = [&](PyObject* v) { + if (v == Py_None) { + item.append(false); + } else { + item.append(true); + tensors.append(_get_index(py::reinterpret_borrow(v), inp)); + } + }; + + if (PySlice_Check(handle.ptr())) { + PySliceObject* s = (PySliceObject*)handle.ptr(); + if (s->start == Py_None && s->stop == Py_None && s->step == Py_None) { + continue; + } + push(s->start); + push(s->stop); + push(s->step); + item.append(false); + } else { + for (size_t j = 0; j < 3; j++) + item.append(false); + push(handle.ptr()); + } + items.append(item); + } + + return py::make_tuple(inp, tensors, items, use_subtensor, need_expand_bool_dim); +} + +py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { + py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl); + if (try_res.size() == 2) { + return try_res[0]; + } + py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); + py::object tensor = py::reinterpret_borrow(up[0]); + py::list tensors = py::reinterpret_borrow(up[1]); + py::list py_items = py::reinterpret_borrow(up[2]); + std::vector> cpp_items; + for (size_t i = 0; i < py_items.size(); ++i) { + py::list item = py::reinterpret_borrow(py_items[i]); + cpp_items.push_back( + {item[0].cast(), item[1].cast(), item[2].cast(), + item[3].cast(), item[4].cast()}); + } + static std::shared_ptr op; + if (up[3].cast()) { + op = Subtensor::make(cpp_items); + } else { + op = IndexingMultiAxisVec::make(cpp_items); + } + std::vector p; + p.resize(tensors.size() + 2); + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = tensor.ptr(); + for (size_t i = 0; i < tensors.size(); ++i) { + p[i + 2] = tensors[i].ptr(); + } + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; +} + +py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) { + py::object org_shape = getattr(inp_hdl, "shape"); + py::object val = py::reinterpret_borrow(val_hdl); + if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance(val)) { + val = + _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"), + inp_hdl); + } + + py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); + py::object tensor = py::reinterpret_borrow(up[0]); + py::list tensors = py::reinterpret_borrow(up[1]); + py::list py_items = py::reinterpret_borrow(up[2]); + std::vector> cpp_items; + for (size_t i = 0; i < py_items.size(); ++i) { + py::list item = py::reinterpret_borrow(py_items[i]); + cpp_items.push_back( + {item[0].cast(), item[1].cast(), item[2].cast(), + item[3].cast(), item[4].cast()}); + } + static std::shared_ptr op, set_op; + if (up[3].cast()) { + op = Subtensor::make(cpp_items); + } else { + op = IndexingMultiAxisVec::make(cpp_items); + } + std::vector p; + p.resize(tensors.size() + 2); + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = tensor.ptr(); + for (size_t i = 0; i < tensors.size(); ++i) { + p[i + 2] = tensors[i].ptr(); + } + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + py::object tmp_result = ret[0]; + + try { + py::object value_tuple_shape = val.attr("_tuple_shape"); + py::object tmp_result_tuple_shape = tmp_result.attr("_tuple_shape"); + py::tuple value_shape = py::reinterpret_borrow(value_tuple_shape); + py::tuple tmp_result_shape = + py::reinterpret_borrow(tmp_result_tuple_shape); + for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) { + size_t vs = value_shape[value_shape.size() - i - 1].cast(); + size_t ts = + tmp_result_shape[tmp_result_shape.size() - i - 1].cast(); + if (vs != 1 && vs != ts) { + std::string lhs = "", rhs = ""; + for (size_t j = 0; j < tmp_result_shape.size(); ++j) { + lhs += std::to_string(tmp_result_shape[j].cast()); + if (j) + lhs += ","; + } + for (size_t j = 0; j < value_shape.size(); ++j) { + rhs += std::to_string(value_shape[j].cast()); + if (j) + rhs += ","; + } + throw py::value_error( + "cannot copy tensor with shape (" + rhs + + ") to subtensor with shape (" + lhs + ")"); + } + } + } catch (py::error_already_set& err) { + ; + } + + py::object broadcast_func = getattr(val, "_broadcast"); + PyObject* Args = PyTuple_New(1); + PyTuple_SetItem(Args, 0, getattr(tmp_result, "shape").release().ptr()); + PyObject* new_val = PyObject_CallObject(broadcast_func.ptr(), Args); + Py_XDECREF(Args); + val = py::reinterpret_steal(new_val); + + if (up[3].cast()) { + set_op = SetSubtensor::make(cpp_items); + } else { + set_op = IndexingSetMultiAxisVec::make(cpp_items); + } + + std::vector q; + q.resize(tensors.size() + 3); + py::object Set_Op = py::cast(set_op); + q[0] = Set_Op.ptr(); + q[1] = tensor.ptr(); + q[2] = val.ptr(); + for (size_t i = 0; i < tensors.size(); ++i) { + q[i + 3] = tensors[i].ptr(); + } + py::tuple result = + py::reinterpret_steal(py_apply(NULL, q.data(), q.size())); + py::object res = result[0]; + + if (up[4].cast()) { + py::object reshape_func = getattr(res, "reshape"); + PyObject* Args = PyTuple_New(1); + PyTuple_SetItem(Args, 0, org_shape.release().ptr()); + PyObject* new_tensor = PyObject_CallObject(reshape_func.ptr(), Args); + Py_XDECREF(Args); + res = py::reinterpret_steal(new_tensor); + } + + return res; +} + // Returns the dtype that would result from performing an arithmetic // operation on the provided input tensors and scalars. PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { @@ -546,6 +1106,30 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _make_shape_tuple(py::handle(args[0])).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + +PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _getitem_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + +PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _setitem_cpp( + py::handle(args[0]), py::handle(args[1]), py::handle(args[2])) + .release() + .ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + #ifdef METH_FASTCALL #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } @@ -559,6 +1143,9 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { WRAP_FUNC_PY35(py_apply); WRAP_FUNC_PY35(dtype_promotion); WRAP_FUNC_PY35(get_device); +WRAP_FUNC_PY35(make_shape_tuple); +WRAP_FUNC_PY35(getitem_cpp); +WRAP_FUNC_PY35(setitem_cpp); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -681,6 +1268,9 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(apply, py_apply), MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), MGE_PY_INTERFACE(get_device, get_device), + MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple), + MGE_PY_INTERFACE(getitem_cpp, getitem_cpp), + MGE_PY_INTERFACE(setitem_cpp, setitem_cpp), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { @@ -1037,6 +1627,10 @@ void init_tensor(py::module m) { return module_trace_transformation; }; + m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape); + + m.def("set_cpp_astensor1d", &set_cpp_astensor1d); + m.def("set_module_tracing", [=] { get_module_trace()->enable(); }); m.def("unset_module_tracing", [=] { get_module_trace()->disable(); }); diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index 9adf12f6c2da31149c37d632869da87a34dec716..233ca822ff750c8d427f10a934f02a2a196f6178 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -751,3 +751,40 @@ def test_subtensor_when_shape_invalid(): inp = rand.uniform(size=[1, 3, 512, 512]) net = cgtools.GraphInference(f.name) net.run(inp_dict={"data": inp}) + + +@pytest.mark.parametrize( + "test_varnode", [True, False], +) +def test_indexing_error(test_varnode): + if test_varnode: + network = Network() + else: + network = None + a = np.arange(9).reshape(3, 3).astype(np.float32) + b = np.array([1, 2]) + aa = make_tensor(a, network) + bb = make_tensor(b, network) + + with pytest.raises(IndexError): + aa[None] # newaxis is not allowed + + with pytest.raises(IndexError): + aa[..., ...] # only one ellipsis is allowed + + with pytest.raises(IndexError): + aa[bb, bb, bb] # too many indices + + with pytest.raises(ValueError): + aa[:] = bb # shape mismatch + + if test_varnode: + cc = aa[aa > 4] + with pytest.raises(IndexError): + cc[...] # does not support ellipsis when tensor's ndim is unknown + + dd = aa > 4 + with pytest.raises(IndexError): + cc[ + ..., dd[dd] + ] # does not support bool index with unknown shape when using ellipsis