Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
07bdb3bf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
07bdb3bf
编写于
6月 06, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): add swapaxes
GitOrigin-RevId: e84014a01169cb8e2dd5c68227531537585d34ce
上级
a0862865
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
39 addition
and
0 deletion
+39
-0
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+27
-0
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+12
-0
未找到文件。
imperative/python/megengine/functional/tensor.py
浏览文件 @
07bdb3bf
...
...
@@ -48,6 +48,7 @@ __all__ = [
"tile"
,
"copy"
,
"transpose"
,
"swapaxes"
,
"where"
,
"zeros"
,
"zeros_like"
,
...
...
@@ -715,6 +716,32 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
return
inp
.
transpose
(
pattern
)
def
swapaxes
(
inp
:
Tensor
,
axis1
:
int
,
axis2
:
int
)
->
Tensor
:
r
"""Interchange two axes of a tensor.
Args:
inp: input tensor to swapaxes.
axis1: first axis.
axis2: second axis.
Returns:
a tensor after swapping the two axes of 'inp'.
Examples:
>>> x = Tensor(np.array([[[0,1],[2,3]],[[4,5],[6,7]]], dtype=np.int32))
>>> F.swapaxes(x, 0, 2)
Tensor([[[0 4]
[2 6]]
[[1 5]
[3 7]]], dtype=int32, device=xpux:0)
"""
pattern
=
list
(
range
(
inp
.
ndim
))
tempAxis
=
pattern
[
axis1
]
pattern
[
axis1
]
=
pattern
[
axis2
]
pattern
[
axis2
]
=
tempAxis
return
inp
.
transpose
(
pattern
)
def
reshape
(
inp
:
Tensor
,
target_shape
:
Iterable
[
int
])
->
Tensor
:
r
"""Reshapes a tensor without changing its data.
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
07bdb3bf
...
...
@@ -214,6 +214,18 @@ def test_split(symbolic):
np
.
testing
.
assert_equal
(
ref_out
[
idx
],
out
[
idx
].
numpy
())
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_swapaxes
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
x
=
tensor
(
np
.
array
([[
1
,
2
,
3
]],
dtype
=
np
.
int32
))
y
=
F
.
swapaxes
(
x
,
0
,
1
)
np
.
testing
.
assert_equal
(
y
.
numpy
(),
np
.
array
([[
1
],
[
2
],
[
3
]]).
astype
(
np
.
int32
))
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_reshape
(
is_varnode
):
if
is_varnode
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录