提交 1e2117f6 编写于 作者: M Megvii Engine Team

feat(mge): remove add_axis and remove_axis

GitOrigin-RevId: 59611d43f979ce2de3b8fdb078c3e8fca6e9bf49
上级 5b7ae268
......@@ -411,7 +411,7 @@ class ArrayMethodMixin(abc.ABC):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.squeeze`).
Same for prod/mean/max/min.
......
......@@ -19,7 +19,7 @@ from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor import Tensor
from .elemwise import clip, exp, log, log1p
from .tensor import add_axis, remove_axis, reshape
from .tensor import reshape, squeeze
__all__ = [
"argmax",
......@@ -459,7 +459,7 @@ def argmin(
(inp,) = apply(op, inp)
if not keepdims:
inp = remove_axis(inp, ai)
inp = squeeze(inp, ai)
return inp
......@@ -471,7 +471,7 @@ def argmin(
op = builtin.Argmin(axis=axis)
(result,) = apply(op, inp)
if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result
......@@ -517,7 +517,7 @@ def argmax(
(inp,) = apply(op, inp)
if not keepdims:
inp = remove_axis(inp, ai)
inp = squeeze(inp, ai)
return inp
......@@ -529,7 +529,7 @@ def argmax(
op = builtin.Argmax(axis=axis)
(result,) = apply(op, inp)
if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result
......
......@@ -23,7 +23,7 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum
from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshape, zeros
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros
from .types import _pair, _pair_nonzero
__all__ = [
......@@ -542,7 +542,7 @@ def logsumexp(
if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else:
return remove_axis(max_value, axis=None) + log(
return squeeze(max_value, axis=None) + log(
sum(exp(inp - max_value), axis, keepdims)
)
......@@ -640,7 +640,7 @@ def batch_norm2d(
def expand_or_full(x, value):
if x is None:
return full_value(value)
return add_axis(x, [0, 2, 3])
return expand_dims(x, [0, 2, 3])
def make_full_if_none(x, value):
if x is None:
......@@ -998,10 +998,10 @@ def matmul(
else:
if dim1 == 1:
shp = (inp2.shape[1],)
inp1 = add_axis(inp1, 0)
inp1 = expand_dims(inp1, 0)
if dim2 == 1:
shp = (inp1.shape[0],)
inp2 = add_axis(inp2, 1)
inp2 = expand_dims(inp2, 1)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
......@@ -1135,7 +1135,7 @@ def interpolate(
align_corners = False
if mode == "LINEAR":
inp = add_axis(inp, 3)
inp = expand_dims(inp, 3)
if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode")
......@@ -1452,7 +1452,7 @@ def indexing_one_hot(
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result
......
......@@ -32,12 +32,12 @@ from ..tensor import Tensor
from .elemwise import ceil
__all__ = [
"add_axis",
"arange",
"broadcast",
"concat",
"cond_take",
"dimshuffle",
"expand_dims",
"eye",
"flatten",
"full",
......@@ -50,7 +50,6 @@ __all__ = [
"param_pack_concat",
"param_pack_split",
"reshape",
"remove_axis",
"split",
"squeeze",
"stack",
......@@ -160,7 +159,7 @@ def zeros_like(inp: Tensor) -> Tensor:
print(out.numpy())
Outputs:
.. testoutput::
[[0 0 0]
......@@ -320,7 +319,7 @@ def stack(inps, axis=0, device=None):
if len(shapes) != 1:
raise ValueError("All input tensors must have the same shape")
inps = [add_axis(inp, axis=axis) for inp in inps]
inps = [expand_dims(inp, axis=axis) for inp in inps]
return concat(inps, axis=axis, device=device)
......@@ -480,7 +479,7 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:
def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
r"""Writes all values from the tensor source into input tensor
r"""Writes all values from the tensor source into input tensor
at the indices specified in the index tensor.
For each value in source, its output index is specified by its index
......@@ -694,7 +693,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
Swaps shapes and strides according to given pattern.
:param inp: input tensor.
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:
* (``'x'``) -> make a 0d (scalar) into a 1d vector
......@@ -815,7 +814,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
return inp.reshape(*target_shape)
def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
r"""
Adds dimension before given axis.
......@@ -832,7 +831,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
import megengine.functional as F
x = tensor([1, 2])
out = F.add_axis(x, 0)
out = F.expand_dims(x, 0)
print(out.shape)
Outputs:
......@@ -861,12 +860,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return result
expand_dims = add_axis
def remove_axis(
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None
) -> Tensor:
def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
r"""
Removes dimension of shape 1.
......@@ -883,7 +877,7 @@ def remove_axis(
import megengine.functional as F
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
out = F.remove_axis(x, 3)
out = F.squeeze(x, 3)
print(out.shape)
Outputs:
......@@ -896,9 +890,6 @@ def remove_axis(
return _remove_axis(inp, axis)
squeeze = remove_axis
def linspace(
start: Union[int, float, Tensor],
stop: Union[int, float, Tensor],
......@@ -925,7 +916,7 @@ def linspace(
print(a.numpy())
Outputs:
.. testoutput::
[ 3. 4.75 6.5 8.25 10. ]
......@@ -967,7 +958,7 @@ def arange(
a = F.arange(5)
print(a.numpy())
Outputs:
Outputs:
......@@ -1018,9 +1009,9 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy())
print(c.numpy())
Outputs:
.. testoutput::
[1]
......@@ -1059,9 +1050,9 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
offsets = tensor(offsets_val, np.int32)
c = F.param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy())
Outputs:
.. testoutput::
[1 1 1 1 1 1 1 1 1 1]
......
......@@ -306,7 +306,7 @@ def test_AxisAddRemove():
x = TensorWrapper(x_np)
grad = Grad().wrt(x, callback=save_to(x))
y = F.remove_axis(F.add_axis(x, 2), 0)
y = F.squeeze(F.expand_dims(x, 2), 0)
grad(y, F.ones_like(y))
np.testing.assert_equal(
......
......@@ -100,7 +100,7 @@ def test_squeeze():
for axis in [None, 3, -4, (3, -4)]:
y = np.squeeze(x, axis)
yy = F.remove_axis(xx, axis)
yy = F.squeeze(xx, axis)
np.testing.assert_equal(y, yy.numpy())
......@@ -110,7 +110,7 @@ def test_expand_dims():
for axis in [2, -3, (3, -4), (1, -4)]:
y = np.expand_dims(x, axis)
yy = F.add_axis(xx, axis)
yy = F.expand_dims(xx, axis)
np.testing.assert_equal(y, yy.numpy())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册