From 35d46dbb41dcc7a9b994d56d89d25962e4f37cc2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 21 May 2020 17:11:47 +0800 Subject: [PATCH] fix(mge/functional): simplify the api of add/remove_axis GitOrigin-RevId: 2482529704a04dfa7c480142c66395abda2788e1 --- python_module/megengine/functional/tensor.py | 24 ++++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/python_module/megengine/functional/tensor.py b/python_module/megengine/functional/tensor.py index 3f1c032da..1dc0b6b70 100644 --- a/python_module/megengine/functional/tensor.py +++ b/python_module/megengine/functional/tensor.py @@ -488,12 +488,12 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: @wrap_io_tensor -def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: +def add_axis(inp: Tensor, axis: int) -> Tensor: r""" - Add dimension(s) before given axis/axes + Add dimension before given axis. :param inp: Input tensor - :param axis: Place(s) of new axes + :param axis: Place of new axes :return: The output tensor Examples: @@ -504,26 +504,28 @@ def add_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: from megengine import tensor import megengine.functional as F x = tensor([1, 2]) - out = F.add_axis(x, (0, 2)) + out = F.add_axis(x, 0) print(out.shape) Outputs: .. testoutput:: - (1, 2, 1) + (1, 2) """ + if not isinstance(axis, int): + raise ValueError("axis must be int, but got type:{}".format(type(axis))) return mgb.opr.add_axis(inp, axis) @wrap_io_tensor -def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: +def remove_axis(inp: Tensor, axis: int) -> Tensor: r""" - Remove dimension(s) of shape 1 + Remove dimension of shape 1. :param inp: Input tensor - :param axis: Place(s) of axes to be removed + :param axis: Place of axis to be removed :return: The output tensor Examples: @@ -534,16 +536,18 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: from megengine import tensor import megengine.functional as F x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) - out = F.remove_axis(x, (0, 0, 1)) + out = F.remove_axis(x, 3) print(out.shape) Outputs: .. testoutput:: - (2,) + (1, 1, 2) """ + if not isinstance(axis, int): + raise ValueError("axis must be int, but got type:{}".format(type(axis))) return mgb.opr.remove_axis(inp, axis) -- GitLab