提交 d168cea4 编写于 作者: M Megvii Engine Team

feat(opr): add param(axis) for GetVarShape

feat(mge/imperative): GetVarShape support negative axis

GitOrigin-RevId: 30ce0758e66285e984de0bb410759032653cf20a
上级 e9c036cc
......@@ -40,10 +40,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
ptr[i] = shp.shape[i];
}
}else{
mgb_assert(op_def.axis < shp.ndim);
int32_t axis = op_def.axis;
if (axis < 0) {
axis += shp.ndim;
}
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = shp.shape[op_def.axis];
ptr[0] = shp.shape[axis];
}
return {Tensor::make(std::move(hv))};
}
......@@ -65,10 +69,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
ptr[i] = desc.layout[i];
}
}else{
mgb_assert(op_def.axis < desc.layout.ndim);
int32_t axis = op_def.axis;
if (axis < 0) {
axis += desc.layout.ndim;
}
mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim);
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
ptr[0] = desc.layout[op_def.axis];
ptr[0] = desc.layout[axis];
}
return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册