diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 21ed0f9259be93b0f205a3aed6ba7ea7c3c7f5d4..e4a8344f20e6b20b3a0586c9d6462406499d0f81 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -40,10 +40,14 @@ SmallVector apply_on_physical_tensor( ptr[i] = shp.shape[i]; } }else{ - mgb_assert(op_def.axis < shp.ndim); + int32_t axis = op_def.axis; + if (axis < 0) { + axis += shp.ndim; + } + mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); auto* ptr = hv.ptr(); - ptr[0] = shp.shape[op_def.axis]; + ptr[0] = shp.shape[axis]; } return {Tensor::make(std::move(hv))}; } @@ -65,10 +69,14 @@ std::tuple, bool> infer_output_attrs_fallible( ptr[i] = desc.layout[i]; } }else{ - mgb_assert(op_def.axis < desc.layout.ndim); + int32_t axis = op_def.axis; + if (axis < 0) { + axis += desc.layout.ndim; + } + mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim); value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); auto* ptr = value.ptr(); - ptr[0] = desc.layout[op_def.axis]; + ptr[0] = desc.layout[axis]; } return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; }