提交 19652768 编写于 作者: M Megvii Engine Team

fix(mge): fix bool.sum()

GitOrigin-RevId: 62c482db40405e2ff75cadfe547c355ae6a95967
上级 1a24fb29
......@@ -158,6 +158,10 @@ def _reduce(mode):
def f(self, axis=None, keepdims: bool = False):
data = self
(data,) = utils.convert_inputs(data)
if mode == "MEAN":
data = data.astype("float32")
elif self.dtype == np.bool_:
data = data.astype("int32")
if axis is None:
data = data.reshape(-1)
assert not keepdims, "can not set axis=None and keepdims=True"
......@@ -180,6 +184,9 @@ def _reduce(mode):
if not keepdims:
result = _remove_axis(result, axis)
if self.dtype == np.bool_:
if mode in ["MIN", "MAX"]:
result = result.astype("bool")
return result
return f
......@@ -377,7 +384,38 @@ class ArrayMethodMixin(abc.ABC):
def flatten(self):
return self.reshape(-1)
sum = _reduce("SUM")
def sum(self, axis=None, keepdims: bool = False):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).
Same for prod/mean/max/min.
:param axis: the dimension or dimensions to reduce.
:param keepdim: whether the output tensor has ndim retained or not.
:return: output tensor.
Examples:
.. testcode::
from megengine import tensor
a = tensor([False, True, True, False])
b = tensor([1.0, 2.0, 3.0, 4.0])
print(a.sum().numpy())
print(b.sum().numpy())
Outputs:
.. testoutput::
[2]
[10.]
"""
return _reduce("SUM")(self, axis, keepdims)
prod = _reduce("PRODUCT")
min = _reduce("MIN")
max = _reduce("MAX")
......
......@@ -35,12 +35,17 @@ def test_matmul():
def test_reduce():
def test_x(x_np):
for m in ["sum", "prod", "min", "max", "mean"]:
x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np)
y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
test_x((10 * np.random.rand(10) + 1).astype("int32"))
test_x(np.random.rand(10).astype("float32"))
test_x(np.array([True, True, True]))
test_x(np.array([True, False, True]))
def test_set_subtensor():
x = TensorWrapper([1, 2, 3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册