提交 19652768 编写于 作者: M Megvii Engine Team

fix(mge): fix bool.sum()

GitOrigin-RevId: 62c482db40405e2ff75cadfe547c355ae6a95967
上级 1a24fb29
...@@ -158,6 +158,10 @@ def _reduce(mode): ...@@ -158,6 +158,10 @@ def _reduce(mode):
def f(self, axis=None, keepdims: bool = False): def f(self, axis=None, keepdims: bool = False):
data = self data = self
(data,) = utils.convert_inputs(data) (data,) = utils.convert_inputs(data)
if mode == "MEAN":
data = data.astype("float32")
elif self.dtype == np.bool_:
data = data.astype("int32")
if axis is None: if axis is None:
data = data.reshape(-1) data = data.reshape(-1)
assert not keepdims, "can not set axis=None and keepdims=True" assert not keepdims, "can not set axis=None and keepdims=True"
...@@ -180,6 +184,9 @@ def _reduce(mode): ...@@ -180,6 +184,9 @@ def _reduce(mode):
if not keepdims: if not keepdims:
result = _remove_axis(result, axis) result = _remove_axis(result, axis)
if self.dtype == np.bool_:
if mode in ["MIN", "MAX"]:
result = result.astype("bool")
return result return result
return f return f
...@@ -377,7 +384,38 @@ class ArrayMethodMixin(abc.ABC): ...@@ -377,7 +384,38 @@ class ArrayMethodMixin(abc.ABC):
def flatten(self): def flatten(self):
return self.reshape(-1) return self.reshape(-1)
sum = _reduce("SUM") def sum(self, axis=None, keepdims: bool = False):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).
Same for prod/mean/max/min.
:param axis: the dimension or dimensions to reduce.
:param keepdim: whether the output tensor has ndim retained or not.
:return: output tensor.
Examples:
.. testcode::
from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.sum().numpy())
print(b.sum().numpy())
Outputs:
.. testoutput::
[2]
[10.]
"""
return _reduce("SUM")(self, axis, keepdims)
prod = _reduce("PRODUCT") prod = _reduce("PRODUCT")
min = _reduce("MIN") min = _reduce("MIN")
max = _reduce("MAX") max = _reduce("MAX")
......
...@@ -35,11 +35,16 @@ def test_matmul(): ...@@ -35,11 +35,16 @@ def test_matmul():
def test_reduce(): def test_reduce():
for m in ["sum", "prod", "min", "max", "mean"]: def test_x(x_np):
x_np = np.random.rand(10).astype("float32") for m in ["sum", "prod", "min", "max", "mean"]:
x = TensorWrapper(x_np) x = TensorWrapper(x_np)
y = getattr(x, m)(axis=-1, keepdims=True) y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
test_x((10 * np.random.rand(10) + 1).astype("int32"))
test_x(np.random.rand(10).astype("float32"))
test_x(np.array([True, True, True]))
test_x(np.array([True, False, True]))
def test_set_subtensor(): def test_set_subtensor():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册