提交 c34a75d0 编写于 作者: M Megvii Engine Team

fix(trace): assume result is not scalar when shape is valid

GitOrigin-RevId: beee2d0f28620cc3410d5c4172e0413e012114fd
上级 bebb2cf4
......@@ -119,12 +119,16 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
else 1
)
else:
try:
if ndim_indexed > inp.ndim:
raise IndexError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
inp.ndim, len(tuple_val)
)
)
except ValueError:
# ignore
pass
tuple_val = remove_ellipsis(inp, tuple_val)
use_subtensor = True
......
......@@ -272,16 +272,12 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() {
auto hv = m_tensor->numpy();
// if (!hv) {
// PyErr_SetString(PyExc_ValueError, "tensor invalid");
// return nullptr;
// }
auto arr = py::reinterpret_steal<py::array>(
npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
if (!arr) {
if (!hv) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr;
}
auto arr = py::reinterpret_steal<py::array>(
npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
if (hv->shape().is_scalar()) {
mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
......
......@@ -51,6 +51,7 @@ bool is_scalar_shape(ValueRef shape) {
if (shape.is<ScalarValue>()) {
return false;
}
// may have performance issue
auto shape_of_shape = shape.shape();
if (!shape_of_shape) {
// assume not scalar
......@@ -211,14 +212,21 @@ std::vector<ValueRef> subtensor_rule(
const Subtensor& subtensor, Span<ValueRef> inputs) {
mgb_assert(inputs.size() >= 1);
auto input = inputs[0];
size_t ndim = input.is<ScalarValue>() ? 0 : input.shape()->ndim;
bool is_scalar;
mgb_assert(!input.is<ScalarValue>(), "subtensor shouldn't have scalar input");
if (auto shape = input.shape()) {
size_t ndim = input.shape()->ndim;
for (auto&& [axis, begin, end, step, idx] : subtensor.items) {
if (idx) {
ndim--;
}
}
is_scalar = ndim == 0;
} else {
is_scalar = false;
}
auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0];
if (!ndim) {
if (is_scalar) {
return {ScalarValue::make(output)};
} else {
return {output};
......@@ -261,8 +269,7 @@ std::vector<ValueRef> fastpath_copy_rule(
std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) {
mgb_assert(inputs.size() == 2);
bool is_scalar =
(!inputs[1].is<ScalarValue>()) && *inputs[1].shape() == ValueShape{0};
bool is_scalar = is_scalar_shape(inputs[1]);
auto unwrapped_input = inputs[0].is<ScalarValue>()
? inputs[0].cast<ScalarValue>().value()
: inputs[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册