提交 07bdb3bf 编写于 作者: M Megvii Engine Team

feat(imperative): add swapaxes

GitOrigin-RevId: e84014a01169cb8e2dd5c68227531537585d34ce
上级 a0862865
......@@ -48,6 +48,7 @@ __all__ = [
"tile",
"copy",
"transpose",
"swapaxes",
"where",
"zeros",
"zeros_like",
......@@ -715,6 +716,32 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
return inp.transpose(pattern)
def swapaxes(inp: Tensor, axis1: int, axis2: int) -> Tensor:
r"""Interchange two axes of a tensor.
Args:
inp: input tensor to swapaxes.
axis1: first axis.
axis2: second axis.
Returns:
a tensor after swapping the two axes of 'inp'.
Examples:
>>> x = Tensor(np.array([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=np.int32))
>>> F.swapaxes(x, 0, 2)
Tensor([[[0 4]
[2 6]]
[[1 5]
[3 7]]], dtype=int32, device=xpux:0)
"""
pattern = list(range(inp.ndim))
tempAxis = pattern[axis1]
pattern[axis1] = pattern[axis2]
pattern[axis2] = tempAxis
return inp.transpose(pattern)
def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
r"""Reshapes a tensor without changing its data.
......
......@@ -214,6 +214,18 @@ def test_split(symbolic):
np.testing.assert_equal(ref_out[idx], out[idx].numpy())
@pytest.mark.parametrize("is_varnode", [True, False])
def test_swapaxes(is_varnode):
if is_varnode:
network = Network()
else:
network = None
x = tensor(np.array([[1, 2, 3]], dtype=np.int32))
y = F.swapaxes(x, 0, 1)
np.testing.assert_equal(y.numpy(), np.array([[1], [2], [3]]).astype(np.int32))
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode):
if is_varnode:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册