From eac447bae748ec5029dc4def9580f5344f7bcb1b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 10 Apr 2023 14:43:42 +0800 Subject: [PATCH] fix(imperative): fix the arguments of flatten GitOrigin-RevId: c67200826e6ae90a237e5959f95e5492cb13b6d5 --- .../python/megengine/core/tensor/array_method.py | 10 ++++++++-- imperative/python/megengine/functional/tensor.py | 7 +------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index cd8bf4583..7cf1ea633 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -405,9 +405,15 @@ class ArrayMethodMixin(abc.ABC): r"""See :func:`~.transpose`.""" return transpose_cpp(self, args) - def flatten(self): + def flatten(self, start_axis: int = 0, end_axis: int = -1): r"""See :func:`~.flatten`.""" - return reshape_cpp(self, (-1,)) + inp_shape = self.shape + if start_axis < 0: + start_axis += len(inp_shape) + target_shape = tuple(inp_shape[i] for i in range(start_axis)) + (-1,) + if end_axis != -1: + target_shape += (*inp_shape[end_axis + 1 :],) + return reshape_cpp(self, target_shape) def sum(self, axis=None, keepdims: bool = False): r"""See :func:`~.sum`.""" diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index a07f7f287..e639df073 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -941,12 +941,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: >>> out.numpy().shape (2, 2, 9) """ - if start_axis < 0: - start_axis += len(inp.shape) - target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,) - if end_axis != -1: - target_shape += (*inp.shape[end_axis + 1 :],) - return inp.reshape(*target_shape) + return inp.flatten(start_axis, end_axis) def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: -- GitLab