diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index cd8bf458374cb9e933297f441046755308b92429..7cf1ea633feccafe3c49074fe135ce39164f30d2 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -405,9 +405,15 @@ class ArrayMethodMixin(abc.ABC): r"""See :func:`~.transpose`.""" return transpose_cpp(self, args) - def flatten(self): + def flatten(self, start_axis: int = 0, end_axis: int = -1): r"""See :func:`~.flatten`.""" - return reshape_cpp(self, (-1,)) + inp_shape = self.shape + if start_axis < 0: + start_axis += len(inp_shape) + target_shape = tuple(inp_shape[i] for i in range(start_axis)) + (-1,) + if end_axis != -1: + target_shape += (*inp_shape[end_axis + 1 :],) + return reshape_cpp(self, target_shape) def sum(self, axis=None, keepdims: bool = False): r"""See :func:`~.sum`.""" diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index a07f7f287a329a70097ba177dfac072823342dd1..e639df07372cbe79d1fb74e2623448eda78acb1a 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -941,12 +941,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: >>> out.numpy().shape (2, 2, 9) """ - if start_axis < 0: - start_axis += len(inp.shape) - target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,) - if end_axis != -1: - target_shape += (*inp.shape[end_axis + 1 :],) - return inp.reshape(*target_shape) + return inp.flatten(start_axis, end_axis) def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: