diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index e6cd6e7e040f32b396ef86b3c9f870d251343a35..fd2e6bea08bad93ae362a6780cf314e00ee9b4e3 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -134,15 +134,54 @@ def _logical_binary_elwise(mode, rev=False): return f +def _remove_axis(inp: Tensor, axis) -> Tensor: + Param = builtin.AxisAddRemove.Param + + def get_axes(): + if axis is None: + return [i for i, s in enumerate(inp.shape) if s == 1] + try: + return [int(axis)] + except (TypeError, ValueError): + pass + return list(map(int, axis)) + + axis = get_axes() + axis = sorted(i + inp.ndim if i < 0 else i for i in axis) + axis = [a - i for i, a in enumerate(axis)] + + param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) + op = builtin.AxisAddRemove(param=param) + (result,) = apply(op, inp) + return result + + def _reduce(mode): - def f(self, axis=None): - inp = self + def f(self, axis=None, keepdims: bool = False): + data = self + (data,) = utils.convert_inputs(data) if axis is None: - inp = self.flatten() - axis = 0 - op = builtin.Reduce(mode=mode, axis=axis) - (result,) = utils.convert_inputs(inp) - (result,) = apply(op, result) + data = data.reshape(-1) + assert not keepdims, "can not set axis=None and keepdims=True" + + op = builtin.Reduce(mode=mode, axis=0) + (result,) = apply(op, data) + elif isinstance(axis, collections.Iterable): + axis = list(axis) + axis.sort(reverse=True) + + for ai in axis: + op = builtin.Reduce(mode=mode, axis=ai) + (data,) = apply(op, data) + if not keepdims: + data = _remove_axis(data, ai) + result = data + else: + op = builtin.Reduce(mode=mode, axis=axis) + (result,) = apply(op, data) + + if not keepdims: + result = _remove_axis(result, axis) return result return f diff --git a/imperative/python/megengine/functional/loss.py b/imperative/python/megengine/functional/loss.py index 2fbfd173376888c214e1682fedb7dc4756db6c45..b98fcfc7fa8defe810885187573e6470dfe353ad 100644 --- a/imperative/python/megengine/functional/loss.py +++ b/imperative/python/megengine/functional/loss.py @@ -176,15 +176,15 @@ def cross_entropy_with_softmax( num_classes = pred.shape[axis] # Denominator of the softmax - offset = pred.max(axis=axis).detach() + offset = pred.max(axis=axis, keepdims=True).detach() pred = pred - offset - down = exp(pred).sum(axis=axis) + down = exp(pred).sum(axis=axis, keepdims=True) up = indexing_one_hot(pred, label, axis) if label_smooth != 0: factor = label_smooth / num_classes - up = up * (1 - label_smooth) + pred.sum(axis=axis) * factor + up = up * (1 - label_smooth) + pred.sum(axis=axis, keepdims=True) * factor return (log(down) - up).mean() diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 3483ad816dc9e1193d4202e809172e3c3a2b494b..1bdf308069c3793a5154d7256a8e636240c5586b 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -117,40 +117,6 @@ def sign(inp: Tensor): raise NotImplementedError -def _reduce( - data, - *, - mode, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False -): - (data,) = utils.convert_inputs(data) - if axis is None: - data = data.reshape(-1) - assert not keepdims, "can not set axis=None and keepdims=True" - - op = builtin.Reduce(mode=mode, axis=0) - (result,) = apply(op, data) - elif isinstance(axis, collections.Iterable): - axis = list(axis) - axis.sort(reverse=True) - - for ai in axis: - op = builtin.Reduce(mode=mode, axis=ai) - (data,) = apply(op, data) - if not keepdims: - data = remove_axis(data, ai) - result = data - else: - op = builtin.Reduce(mode=mode, axis=axis) - (result,) = apply(op, data) - - if not keepdims: - result = remove_axis(result, axis) - - return result - - def sum( inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, @@ -182,7 +148,7 @@ def sum( [21] """ - return _reduce(inp, mode="SUM", axis=axis, keepdims=keepdims) + return inp.sum(axis=axis, keepdims=keepdims) def prod( @@ -215,7 +181,7 @@ def prod( [720] """ - return _reduce(inp, mode="PRODUCT", axis=axis, keepdims=keepdims) + return inp.prod(axis=axis, keepdims=keepdims) def mean( @@ -248,7 +214,7 @@ def mean( [3.5] """ - return _reduce(inp, mode="MEAN", axis=axis, keepdims=keepdims) + return inp.astype("float32").mean(axis=axis, keepdims=keepdims) def median( @@ -362,7 +328,7 @@ def min( [1] """ - return _reduce(inp, mode="MIN", axis=axis, keepdims=keepdims) + return inp.min(axis=axis, keepdims=keepdims) def max( @@ -394,7 +360,7 @@ def max( [6] """ - return _reduce(inp, mode="MAX", axis=axis, keepdims=keepdims) + return inp.max(axis=axis, keepdims=keepdims) def norm( diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index e2b3e2551ba0c8e4376c65a9f7a6be067b35b678..f7163cd466326b7bcb7754f944ae2c1c5b41bf55 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -580,7 +580,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: """ if axis is None: axis = _get_softmax_axis(len(inp.shape)) - offset = inp.max(axis=axis).detach() + offset = inp.max(axis=axis, keepdims=True).detach() cached = exp(inp - offset) down = sum(cached, axis=axis, keepdims=True) return cached / down diff --git a/imperative/python/test/unit/test_tensor_wrapper.py b/imperative/python/test/unit/test_tensor_wrapper.py index 92dc1c255fd967b9391b86dc5ab0689b244370ad..c2f8def6610c4d62d12d5fde87d5949e3fdad4a3 100644 --- a/imperative/python/test/unit/test_tensor_wrapper.py +++ b/imperative/python/test/unit/test_tensor_wrapper.py @@ -38,7 +38,7 @@ def test_reduce(): for m in ["sum", "prod", "min", "max", "mean"]: x_np = np.random.rand(10).astype("float32") x = TensorWrapper(x_np) - y = getattr(x, m)(-1) + y = getattr(x, m)(axis=-1, keepdims=True) np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)