提交 bc9aa47a 编写于 作者: M Megvii Engine Team

feat(mge/indexing): support newaxis

GitOrigin-RevId: 8338c4b47542671f07cee9d68bbe8c35c25c5f16
上级 9779bc7f
......@@ -459,12 +459,8 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
if (!dtype_equal(cur, descr)) {
std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr));
py::object Op = py::cast(op);
std::vector<PyObject*> p;
p.resize(2);
p[0] = Op.ptr();
p[1] = tensor.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
PyObject* p[2] = {Op.ptr(), tensor.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
return ret[0];
} else {
return py::reinterpret_borrow<py::object>(tensor);
......@@ -514,7 +510,7 @@ py::object _convert_inputs_cpp(
}
}
auto convert = [&](py::object value) {
if (value.ptr() == Py_None) {
if (value.is_none()) {
return value;
}
return _convert_single_value_cpp(value, dtype, device);
......@@ -545,12 +541,9 @@ py::object _astensor1d_cpp(
if (device.ptr() != Py_None) {
std::shared_ptr<OpDef> op = Copy::make(device_obj.cast<CompNode>());
py::object Op = py::cast(op);
std::vector<PyObject*> p;
p.resize(2);
p[0] = Op.ptr();
p[1] = ret.ptr();
py::tuple copy_ret = py::reinterpret_steal<py::object>(
py_apply(NULL, p.data(), p.size()));
PyObject* p[2] = {Op.ptr(), ret.ptr()};
py::tuple copy_ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
return copy_ret[0];
}
return ret;
......@@ -590,7 +583,7 @@ py::object _astensor1d_cpp(
c_args[lis.size()] = Py_None;
py::tuple inp_tup = py::reinterpret_steal<py::tuple>(
convert_inputs_cpp(NULL, c_args.data(), c_args.size()));
if (device_obj.ptr() == Py_None) {
if (device_obj.is_none()) {
std::vector<PyObject*> inp(inp_tup.size());
for (size_t i = 0; i < inp_tup.size(); ++i) {
inp[i] = inp_tup[i].ptr();
......@@ -637,15 +630,10 @@ py::object _get_index(py::object tensor, py::object src) {
return tensor;
}
}
static std::shared_ptr<OpDef> op = CondTake::make();
std::vector<PyObject*> p;
p.resize(3);
std::shared_ptr<OpDef> op = CondTake::make();
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
p[2] = tensor.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
PyObject* p[3] = {Op.ptr(), tensor.ptr(), tensor.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
return ret[1];
}
......@@ -666,15 +654,10 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) {
} else {
iobj = py::reinterpret_borrow<py::object>(index);
}
static std::shared_ptr<OpDef> op = CondTake::make();
std::vector<PyObject*> p;
p.resize(3);
std::shared_ptr<OpDef> op = CondTake::make();
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
p[2] = iobj.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
PyObject* p[3] = {Op.ptr(), tensor.ptr(), iobj.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3));
return ret;
}
......@@ -685,7 +668,9 @@ py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
bool has_unknown_ndim_bool_index = false;
for (size_t i = 0; i < tuple_size; ++i) {
py::object handle = tuple_val[i];
if (handle.ptr() == Py_Ellipsis) {
if (handle.is_none()) {
continue;
} else if (handle.ptr() == Py_Ellipsis) {
pos = static_cast<int>(i);
for (size_t j = 0; j < i; ++j) {
py::object t = tuple_val[j];
......@@ -749,8 +734,14 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
size_t offset = 0;
size_t tdim = 0;
size_t nonedim = 0;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::handle k = tuple_val[i];
if (k.ptr() == Py_None) {
nonedim++;
new_tuple_val.append(k);
continue;
}
if (is_bool_dtype(k.ptr())) {
size_t ndim = getattr(k, "ndim").cast<size_t>();
if (ndim > 1) {
......@@ -777,7 +768,7 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
Py_XDECREF(sym);
if (is_sym) {
py::object tshape = getattr(tensor, "shape");
for (size_t j = 0; j < i; ++j) {
for (size_t j = 0; j < i - nonedim; ++j) {
new_shape.append(tshape[py::int_(j)]);
}
new_shape.append(kshape[py::int_(0)]);
......@@ -789,7 +780,7 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
tensor = _reshape_cpp(tensor, shape_tensor);
cur_shape = _make_shape_tuple(shape_tensor);
} else {
for (size_t j = 0; j < i; ++j) {
for (size_t j = 0; j < i - nonedim; ++j) {
new_shape.append(cur_shape[j]);
}
new_shape.append(py::reinterpret_borrow<py::tuple>(kshape)[0]);
......@@ -838,8 +829,8 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
size_t idx_ndim = 0;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::object k = tuple_val[i];
if (k.ptr() == Py_None) {
throw py::index_error("newaxis is not allowed here");
if (k.is_none()) {
continue;
} else if (k.ptr() == Py_Ellipsis) {
need_remove_ellipsis = true;
} else {
......@@ -878,6 +869,20 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
}
}
std::vector<int32_t> axis;
for (size_t i = 0; i < tuple_val.size(); ++i) {
if (tuple_val[i].is_none()) {
axis.push_back(i);
}
}
if (axis.size()) {
std::shared_ptr<OpDef> op = AddAxis::make(axis);
py::object Op = py::cast(op);
PyObject* p[2] = {Op.ptr(), inp.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
inp = ret[0];
}
py::list items;
py::list tensors;
int cur_axis = -1;
......@@ -885,6 +890,9 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::object handle = tuple_val[i];
cur_axis++;
if (handle.is_none()) {
continue;
}
if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) {
use_subtensor = false;
}
......@@ -970,11 +978,11 @@ py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) {
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) {
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr()));
for (size_t i = 0; i < lis.size(); ++i) {
if (lis[i].ptr() == Py_None) {
if (lis[i].is_none()) {
auto_infer = true;
size_t right = lis.size() - i;
py::object tshp = getattr(inp_hdl, "_tuple_shape");
if (tshp.ptr() == Py_None) {
if (tshp.is_none()) {
throw py::index_error("does not support `None` with unknown shape");
}
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp);
......@@ -1116,7 +1124,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
}
static std::shared_ptr<OpDef> op;
std::shared_ptr<OpDef> op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
} else {
......@@ -1155,7 +1163,7 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
}
static std::shared_ptr<OpDef> op, set_op;
std::shared_ptr<OpDef> op, set_op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
} else {
......@@ -1340,13 +1348,9 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
}
std::sort(axis.begin(), axis.end());
std::shared_ptr<OpDef> op = AddAxis::make(axis = axis);
std::vector<PyObject*> p;
p.resize(2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
return ret[0];
}
......@@ -1390,13 +1394,9 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
axis[i] -= static_cast<int32_t>(i);
}
std::shared_ptr<OpDef> op = RemoveAxis::make(axis = axis);
std::vector<PyObject*> p;
p.resize(2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
return ret[0];
}
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
......@@ -1437,13 +1437,9 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
}
}
std::shared_ptr<OpDef> op = Dimshuffle::make(pattern);
std::vector<PyObject*> p;
p.resize(2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
PyObject* p[2] = {Op.ptr(), inp_hdl.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
return ret[0];
}
......
......@@ -436,6 +436,8 @@ def test_advance_indexing_high_level(test_varnode):
x = np.arange(27).reshape(3, 3, 3).astype("int32")
xx = make_tensor(x, network)
y = np.array([0, 2], dtype=np.int32)
z = np.array([[0, 1], [1, 2]], dtype=np.int32)
np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :]))
np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1]))
......@@ -444,6 +446,21 @@ def test_advance_indexing_high_level(test_varnode):
np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1]))
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1]))
np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2]))
np.testing.assert_equal(x[:2, y, [0, 1]], get_value(xx[:2, y, [0, 1]]))
np.testing.assert_equal(x[None, None], get_value(xx[None, None]))
np.testing.assert_equal(x[:, None, ...], get_value(xx[:, None, ...]))
np.testing.assert_equal(x[1, None, :, 1], get_value(xx[1, None, :, 1]))
np.testing.assert_equal(x[:, None, 1, None], get_value(xx[:, None, 1, None]))
np.testing.assert_equal(x[:2, y, None, [0, 1]], get_value(xx[:2, y, None, [0, 1]]))
np.testing.assert_equal(
x[None, :, None, [0, 2], None, [1, 2]],
get_value(xx[None, :, None, [0, 2], None, [1, 2]]),
)
np.testing.assert_equal(x[z], get_value(xx[z]))
np.testing.assert_equal(x[z, None], get_value(xx[z, None]))
np.testing.assert_equal(x[None, z], get_value(xx[None, z]))
np.testing.assert_equal(x[z, None, z], get_value(xx[z, None, z]))
np.testing.assert_equal(x[None, z, None], get_value(xx[None, z, None]))
x_ = x.copy()
x_[1, 1, 1] = -1
......@@ -592,16 +609,24 @@ def test_advance_indexing_with_bool(test_varnode):
b = (np.random.sample((2, 3, 4)) > 0.5).astype("bool")
bb = make_tensor(b, network)
np.testing.assert_equal(a[b, :, 0:4:2], get_value(aa[bb, :, 0:4:2]))
np.testing.assert_equal(a[None, b, :, 0:4:2], get_value(aa[None, bb, :, 0:4:2]))
b = (np.random.sample((4, 3, 4)) > 0.5).astype("bool")
bb = make_tensor(b, network)
np.testing.assert_equal(a[..., b, 0:2], get_value(aa[..., bb, 0:2]))
np.testing.assert_equal(
a[None, ..., b, None, 0:2], get_value(aa[None, ..., bb, None, 0:2])
)
b = (np.random.sample((3, 4, 3)) > 0.5).astype("bool")
bb = make_tensor(b, network)
np.testing.assert_equal(
a[:, b, 0:2, [True, False]], get_value(aa[:, bb, 0:2, [True, False]])
)
np.testing.assert_equal(
a[:, b, None, 0:2, [True, False]],
get_value(aa[:, bb, None, 0:2, [True, False]]),
)
@pytest.mark.parametrize("symbolic", [True, False, None])
......@@ -781,9 +806,6 @@ def test_indexing_error(test_varnode):
aa = make_tensor(a, network)
bb = make_tensor(b, network)
with pytest.raises(IndexError):
aa[None] # newaxis is not allowed
with pytest.raises(IndexError):
aa[..., ...] # only one ellipsis is allowed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册