提交 1e2117f6 编写于 作者: M Megvii Engine Team

feat(mge): remove add_axis and remove_axis

GitOrigin-RevId: 59611d43f979ce2de3b8fdb078c3e8fca6e9bf49
上级 5b7ae268
...@@ -411,7 +411,7 @@ class ArrayMethodMixin(abc.ABC): ...@@ -411,7 +411,7 @@ class ArrayMethodMixin(abc.ABC):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them. If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`). If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.squeeze`).
Same for prod/mean/max/min. Same for prod/mean/max/min.
......
...@@ -19,7 +19,7 @@ from ..core.tensor import utils ...@@ -19,7 +19,7 @@ from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import clip, exp, log, log1p from .elemwise import clip, exp, log, log1p
from .tensor import add_axis, remove_axis, reshape from .tensor import reshape, squeeze
__all__ = [ __all__ = [
"argmax", "argmax",
...@@ -459,7 +459,7 @@ def argmin( ...@@ -459,7 +459,7 @@ def argmin(
(inp,) = apply(op, inp) (inp,) = apply(op, inp)
if not keepdims: if not keepdims:
inp = remove_axis(inp, ai) inp = squeeze(inp, ai)
return inp return inp
...@@ -471,7 +471,7 @@ def argmin( ...@@ -471,7 +471,7 @@ def argmin(
op = builtin.Argmin(axis=axis) op = builtin.Argmin(axis=axis)
(result,) = apply(op, inp) (result,) = apply(op, inp)
if not keepdims: if not keepdims:
result = remove_axis(result, axis) result = squeeze(result, axis)
return result return result
...@@ -517,7 +517,7 @@ def argmax( ...@@ -517,7 +517,7 @@ def argmax(
(inp,) = apply(op, inp) (inp,) = apply(op, inp)
if not keepdims: if not keepdims:
inp = remove_axis(inp, ai) inp = squeeze(inp, ai)
return inp return inp
...@@ -529,7 +529,7 @@ def argmax( ...@@ -529,7 +529,7 @@ def argmax(
op = builtin.Argmax(axis=axis) op = builtin.Argmax(axis=axis)
(result,) = apply(op, inp) (result,) = apply(op, inp)
if not keepdims: if not keepdims:
result = remove_axis(result, axis) result = squeeze(result, axis)
return result return result
......
...@@ -23,7 +23,7 @@ from .debug_param import get_conv_execution_strategy ...@@ -23,7 +23,7 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum from .math import argsort, max, sum
from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshape, zeros from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros
from .types import _pair, _pair_nonzero from .types import _pair, _pair_nonzero
__all__ = [ __all__ = [
...@@ -542,7 +542,7 @@ def logsumexp( ...@@ -542,7 +542,7 @@ def logsumexp(
if keepdims: if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims)) return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else: else:
return remove_axis(max_value, axis=None) + log( return squeeze(max_value, axis=None) + log(
sum(exp(inp - max_value), axis, keepdims) sum(exp(inp - max_value), axis, keepdims)
) )
...@@ -640,7 +640,7 @@ def batch_norm2d( ...@@ -640,7 +640,7 @@ def batch_norm2d(
def expand_or_full(x, value): def expand_or_full(x, value):
if x is None: if x is None:
return full_value(value) return full_value(value)
return add_axis(x, [0, 2, 3]) return expand_dims(x, [0, 2, 3])
def make_full_if_none(x, value): def make_full_if_none(x, value):
if x is None: if x is None:
...@@ -998,10 +998,10 @@ def matmul( ...@@ -998,10 +998,10 @@ def matmul(
else: else:
if dim1 == 1: if dim1 == 1:
shp = (inp2.shape[1],) shp = (inp2.shape[1],)
inp1 = add_axis(inp1, 0) inp1 = expand_dims(inp1, 0)
if dim2 == 1: if dim2 == 1:
shp = (inp1.shape[0],) shp = (inp1.shape[0],)
inp2 = add_axis(inp2, 1) inp2 = expand_dims(inp2, 1)
op = builtin.MatrixMul( op = builtin.MatrixMul(
transposeA=transpose_a, transposeA=transpose_a,
transposeB=transpose_b, transposeB=transpose_b,
...@@ -1135,7 +1135,7 @@ def interpolate( ...@@ -1135,7 +1135,7 @@ def interpolate(
align_corners = False align_corners = False
if mode == "LINEAR": if mode == "LINEAR":
inp = add_axis(inp, 3) inp = expand_dims(inp, 3)
if inp.ndim != 4: if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode") raise ValueError("shape of input tensor must correspond to the operartion mode")
...@@ -1452,7 +1452,7 @@ def indexing_one_hot( ...@@ -1452,7 +1452,7 @@ def indexing_one_hot(
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index) (result,) = apply(op, src, index)
if not keepdims: if not keepdims:
result = remove_axis(result, axis) result = squeeze(result, axis)
return result return result
......
...@@ -32,12 +32,12 @@ from ..tensor import Tensor ...@@ -32,12 +32,12 @@ from ..tensor import Tensor
from .elemwise import ceil from .elemwise import ceil
__all__ = [ __all__ = [
"add_axis",
"arange", "arange",
"broadcast", "broadcast",
"concat", "concat",
"cond_take", "cond_take",
"dimshuffle", "dimshuffle",
"expand_dims",
"eye", "eye",
"flatten", "flatten",
"full", "full",
...@@ -50,7 +50,6 @@ __all__ = [ ...@@ -50,7 +50,6 @@ __all__ = [
"param_pack_concat", "param_pack_concat",
"param_pack_split", "param_pack_split",
"reshape", "reshape",
"remove_axis",
"split", "split",
"squeeze", "squeeze",
"stack", "stack",
...@@ -160,7 +159,7 @@ def zeros_like(inp: Tensor) -> Tensor: ...@@ -160,7 +159,7 @@ def zeros_like(inp: Tensor) -> Tensor:
print(out.numpy()) print(out.numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::
[[0 0 0] [[0 0 0]
...@@ -320,7 +319,7 @@ def stack(inps, axis=0, device=None): ...@@ -320,7 +319,7 @@ def stack(inps, axis=0, device=None):
if len(shapes) != 1: if len(shapes) != 1:
raise ValueError("All input tensors must have the same shape") raise ValueError("All input tensors must have the same shape")
inps = [add_axis(inp, axis=axis) for inp in inps] inps = [expand_dims(inp, axis=axis) for inp in inps]
return concat(inps, axis=axis, device=device) return concat(inps, axis=axis, device=device)
...@@ -480,7 +479,7 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: ...@@ -480,7 +479,7 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
r"""Writes all values from the tensor source into input tensor r"""Writes all values from the tensor source into input tensor
at the indices specified in the index tensor. at the indices specified in the index tensor.
For each value in source, its output index is specified by its index For each value in source, its output index is specified by its index
...@@ -694,7 +693,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: ...@@ -694,7 +693,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
Swaps shapes and strides according to given pattern. Swaps shapes and strides according to given pattern.
:param inp: input tensor. :param inp: input tensor.
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, :param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
* (``'x'``) -> make a 0d (scalar) into a 1d vector * (``'x'``) -> make a 0d (scalar) into a 1d vector
...@@ -815,7 +814,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: ...@@ -815,7 +814,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
return inp.reshape(*target_shape) return inp.reshape(*target_shape)
def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
r""" r"""
Adds dimension before given axis. Adds dimension before given axis.
...@@ -832,7 +831,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: ...@@ -832,7 +831,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
import megengine.functional as F import megengine.functional as F
x = tensor([1, 2]) x = tensor([1, 2])
out = F.add_axis(x, 0) out = F.expand_dims(x, 0)
print(out.shape) print(out.shape)
Outputs: Outputs:
...@@ -861,12 +860,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: ...@@ -861,12 +860,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return result return result
expand_dims = add_axis def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
def remove_axis(
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None
) -> Tensor:
r""" r"""
Removes dimension of shape 1. Removes dimension of shape 1.
...@@ -883,7 +877,7 @@ def remove_axis( ...@@ -883,7 +877,7 @@ def remove_axis(
import megengine.functional as F import megengine.functional as F
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
out = F.remove_axis(x, 3) out = F.squeeze(x, 3)
print(out.shape) print(out.shape)
Outputs: Outputs:
...@@ -896,9 +890,6 @@ def remove_axis( ...@@ -896,9 +890,6 @@ def remove_axis(
return _remove_axis(inp, axis) return _remove_axis(inp, axis)
squeeze = remove_axis
def linspace( def linspace(
start: Union[int, float, Tensor], start: Union[int, float, Tensor],
stop: Union[int, float, Tensor], stop: Union[int, float, Tensor],
...@@ -925,7 +916,7 @@ def linspace( ...@@ -925,7 +916,7 @@ def linspace(
print(a.numpy()) print(a.numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::
[ 3. 4.75 6.5 8.25 10. ] [ 3. 4.75 6.5 8.25 10. ]
...@@ -967,7 +958,7 @@ def arange( ...@@ -967,7 +958,7 @@ def arange(
a = F.arange(5) a = F.arange(5)
print(a.numpy()) print(a.numpy())
Outputs: Outputs:
Outputs: Outputs:
...@@ -1018,9 +1009,9 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: ...@@ -1018,9 +1009,9 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy()) print(b.numpy())
print(c.numpy()) print(c.numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::
[1] [1]
...@@ -1059,9 +1050,9 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor: ...@@ -1059,9 +1050,9 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
offsets = tensor(offsets_val, np.int32) offsets = tensor(offsets_val, np.int32)
c = F.param_pack_concat([a, b], offsets, offsets_val) c = F.param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy()) print(c.numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::
[1 1 1 1 1 1 1 1 1 1] [1 1 1 1 1 1 1 1 1 1]
......
...@@ -306,7 +306,7 @@ def test_AxisAddRemove(): ...@@ -306,7 +306,7 @@ def test_AxisAddRemove():
x = TensorWrapper(x_np) x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.remove_axis(F.add_axis(x, 2), 0) y = F.squeeze(F.expand_dims(x, 2), 0)
grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(
......
...@@ -100,7 +100,7 @@ def test_squeeze(): ...@@ -100,7 +100,7 @@ def test_squeeze():
for axis in [None, 3, -4, (3, -4)]: for axis in [None, 3, -4, (3, -4)]:
y = np.squeeze(x, axis) y = np.squeeze(x, axis)
yy = F.remove_axis(xx, axis) yy = F.squeeze(xx, axis)
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())
...@@ -110,7 +110,7 @@ def test_expand_dims(): ...@@ -110,7 +110,7 @@ def test_expand_dims():
for axis in [2, -3, (3, -4), (1, -4)]: for axis in [2, -3, (3, -4), (1, -4)]:
y = np.expand_dims(x, axis) y = np.expand_dims(x, axis)
yy = F.add_axis(xx, axis) yy = F.expand_dims(xx, axis)
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册