提交 54eef558 编写于 作者: M Megvii Engine Team 提交者: “wenjuan”

fix(trace): assume result is not scalar when shape is valid

GitOrigin-RevId: beee2d0f28620cc3410d5c4172e0413e012114fd
上级 84d99d1c
...@@ -119,12 +119,16 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): ...@@ -119,12 +119,16 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
else 1 else 1
) )
else: else:
if ndim_indexed > inp.ndim: try:
raise IndexError( if ndim_indexed > inp.ndim:
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format( raise IndexError(
inp.ndim, len(tuple_val) "too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
inp.ndim, len(tuple_val)
)
) )
) except ValueError:
# ignore
pass
tuple_val = remove_ellipsis(inp, tuple_val) tuple_val = remove_ellipsis(inp, tuple_val)
use_subtensor = True use_subtensor = True
......
...@@ -272,16 +272,12 @@ PyObject* TensorWrapper::device() { ...@@ -272,16 +272,12 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() { PyObject* TensorWrapper::numpy() {
auto hv = m_tensor->numpy(); auto hv = m_tensor->numpy();
// if (!hv) { if (!hv) {
// PyErr_SetString(PyExc_ValueError, "tensor invalid");
// return nullptr;
// }
auto arr = py::reinterpret_steal<py::array>(
npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
if (!arr) {
PyErr_SetString(PyExc_ValueError, "tensor invalid"); PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr; return nullptr;
} }
auto arr = py::reinterpret_steal<py::array>(
npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
if (hv->shape().is_scalar()) { if (hv->shape().is_scalar()) {
mgb_assert(PyArray_Check(arr.ptr())); mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr())); return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
......
...@@ -51,6 +51,7 @@ bool is_scalar_shape(ValueRef shape) { ...@@ -51,6 +51,7 @@ bool is_scalar_shape(ValueRef shape) {
if (shape.is<ScalarValue>()) { if (shape.is<ScalarValue>()) {
return false; return false;
} }
// may have performance issue
auto shape_of_shape = shape.shape(); auto shape_of_shape = shape.shape();
if (!shape_of_shape) { if (!shape_of_shape) {
// assume not scalar // assume not scalar
...@@ -211,14 +212,21 @@ std::vector<ValueRef> subtensor_rule( ...@@ -211,14 +212,21 @@ std::vector<ValueRef> subtensor_rule(
const Subtensor& subtensor, Span<ValueRef> inputs) { const Subtensor& subtensor, Span<ValueRef> inputs) {
mgb_assert(inputs.size() >= 1); mgb_assert(inputs.size() >= 1);
auto input = inputs[0]; auto input = inputs[0];
size_t ndim = input.is<ScalarValue>() ? 0 : input.shape()->ndim; bool is_scalar;
for (auto&& [axis, begin, end, step, idx] : subtensor.items) { mgb_assert(!input.is<ScalarValue>(), "subtensor shouldn't have scalar input");
if (idx) { if (auto shape = input.shape()) {
ndim--; size_t ndim = input.shape()->ndim;
for (auto&& [axis, begin, end, step, idx] : subtensor.items) {
if (idx) {
ndim--;
}
} }
is_scalar = ndim == 0;
} else {
is_scalar = false;
} }
auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0];
if (!ndim) { if (is_scalar) {
return {ScalarValue::make(output)}; return {ScalarValue::make(output)};
} else { } else {
return {output}; return {output};
...@@ -261,8 +269,7 @@ std::vector<ValueRef> fastpath_copy_rule( ...@@ -261,8 +269,7 @@ std::vector<ValueRef> fastpath_copy_rule(
std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) { std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
bool is_scalar = bool is_scalar = is_scalar_shape(inputs[1]);
(!inputs[1].is<ScalarValue>()) && *inputs[1].shape() == ValueShape{0};
auto unwrapped_input = inputs[0].is<ScalarValue>() auto unwrapped_input = inputs[0].is<ScalarValue>()
? inputs[0].cast<ScalarValue>().value() ? inputs[0].cast<ScalarValue>().value()
: inputs[0]; : inputs[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册