提交 eac447ba 编写于 作者: M Megvii Engine Team

fix(imperative): fix the arguments of flatten

GitOrigin-RevId: c67200826e6ae90a237e5959f95e5492cb13b6d5
上级 79328fa0
...@@ -405,9 +405,15 @@ class ArrayMethodMixin(abc.ABC): ...@@ -405,9 +405,15 @@ class ArrayMethodMixin(abc.ABC):
r"""See :func:`~.transpose`.""" r"""See :func:`~.transpose`."""
return transpose_cpp(self, args) return transpose_cpp(self, args)
def flatten(self): def flatten(self, start_axis: int = 0, end_axis: int = -1):
r"""See :func:`~.flatten`.""" r"""See :func:`~.flatten`."""
return reshape_cpp(self, (-1,)) inp_shape = self.shape
if start_axis < 0:
start_axis += len(inp_shape)
target_shape = tuple(inp_shape[i] for i in range(start_axis)) + (-1,)
if end_axis != -1:
target_shape += (*inp_shape[end_axis + 1 :],)
return reshape_cpp(self, target_shape)
def sum(self, axis=None, keepdims: bool = False): def sum(self, axis=None, keepdims: bool = False):
r"""See :func:`~.sum`.""" r"""See :func:`~.sum`."""
......
...@@ -941,12 +941,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: ...@@ -941,12 +941,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
>>> out.numpy().shape >>> out.numpy().shape
(2, 2, 9) (2, 2, 9)
""" """
if start_axis < 0: return inp.flatten(start_axis, end_axis)
start_axis += len(inp.shape)
target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,)
if end_axis != -1:
target_shape += (*inp.shape[end_axis + 1 :],)
return inp.reshape(*target_shape)
def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册