Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5c7d48cd
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5c7d48cd
编写于
1月 04, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/functional): fix tensor split
GitOrigin-RevId: 0a112ab0bdaa82202c50f7f7b9fe05248b22e415
上级
a240d558
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
78 addition
and
37 deletion
+78
-37
imperative/python/megengine/functional/elemwise.py
imperative/python/megengine/functional/elemwise.py
+1
-1
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+54
-33
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+23
-3
未找到文件。
imperative/python/megengine/functional/elemwise.py
浏览文件 @
5c7d48cd
...
...
@@ -158,7 +158,7 @@ def div(x, y):
def
floor_div
(
x
,
y
):
"""Element-wise `floor(x / y)`."""
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
FLOOR_DIV
IDE
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
FLOOR_DIV
)
def
neg
(
x
):
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
5c7d48cd
...
...
@@ -28,7 +28,7 @@ from ..core.tensor.utils import (
)
from
..device
import
get_default_device
from
..tensor
import
Tensor
from
.elemwise
import
ceil
from
.elemwise
import
ceil
,
floor_div
__all__
=
[
"arange"
,
...
...
@@ -324,52 +324,73 @@ def split(inp, nsplits_or_sections, axis=0):
.. testcode::
import os
import numpy as np
from megengine import tensor
import megengine.functional as F
x = tensor(np.random.random((2,3,4,5)), dtype=np.float32)
out = F.split(x, 2, axis=3)
print(out[0].numpy().shape, out[1].numpy().shape)
x = tensor(np.random.random((10, 20)), dtype=np.float32)
y = F.split(x, 3)
z = F.split(x, [6, 17], axis=1)
if os.environ.get("MEGENGINE_USE_SYMBOLIC_SHAPE"):
print([tuple(i.shape.numpy().tolist()) for i in y])
print([tuple(i.shape.numpy().tolist()) for i in z])
else:
print([i.shape for i in y])
print([i.shape for i in z])
Outputs:
.. testoutput::
(2, 3, 4, 3) (2, 3, 4, 2)
[(4, 20), (3, 20), (3, 20)]
[(10, 6), (10, 11), (10, 3)]
"""
sub_tensors
=
[]
sections
=
[]
def
swapaxis
(
inp
,
src
,
dst
):
if
src
==
dst
:
return
inp
shape
=
[
i
for
i
in
range
(
inp
.
ndim
)]
shape
[
src
]
=
dst
shape
[
dst
]
=
src
return
inp
.
transpose
(
shape
)
inp
=
swapaxis
(
inp
,
0
,
axis
)
if
isinstance
(
nsplits_or_sections
,
int
):
incr_step
=
ceil
(
inp
.
shape
[
0
]
/
nsplits_or_sections
)
nsplits
=
nsplits_or_sections
while
nsplits
>
0
:
nsplits
-=
1
sections
.
append
(
incr_step
.
astype
(
"int32"
))
incr_step
+=
nsplits_or_sections
else
:
sections
=
nsplits_or_sections
ndim
=
len
(
inp
.
shape
)
if
axis
>=
ndim
:
raise
ValueError
(
"Invalid axis {}"
.
format
(
axis
))
st
=
0
for
se
in
sections
:
sub_tensors
.
append
(
swapaxis
(
inp
[
st
:
se
],
axis
,
0
))
st
=
se
Ntotal
=
inp
.
shape
[
axis
]
if
st
<
inp
.
shape
[
0
]:
sub_tensors
.
append
(
swapaxis
(
inp
[
st
:],
axis
,
0
))
try
:
Nsections
=
len
(
nsplits_or_sections
)
+
1
is_array
=
True
except
TypeError
:
Nsections
=
int
(
nsplits_or_sections
)
is_array
=
False
if
is_array
:
div_points
=
[
0
]
+
list
(
nsplits_or_sections
)
+
[
Ntotal
]
for
i
in
range
(
1
,
len
(
div_points
)):
if
div_points
[
i
-
1
]
>=
div_points
[
i
]:
raise
ValueError
(
"Invalid nsplits_or_secions: {}"
.
format
(
nsplits_or_sections
)
)
else
:
# scalar
if
Nsections
<=
0
:
raise
ValueError
(
"Number sections must be larger than 0"
)
if
Nsections
>
Ntotal
:
raise
ValueError
(
"The size {} at dim {} cannot be split into {} sections"
.
format
(
Ntotal
,
axis
,
Nsections
)
)
div_points
=
[
0
]
+
[
floor_div
(
Ntotal
+
Nsections
-
i
-
1
,
Nsections
)
for
i
in
range
(
Nsections
)
]
for
i
in
range
(
2
,
Nsections
+
1
):
div_points
[
i
]
=
div_points
[
i
-
1
]
+
div_points
[
i
]
sub_tensors
=
[]
for
i
in
range
(
Nsections
):
l
=
div_points
[
i
]
r
=
div_points
[
i
+
1
]
slices
=
tuple
(
[
slice
(
None
)]
*
axis
+
[
slice
(
l
,
r
)]
+
[
slice
(
None
)]
*
(
ndim
-
axis
-
1
)
)
sub_tensors
.
append
(
inp
[
slices
])
return
sub_tensors
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
5c7d48cd
...
...
@@ -77,14 +77,34 @@ def test_stack():
def
test_split
():
data
=
np
.
random
.
random
((
2
,
3
,
4
,
5
)).
astype
(
np
.
float32
)
mge_out1
=
F
.
split
(
tensor
(
data
),
2
,
axis
=
3
)
mge_out2
=
F
.
split
(
tensor
(
data
),
[
3
,
5
],
axis
=
3
)
inp
=
tensor
(
data
)
mge_out0
=
F
.
split
(
inp
,
2
,
axis
=
3
)
mge_out1
=
F
.
split
(
inp
,
[
3
],
axis
=
3
)
np_out
=
np
.
split
(
data
,
[
3
,
5
],
axis
=
3
)
np
.
testing
.
assert_equal
(
mge_out1
[
0
].
numpy
(),
mge_out2
[
0
].
numpy
())
assert
len
(
mge_out0
)
==
2
assert
len
(
mge_out1
)
==
2
np
.
testing
.
assert_equal
(
mge_out0
[
0
].
numpy
(),
np_out
[
0
])
np
.
testing
.
assert_equal
(
mge_out1
[
0
].
numpy
(),
np_out
[
0
])
np
.
testing
.
assert_equal
(
mge_out0
[
1
].
numpy
(),
np_out
[
1
])
np
.
testing
.
assert_equal
(
mge_out1
[
1
].
numpy
(),
np_out
[
1
])
try
:
F
.
split
(
inp
,
4
)
assert
False
except
ValueError
as
e
:
pass
try
:
F
.
split
(
inp
,
[
3
,
3
,
5
],
axis
=
3
)
assert
False
except
ValueError
as
e
:
assert
str
(
e
)
==
"Invalid nsplits_or_secions: [3, 3, 5]"
def
test_reshape
():
x
=
np
.
arange
(
6
,
dtype
=
"float32"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录