assertnotkeepdims,"can not set axis=None and keepdims=True"
...
...
@@ -180,6 +184,9 @@ def _reduce(mode):
ifnotkeepdims:
result=_remove_axis(result,axis)
ifself.dtype==np.bool_:
ifmodein["MIN","MAX"]:
result=result.astype("bool")
returnresult
returnf
...
...
@@ -377,7 +384,38 @@ class ArrayMethodMixin(abc.ABC):
defflatten(self):
returnself.reshape(-1)
sum=_reduce("SUM")
defsum(self,axis=None,keepdims:bool=False):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).
Same for prod/mean/max/min.
:param axis: the dimension or dimensions to reduce.
:param keepdim: whether the output tensor has ndim retained or not.