Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ab309eb5
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ab309eb5
编写于
8月 20, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opr): let Split support empty IO
GitOrigin-RevId: aad6dc06bfe9b95889b924e0a26f3ea33c52319a
上级
a8292704
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
74 addition
and
32 deletion
+74
-32
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+16
-24
imperative/python/test/unit/autodiff/test_grad_manger.py
imperative/python/test/unit/autodiff/test_grad_manger.py
+15
-0
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+36
-3
src/opr/impl/tensor_manip.cpp
src/opr/impl/tensor_manip.cpp
+7
-5
未找到文件。
imperative/python/megengine/functional/tensor.py
浏览文件 @
ab309eb5
...
...
@@ -20,7 +20,7 @@ from ..core.tensor.array_method import _broadcast, _remove_axis
from
..core.tensor.utils
import
astensor1d
,
convert_inputs
,
get_device
from
..device
import
get_default_device
from
..tensor
import
Tensor
from
.elemwise
import
ceil
,
floor_div
from
.elemwise
import
ceil
__all__
=
[
"arange"
,
...
...
@@ -442,10 +442,10 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal
=
inp
.
shape
[
axis
]
try
:
if
isinstance
(
nsplits_or_sections
,
Sequence
)
:
Nsections
=
len
(
nsplits_or_sections
)
+
1
is_array
=
True
e
xcept
TypeError
:
e
lse
:
Nsections
=
int
(
nsplits_or_sections
)
is_array
=
False
...
...
@@ -465,27 +465,19 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal
,
axis
,
Nsections
)
)
func
=
(
floor_div
if
isinstance
(
Nsections
,
(
SymbolVar
,
Tensor
))
else
lambda
x
,
y
:
x
//
y
)
div_points
=
[
0
]
+
[
func
(
Ntotal
+
Nsections
-
i
-
1
,
Nsections
)
for
i
in
range
(
Nsections
)
]
for
i
in
range
(
2
,
Nsections
+
1
):
div_points
[
i
]
=
div_points
[
i
-
1
]
+
div_points
[
i
]
sub_tensors
=
[]
for
i
in
range
(
Nsections
):
l
=
div_points
[
i
]
r
=
div_points
[
i
+
1
]
slices
=
tuple
(
[
slice
(
None
)]
*
axis
+
[
slice
(
l
,
r
)]
+
[
slice
(
None
)]
*
(
ndim
-
axis
-
1
)
)
sub_tensors
.
append
(
inp
[
slices
])
return
sub_tensors
partitions
=
[]
for
i
in
range
(
Nsections
):
section_size
=
(
Ntotal
+
Nsections
-
i
-
1
)
//
Nsections
partitions
.
append
(
section_size
)
partitions
=
[
part
if
isinstance
(
part
,
(
SymbolVar
,
Tensor
))
else
Const
(
part
,
dtype
=
"int32"
,
device
=
inp
.
device
)(
inp
)[
0
]
for
part
in
partitions
]
op
=
builtin
.
Split
(
axis
=
axis
)
return
apply
(
op
,
inp
,
*
partitions
)
def
_get_idx
(
index
,
axis
):
...
...
imperative/python/test/unit/autodiff/test_grad_manger.py
浏览文件 @
ab309eb5
...
...
@@ -178,6 +178,21 @@ def test_regression_1762():
gm
.
backward
(
loss
)
def
test_empty_grad_in_backward
():
x
=
mge
.
Parameter
(
F
.
full
(
100
,
0.5
))
y
=
mge
.
Parameter
(
F
.
ones
(
100
))
gm
=
GradManager
()
gm
.
attach
([
x
,
y
])
with
gm
:
z
=
F
.
where
(
x
>
0.7
,
x
,
y
)
loss
=
z
.
sum
()
gm
.
backward
(
loss
)
assert
np
.
all
(
x
.
grad
.
numpy
()
==
0
)
assert
np
.
all
(
y
.
grad
.
numpy
()
==
1
)
@
pytest
.
mark
.
require_ngpu
(
2
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
parametrize
(
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
ab309eb5
...
...
@@ -119,7 +119,7 @@ def test_stack(is_varnode):
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_split
(
is_varnode
):
def
test_split
_basic
(
is_varnode
):
if
is_varnode
:
network
=
Network
()
saved_symbolic_shape
=
set_symbolic_shape
(
False
)
...
...
@@ -150,15 +150,48 @@ def test_split(is_varnode):
pass
try
:
F
.
split
(
inp
,
[
3
,
3
,
5
],
axis
=
3
)
F
.
split
(
inp
,
[
3
,
2
,
5
],
axis
=
3
)
assert
False
except
ValueError
as
e
:
assert
str
(
e
)
==
"Invalid nsplits_or_secions: [3,
3
, 5]"
assert
str
(
e
)
==
"Invalid nsplits_or_secions: [3,
2
, 5]"
if
is_varnode
:
set_symbolic_shape
(
saved_symbolic_shape
)
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
None
,
False
,
True
])
def
test_split
(
symbolic
):
inp1
=
np
.
random
.
random
((
3
,
4
,
5
,
6
)).
astype
(
np
.
float32
)
inp2
=
np
.
random
.
random
((
0
,
4
,
5
,
6
)).
astype
(
np
.
float32
)
def
ref
(
inp
,
nsplits_or_sections
,
axis
):
return
np
.
split
(
inp
,
nsplits_or_sections
,
axis
)
def
func
(
inp
,
nsplits_or_sections
,
axis
):
return
F
.
split
(
inp
,
nsplits_or_sections
,
axis
)
cases
=
[
(
inp1
,
2
,
3
),
(
inp1
,
[
3
],
3
),
(
inp1
,
[
3
,
3
,
5
],
3
),
(
inp2
,
2
,
3
),
(
inp2
,
[
3
],
3
),
(
inp2
,
[
3
,
3
,
5
],
3
),
]
for
case
in
cases
:
if
symbolic
is
None
:
fn
=
func
else
:
fn
=
trace
(
symbolic
=
symbolic
)(
func
)
for
i
in
range
(
3
if
symbolic
is
not
None
else
1
):
ref_out
=
ref
(
*
case
)
out
=
fn
(
tensor
(
case
[
0
]),
case
[
1
],
case
[
2
])
assert
len
(
ref_out
)
==
len
(
out
)
for
idx
in
range
(
len
(
ref_out
)):
np
.
testing
.
assert_equal
(
ref_out
[
idx
],
out
[
idx
].
numpy
())
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_reshape
(
is_varnode
):
if
is_varnode
:
...
...
src/opr/impl/tensor_manip.cpp
浏览文件 @
ab309eb5
...
...
@@ -987,7 +987,8 @@ Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config)
}
for
(
size_t
i
=
0
;
i
<
m_opt
.
nr_part
;
++
i
)
add_output
(
ssprintf
(
"o%zd"
,
i
))
->
dtype
(
inp
->
dtype
());
add_output
(
ssprintf
(
"o%zd"
,
i
))
->
dtype
(
inp
->
dtype
())
.
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
m_output_spec
.
resize
(
m_opt
.
nr_part
);
}
...
...
@@ -1060,10 +1061,6 @@ bool Split::infer_shape(size_t out_idx, TensorShape &dest,
size_t
size
=
0
;
for
(
size_t
i
=
0
;
i
<
m_opt
.
nr_part
;
++
i
)
{
auto
p
=
partition
[
i
];
mgb_assert
(
p
,
"got zero partition size at part %zu, tot_size=%zu"
,
i
,
ishp
.
shape
[
axis
]);
size
+=
p
;
auto
&&
cur
=
m_output_spec
[
i
].
shape
;
...
...
@@ -1126,6 +1123,7 @@ cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const {
auto
rst
=
OperatorNodeBase
::
do_make_node_prop
();
rst
->
add_flag
(
NodeProp
::
Flag
::
CROSS_COMP_NODE_MEMORY
);
outshape_by_symvar_reset_node_dep_type
(
rst
);
rst
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
rst
;
}
...
...
@@ -1141,6 +1139,10 @@ void Split::do_execute(ExecEnv &env) {
auto
&&
in
=
input
(
0
)
->
dev_tensor
();
auto
&&
out
=
output
(
idx
)
->
dev_tensor
();
auto
&&
spec
=
m_output_spec
.
at
(
idx
);
if
(
out
.
layout
().
is_empty
())
{
mgb_assert
(
spec
.
subspec
.
layout
().
is_empty
());
return
;
}
owner_graph
()
->
event
().
signal_inplace
<
cg
::
event
::
BeforeKernel
>
(
this
,
out
.
comp_node
());
if
(
spec
.
mem_fwd_success
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录