提交 e860a083 编写于 作者: M Megvii Engine Team

refactor(mge/indexing): move indexing into c++

GitOrigin-RevId: 43fbdb22ddce876adcbddd3036a4fabe5e5a3fc9
上级 e6706be2
......@@ -6,287 +6,23 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable
import numpy as np
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from .._imperative_rt.core2 import (
getitem_cpp,
set_cpp_astensor1d,
set_cpp_use_symbolic_shape,
setitem_cpp,
)
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.special import Const
from .utils import astensor1d, isscalar, make_shape_tuple
def remove_ellipsis(tensor, tuple_val):
cur_sum = 0
pos = -1
has_unkown_ndim_bool_index = False
for i_idx, i in enumerate(tuple_val):
if i is Ellipsis:
for j in tuple_val[:i_idx:-1]:
if j is Ellipsis:
raise IndexError("only one ellipsis is allowed")
pos = i_idx
else:
try:
cur_sum += (
i.ndim
if hasattr(i, "dtype")
and i.dtype == np.bool_
and hasattr(i, "ndim")
else 1
)
except ValueError:
has_unkown_ndim_bool_index = True
if pos == -1:
return tuple_val
else:
if has_unkown_ndim_bool_index:
raise IndexError(
"Does not support bool index with unknown shape when using Ellipsis"
)
try:
ndim_sum = tensor.ndim
except ValueError:
raise IndexError("Does not support Ellipsis when tensor's ndim is unknown.")
return (
tuple_val[:pos]
+ (slice(None, None, None),) * (ndim_sum - cur_sum)
+ tuple_val[pos + 1 :]
)
# XXX: assume same results during trace
def check_bool_index(tensor, tuple_val):
try:
cur_shape = make_shape_tuple(tensor.shape)
except ValueError:
return tensor, tuple_val
new_tuple_val = []
offset = 0
tdim = 0
for idx, i in enumerate(tuple_val):
if hasattr(i, "dtype") and i.dtype == np.bool_:
if i.ndim > 1:
tot = i.ndim
ishape = make_shape_tuple(i.shape)
for j in range(i.ndim):
if cur_shape[tdim + j - offset] != ishape[j]:
raise IndexError(
"boolean index did not match tensor along dimension {}; dimension is {} but corresponding boolean dimension is {}".format(
tdim + j, cur_shape[tdim + j - offset], ishape[j]
)
)
i = i.reshape(-1)
if not use_symbolic_shape():
cur_shape = (
cur_shape[:idx]
+ (i.shape[0],)
+ cur_shape[tdim + tot - offset :]
)
else:
# XXX: use only for trace
new_shape = []
for ii in range(idx):
new_shape.append(tensor.shape[ii])
new_shape.append(i.shape[0])
for ii in range(tdim + tot - offset, len(cur_shape)):
new_shape.append(cur_shape[ii])
cur_shape = astensor1d(new_shape)
offset += 1
tensor = tensor.reshape(cur_shape)
tdim += tot
if use_symbolic_shape():
cur_shape = make_shape_tuple(cur_shape)
new_tuple_val.append(i)
else:
new_tuple_val.append(i)
tdim += 1
return tensor, new_tuple_val
def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
if not isinstance(tuple_val, tuple):
tuple_val = (tuple_val,)
ndim_indexed = 0
for i in tuple_val:
if not i is Ellipsis:
ndim_indexed += (
i.ndim
if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim")
else 1
)
else:
try:
if ndim_indexed > inp.ndim:
raise IndexError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
inp.ndim, len(tuple_val)
)
)
except ValueError:
# ignore
pass
tuple_val = remove_ellipsis(inp, tuple_val)
use_subtensor = True
if inp.shape is not None:
inp, tuple_val = check_bool_index(inp, tuple_val)
new_axes = []
tensors = []
items = []
cur_axis = -1
for i_idx, i in enumerate(tuple_val):
cur_axis += 1
if i is np.newaxis:
if cur_axis >= 0:
new_axes.append(cur_axis)
continue
if i is Ellipsis:
cur_axis = -1
for j in tuple_val[:i_idx:-1]:
if j is Ellipsis:
raise IndexError("only one ellipsis is allowed")
if j is np.newaxis:
new_axes.append(cur_axis)
cur_axis -= 1
continue
if (
not isscalar(i)
and not i is np.newaxis
and not i is Ellipsis
and not isinstance(i, slice)
):
use_subtensor = False
item = [
cur_axis,
]
def is_bool_list(x):
if not isinstance(x, list):
return False
if len(x) == 0:
return False
for i in x:
if not isinstance(i, bool):
return False
return True
def get_index(i):
if not isinstance(i, (Tensor, SymbolVar)):
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_:
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
else:
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp)
return i
assert isinstance(i, (Tensor, SymbolVar))
if i.dtype != np.bool_:
return i
_, ind = apply(builtin.CondTake(), i, i)
return ind
def push(v, item, tensors):
if v is None:
item.append(False)
else:
item.append(True)
v = get_index(v)
assert np.issubdtype(v.dtype, np.integer) or np.issubdtype(
v.dtype, np.bool_
), "var type in the subscript must be int or bool"
tensors.append(v)
if isinstance(i, slice):
if i.start is None and i.stop is None and i.step is None:
continue
push(i.start, item, tensors)
push(i.stop, item, tensors)
push(i.step, item, tensors)
item.append(False) # idx
else:
item += [False,] * 3 # begin, end, stop
push(i, item, tensors)
assert len(item) == 5
items.append(item)
if new_axes:
raise IndexError("newaxis is not allowed here")
return inp, tensors, items, use_subtensor
def try_condtake(tensor, index):
if not hasattr(index, "dtype") or not hasattr(index, "shape"):
return []
if index.dtype != np.bool_ or make_shape_tuple(index.shape) != make_shape_tuple(
tensor.shape
):
return []
if isinstance(index, np.ndarray):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (Tensor, SymbolVar))
if not isinstance(tensor, (Tensor, SymbolVar)):
raise TypeError("input must be a tensor")
if tensor.device != index.device:
raise ValueError(
"ambiguous device: {} vs {}".format(tensor.device, index.device)
)
return apply(builtin.CondTake(), tensor, index)
from .utils import astensor1d
def getitem(tensor, index):
try_result = try_condtake(tensor, index)
if len(try_result) == 2:
return try_result[0]
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
if use_subtensor:
op = builtin.Subtensor(items=items)
else:
op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors)
return result
return getitem_cpp(tensor, index)
def setitem(tensor, index, value):
org_shape = tensor.shape
try_result = try_condtake(tensor, index)
if len(try_result) == 2:
index = try_result[1]
tensor = tensor.reshape(-1)
if not isinstance(value, (Tensor, SymbolVar)):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor)
tensor, tensors, items, use_subtensor = unpack_getitem(tensor, index)
if use_subtensor:
op = builtin.Subtensor(items=items)
else:
op = builtin.IndexingMultiAxisVec(items=items)
return setitem_cpp(tensor, index, value)
(tmp_result,) = apply(op, tensor, *tensors)
try:
value_shape = value._tuple_shape
tmp_result_shape = tmp_result._tuple_shape
except ValueError:
pass
else:
for i in range(min(len(value_shape), len(tmp_result_shape))):
if (value_shape[-i - 1] != 1) & (
value_shape[-i - 1] != tmp_result_shape[-i - 1]
):
raise ValueError(
"cannot copy tensor with shape {} to subtensor with shape {}".format(
value_shape, tmp_result_shape
)
)
value = value._broadcast(tmp_result.shape)
if use_subtensor:
op = builtin.SetSubtensor(items=items)
else:
op = builtin.IndexingSetMultiAxisVec(items=items)
(result,) = apply(op, tensor, value, *tensors)
result = result.reshape(org_shape)
return result
set_cpp_use_symbolic_shape(use_symbolic_shape)
set_cpp_astensor1d(astensor1d)
......@@ -12,7 +12,14 @@ from typing import Iterable, Union
import numpy as np
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._imperative_rt.core2 import (
SymbolVar,
Tensor,
apply,
dtype_promotion,
get_device,
make_shape_tuple,
)
from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from .._wrap import as_device
from ..ops import builtin
......@@ -163,30 +170,6 @@ def astensor1d(x, *reference, dtype=None, device=None):
return x
def _expand_int(s, i):
if isinstance(i, (Tensor, SymbolVar)):
i_np = i.numpy()
if i_np.ndim == 0:
s.append(int(i_np))
else:
s += list(i_np)
return
if isinstance(i, Iterable):
for ii in i:
_expand_int(s, ii)
return
if np.issubdtype(type(i), np.integer):
s.append(i)
return
raise
def make_shape_tuple(shape):
s = []
_expand_int(s, shape)
return tuple(s)
def _normalize_axis(
ndim: int, axis: Union[int, Iterable], reverse=False
) -> Union[int, list]:
......
此差异已折叠。
......@@ -751,3 +751,40 @@ def test_subtensor_when_shape_invalid():
inp = rand.uniform(size=[1, 3, 512, 512])
net = cgtools.GraphInference(f.name)
net.run(inp_dict={"data": inp})
@pytest.mark.parametrize(
"test_varnode", [True, False],
)
def test_indexing_error(test_varnode):
if test_varnode:
network = Network()
else:
network = None
a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([1, 2])
aa = make_tensor(a, network)
bb = make_tensor(b, network)
with pytest.raises(IndexError):
aa[None] # newaxis is not allowed
with pytest.raises(IndexError):
aa[..., ...] # only one ellipsis is allowed
with pytest.raises(IndexError):
aa[bb, bb, bb] # too many indices
with pytest.raises(ValueError):
aa[:] = bb # shape mismatch
if test_varnode:
cc = aa[aa > 4]
with pytest.raises(IndexError):
cc[...] # does not support ellipsis when tensor's ndim is unknown
dd = aa > 4
with pytest.raises(IndexError):
cc[
..., dd[dd]
] # does not support bool index with unknown shape when using ellipsis
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册