提交 7e8f7209 编写于 作者: M Megvii Engine Team

refactor(mge/tensor): tensor reduce supports keepdims

GitOrigin-RevId: 8ed95e0fb8de213481a24c7b8be3bee71a13e469
上级 59dcd3b7
......@@ -134,15 +134,54 @@ def _logical_binary_elwise(mode, rev=False):
return f
def _remove_axis(inp: Tensor, axis) -> Tensor:
Param = builtin.AxisAddRemove.Param
def get_axes():
if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1]
try:
return [int(axis)]
except (TypeError, ValueError):
pass
return list(map(int, axis))
axis = get_axes()
axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
axis = [a - i for i, a in enumerate(axis)]
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp)
return result
def _reduce(mode):
def f(self, axis=None):
inp = self
def f(self, axis=None, keepdims: bool = False):
data = self
(data,) = utils.convert_inputs(data)
if axis is None:
inp = self.flatten()
axis = 0
op = builtin.Reduce(mode=mode, axis=axis)
(result,) = utils.convert_inputs(inp)
(result,) = apply(op, result)
data = data.reshape(-1)
assert not keepdims, "can not set axis=None and keepdims=True"
op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
elif isinstance(axis, collections.Iterable):
axis = list(axis)
axis.sort(reverse=True)
for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
if not keepdims:
data = _remove_axis(data, ai)
result = data
else:
op = builtin.Reduce(mode=mode, axis=axis)
(result,) = apply(op, data)
if not keepdims:
result = _remove_axis(result, axis)
return result
return f
......
......@@ -176,15 +176,15 @@ def cross_entropy_with_softmax(
num_classes = pred.shape[axis]
# Denominator of the softmax
offset = pred.max(axis=axis).detach()
offset = pred.max(axis=axis, keepdims=True).detach()
pred = pred - offset
down = exp(pred).sum(axis=axis)
down = exp(pred).sum(axis=axis, keepdims=True)
up = indexing_one_hot(pred, label, axis)
if label_smooth != 0:
factor = label_smooth / num_classes
up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor
up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor
return (log(down) - up).mean()
......
......@@ -117,40 +117,6 @@ def sign(inp: Tensor):
raise NotImplementedError
def _reduce(
data,
*,
mode,
axis: Optional[Union[int, Sequence[int]]] = None,
keepdims: bool = False
):
(data,) = utils.convert_inputs(data)
if axis is None:
data = data.reshape(-1)
assert not keepdims, "can not set axis=None and keepdims=True"
op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
elif isinstance(axis, collections.Iterable):
axis = list(axis)
axis.sort(reverse=True)
for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
if not keepdims:
data = remove_axis(data, ai)
result = data
else:
op = builtin.Reduce(mode=mode, axis=axis)
(result,) = apply(op, data)
if not keepdims:
result = remove_axis(result, axis)
return result
def sum(
inp: Tensor,
axis: Optional[Union[int, Sequence[int]]] = None,
......@@ -182,7 +148,7 @@ def sum(
[21]
"""
return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims)
return inp.sum(axis=axis, keepdims=keepdims)
def prod(
......@@ -215,7 +181,7 @@ def prod(
[720]
"""
return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims)
return inp.prod(axis=axis, keepdims=keepdims)
def mean(
......@@ -248,7 +214,7 @@ def mean(
[3.5]
"""
return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims)
return inp.astype("float32").mean(axis=axis, keepdims=keepdims)
def median(
......@@ -362,7 +328,7 @@ def min(
[1]
"""
return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims)
return inp.min(axis=axis, keepdims=keepdims)
def max(
......@@ -394,7 +360,7 @@ def max(
[6]
"""
return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims)
return inp.max(axis=axis, keepdims=keepdims)
def norm(
......
......@@ -580,7 +580,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
"""
if axis is None:
axis = _get_softmax_axis(len(inp.shape))
offset = inp.max(axis=axis).detach()
offset = inp.max(axis=axis, keepdims=True).detach()
cached = exp(inp - offset)
down = sum(cached, axis=axis, keepdims=True)
return cached / down
......
......@@ -38,7 +38,7 @@ def test_reduce():
for m in ["sum", "prod", "min", "max", "mean"]:
x_np = np.random.rand(10).astype("float32")
x = TensorWrapper(x_np)
y = getattr(x, m)(-1)
y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册