提交 35d46dbb 编写于 作者: M Megvii Engine Team

fix(mge/functional): simplify the api of add/remove_axis

GitOrigin-RevId: 2482529704a04dfa7c480142c66395abda2788e1
上级 f7d8b516
......@@ -488,12 +488,12 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
@wrap_io_tensor
def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
def add_axis(inp: Tensor, axis: int) -> Tensor:
r"""
Add dimension(s) before given axis/axes
Add dimension before given axis.
:param inp: Input tensor
:param axis: Place(s) of new axes
:param axis: Place of new axes
:return: The output tensor
Examples:
......@@ -504,26 +504,28 @@ def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
from megengine import tensor
import megengine.functional as F
x = tensor([1, 2])
out = F.add_axis(x, (0, 2))
out = F.add_axis(x, 0)
print(out.shape)
Outputs:
.. testoutput::
(1, 2, 1)
(1, 2)
"""
if not isinstance(axis, int):
raise ValueError("axis must be int, but got type:{}".format(type(axis)))
return mgb.opr.add_axis(inp, axis)
@wrap_io_tensor
def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
def remove_axis(inp: Tensor, axis: int) -> Tensor:
r"""
Remove dimension(s) of shape 1
Remove dimension of shape 1.
:param inp: Input tensor
:param axis: Place(s) of axes to be removed
:param axis: Place of axis to be removed
:return: The output tensor
Examples:
......@@ -534,16 +536,18 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
from megengine import tensor
import megengine.functional as F
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
out = F.remove_axis(x, (0, 0, 1))
out = F.remove_axis(x, 3)
print(out.shape)
Outputs:
.. testoutput::
(2,)
(1, 1, 2)
"""
if not isinstance(axis, int):
raise ValueError("axis must be int, but got type:{}".format(type(axis)))
return mgb.opr.remove_axis(inp, axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册