提交 ffe2bb2e 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge): fix some errors caused by unknown shape when using symbolic trace or building graph

GitOrigin-RevId: 70ddc06eee6df5640ea6cf17bacf61c929eb6c66
上级 2d42455f
......@@ -86,16 +86,22 @@ def _broadcast(inp, shape):
def _reshape(x, shape):
shape_tuple = _make_shape_tuple(shape)
unspec_axis = None
# XXX: assume unspec_axis is not changed in trace
for i, s in enumerate(shape_tuple):
if s < 0:
if s != -1:
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s))
if unspec_axis is not None:
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
unspec_axis = i
try:
shape_tuple = _make_shape_tuple(shape)
except ValueError:
pass
else:
# XXX: assume unspec_axis is not changed in trace
for i, s in enumerate(shape_tuple):
if s < 0:
if s != -1:
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s))
if unspec_axis is not None:
raise ValueError(
"multiple -1 in shape: {} & {}".format(unspec_axis, i)
)
unspec_axis = i
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device)
if unspec_axis is None:
op = builtin.Reshape()
......
......@@ -18,9 +18,9 @@ from .utils import astensor1d, isscalar, make_shape_tuple
def remove_ellipsis(tensor, tuple_val):
ndim_sum = tensor.ndim
cur_sum = 0
pos = -1
has_unkown_ndim_bool_index = False
for i_idx, i in enumerate(tuple_val):
if i is Ellipsis:
for j in tuple_val[:i_idx:-1]:
......@@ -28,10 +28,28 @@ def remove_ellipsis(tensor, tuple_val):
raise IndexError("only one ellipsis is allowed")
pos = i_idx
else:
cur_sum += i.ndim if hasattr(i, "ndim") else 1
try:
cur_sum += (
i.ndim
if hasattr(i, "dtype")
and i.dtype == np.bool_
and hasattr(i, "ndim")
else 1
)
except ValueError:
has_unkown_ndim_bool_index = True
if pos == -1:
return tuple_val
else:
if has_unkown_ndim_bool_index:
raise IndexError(
"Does not support bool index with unknown shape when using Ellipsis"
)
try:
ndim_sum = tensor.ndim
except ValueError:
raise IndexError("Does not support Ellipsis when tensor's ndim is unknown.")
return (
tuple_val[:pos]
+ (slice(None, None, None),) * (ndim_sum - cur_sum)
......@@ -41,7 +59,11 @@ def remove_ellipsis(tensor, tuple_val):
# XXX: assume same results during trace
def check_bool_index(tensor, tuple_val):
cur_shape = make_shape_tuple(tensor.shape)
try:
cur_shape = make_shape_tuple(tensor.shape)
except ValueError:
return tensor, tuple_val
new_tuple_val = []
offset = 0
tdim = 0
......@@ -92,20 +114,31 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
ndim_indexed_scalar = 0
for i in tuple_val:
if not i is Ellipsis:
ndim_indexed += 1 if not hasattr(i, "ndim") else i.ndim
ndim_indexed += (
i.ndim
if hasattr(i, "dtype") and i.dtype == np.bool_ and hasattr(i, "ndim")
else 1
)
if isscalar(i):
ndim_indexed_scalar += 1
if ndim_indexed > inp.ndim:
raise IndexError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
inp.ndim, ndim_indexed
ret_scalar = False
try:
ret_scalar = ndim_indexed_scalar == inp.ndim
except ValueError:
# inp.ndim is unknown
pass
else:
if ndim_indexed > inp.ndim:
raise IndexError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
inp.ndim, len(tuple_val)
)
)
)
tuple_val = remove_ellipsis(inp, tuple_val)
use_subtensor = True
inp, tuple_val = check_bool_index(inp, tuple_val)
if inp.shape is not None:
inp, tuple_val = check_bool_index(inp, tuple_val)
new_axes = []
tensors = []
......@@ -186,7 +219,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
items.append(item)
if new_axes:
raise IndexError("newaxis is not allowed here")
return inp, tensors, items, use_subtensor, ndim_indexed_scalar == inp.ndim
return inp, tensors, items, use_subtensor, ret_scalar
def try_condtake(tensor, index):
......@@ -249,16 +282,21 @@ def setitem(tensor, index, value):
op = builtin.IndexingMultiAxisVec(items=items)
(tmp_result,) = apply(op, tensor, *tensors)
for i in range(min(len(value.shape), len(tmp_result.shape))):
if (value.shape[-i - 1] != 1) & (
value.shape[-i - 1] != tmp_result.shape[-i - 1]
):
raise ValueError(
"cannot copy tensor with shape {} to subtensor with shape {}".format(
value.shape, tmp_result.shape
try:
value_shape = value._tuple_shape
tmp_result_shape = tmp_result._tuple_shape
except ValueError:
pass
else:
for i in range(min(len(value_shape), len(tmp_result_shape))):
if (value_shape[-i - 1] != 1) & (
value_shape[-i - 1] != tmp_result_shape[-i - 1]
):
raise ValueError(
"cannot copy tensor with shape {} to subtensor with shape {}".format(
value_shape, tmp_result_shape
)
)
)
value = value._broadcast(tmp_result.shape)
if use_subtensor:
......
......@@ -137,6 +137,13 @@ def astensor1d(x, *reference, dtype=None, device=None):
ndim = x.ndim
except AttributeError:
pass
except ValueError:
if dtype is not None and dtype != x.dtype:
x = astype(x, dtype)
if device is not None:
cn = as_device(device).to_c()
(x,) = apply(builtin.Copy(comp_node=cn), x)
return x
else:
if ndim != 0 and ndim != 1:
raise ValueError("ndim != 1 or 0, get : %d" % ndim)
......@@ -148,7 +155,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
raise TypeError
if any(isinstance(i, (Tensor, SymbolVar)) for i in x):
x = concatenate(x, device=device)
x = concatenate(x, device=device) if len(x) > 1 else x[0]
if dtype is not None:
x = astype(x, dtype)
return x
......
......@@ -849,8 +849,15 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return list(map(int, axis))
axis = get_axes()
ndim = inp.ndim + len(axis)
axis = sorted(i + ndim if i < 0 else i for i in axis)
try:
ndim = inp.ndim + len(axis)
axis = sorted(i + ndim if i < 0 else i for i in axis)
except ValueError:
if any([ind < 0 for ind in axis]):
raise IndexError(
"Does not support negative index when tensor's ndim is unknown"
)
axis = sorted(axis)
assert axis, "axis could not be empty"
if inp._isscalar():
assert axis[0] == 0, "invalid axis {} for ndim 0".format(axis[0])
......
......@@ -384,6 +384,11 @@ PyObject* TensorWrapper::shape() {
TensorShape shape;
if (m_tensor->m_var) { // get shape from m_var
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(m_tensor->m_var);
using InferType = cg::static_infer::InferType;
if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) {
Py_RETURN_NONE;
}
auto *tshp = mgr.infer_shape_fallible(m_tensor->m_var);
if (!tshp) {
Py_RETURN_NONE;
......@@ -878,6 +883,24 @@ void init_tensor(py::module m) {
->static_infer_manager();
return mgr.infer_shape_fallible(v->m_node);
})
.def("numpy", [](PySymbolVar* v){
auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
auto&& type = mgr.get_infer_type(v->m_node);
using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
throw py::value_error("value invalid!");
}
auto* val = mgr.infer_value_fallible(v->m_node);
if (!val) {
throw py::value_error("value invalid!");
}
auto np_val = py::cast(*val).attr("numpy")();
if (v->is_scalar) {
return py::object(py::array(np_val).squeeze());
}
return np_val;
})
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
.def("_setscalar",
[](PySymbolVar* v) { return v->is_scalar = true; })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册