Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5431929e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5431929e
编写于
8月 18, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(functional): let advance indexing support empty tensor and add more tests
GitOrigin-RevId: 49e1492934813caf4e491a901610b95439bac236
上级
703b783c
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
59 addition
and
15 deletion
+59
-15
imperative/python/megengine/core/tensor/indexing.py
imperative/python/megengine/core/tensor/indexing.py
+2
-11
imperative/python/test/unit/core/test_indexing_op.py
imperative/python/test/unit/core/test_indexing_op.py
+55
-0
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+2
-2
src/opr/impl/indexing.cpp
src/opr/impl/indexing.cpp
+0
-2
未找到文件。
imperative/python/megengine/core/tensor/indexing.py
浏览文件 @
5431929e
...
@@ -176,6 +176,8 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
...
@@ -176,6 +176,8 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
def
is_bool_list
(
x
):
def
is_bool_list
(
x
):
if
not
isinstance
(
x
,
list
):
if
not
isinstance
(
x
,
list
):
return
False
return
False
if
len
(
x
)
==
0
:
return
False
for
i
in
x
:
for
i
in
x
:
if
not
isinstance
(
i
,
bool
):
if
not
isinstance
(
i
,
bool
):
return
False
return
False
...
@@ -246,17 +248,6 @@ def getitem(tensor, index):
...
@@ -246,17 +248,6 @@ def getitem(tensor, index):
if
len
(
try_result
)
==
2
:
if
len
(
try_result
)
==
2
:
return
try_result
[
0
]
return
try_result
[
0
]
tensor
,
tensors
,
items
,
use_subtensor
,
ret_scalar
=
unpack_getitem
(
tensor
,
index
)
tensor
,
tensors
,
items
,
use_subtensor
,
ret_scalar
=
unpack_getitem
(
tensor
,
index
)
for
v
in
tensors
:
if
v
.
shape
is
None
:
break
if
isinstance
(
v
.
shape
,
v
.
__class__
):
break
if
len
(
v
.
shape
)
>
0
and
v
.
shape
[
0
]
==
0
:
(
empty_tensor
,)
=
Const
([],
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)(
tensor
)
return
empty_tensor
if
use_subtensor
:
if
use_subtensor
:
op
=
builtin
.
Subtensor
(
items
=
items
)
op
=
builtin
.
Subtensor
(
items
=
items
)
else
:
else
:
...
...
imperative/python/test/unit/core/test_indexing_op.py
浏览文件 @
5431929e
...
@@ -610,6 +610,25 @@ def test_subtensor_on_empty_tensor(symbolic):
...
@@ -610,6 +610,25 @@ def test_subtensor_on_empty_tensor(symbolic):
run_test
(
lambda
x
:
x
[
100
:
200
,
300
:
400
,
500
:
600
])
run_test
(
lambda
x
:
x
[
100
:
200
,
300
:
400
,
500
:
600
])
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
True
,
False
,
None
])
def
test_indexingMultiAxisVec_on_empty_tensor
(
symbolic
):
np_x
=
np
.
array
([],
dtype
=
np
.
float32
).
reshape
(
10
,
10
,
0
)
mge_x
=
megengine
.
tensor
(
np_x
)
def
run_test
(
fn
):
out_ref
=
fn
(
np_x
)
if
symbolic
is
not
None
:
fn
=
jit
.
trace
(
symbolic
=
symbolic
)(
fn
)
for
i
in
range
(
3
):
out
=
fn
(
mge_x
)
np
.
testing
.
assert_equal
(
out
.
numpy
(),
out_ref
)
run_test
(
lambda
x
:
x
[[
1
,
2
,
3
]])
run_test
(
lambda
x
:
x
[[
1
,
2
,
3
],
[
4
,
5
,
6
]])
run_test
(
lambda
x
:
x
[[]])
run_test
(
lambda
x
:
x
[[],
[],
[]])
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
True
,
False
,
None
])
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
True
,
False
,
None
])
def
test_setsubtensor_on_empty_tensor
(
symbolic
):
def
test_setsubtensor_on_empty_tensor
(
symbolic
):
def
run_test
(
inp_shp
,
fn
):
def
run_test
(
inp_shp
,
fn
):
...
@@ -655,3 +674,39 @@ def test_setsubtensor_on_empty_tensor(symbolic):
...
@@ -655,3 +674,39 @@ def test_setsubtensor_on_empty_tensor(symbolic):
run_test
((
10
,
10
,
10
),
test4
)
run_test
((
10
,
10
,
10
),
test4
)
run_test
((
10
,
10
,
10
),
test5
)
run_test
((
10
,
10
,
10
),
test5
)
run_test
((
10
,
10
,
10
),
test6
)
run_test
((
10
,
10
,
10
),
test6
)
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
True
,
False
,
None
])
def
test_indexingSetMultiAxisVec_on_empty_tensor
(
symbolic
):
def
run_test
(
inp_shp
,
fn
):
np_x
=
np
.
random
.
randn
(
*
inp_shp
).
astype
(
np
.
float32
)
mge_x
=
megengine
.
tensor
(
np_x
)
out_ref
=
fn
(
np_x
)
if
symbolic
is
not
None
:
fn
=
jit
.
trace
(
symbolic
=
symbolic
)(
fn
)
for
i
in
range
(
3
):
out
=
fn
(
mge_x
)
np
.
testing
.
assert_equal
(
out
.
numpy
(),
out_ref
)
def
test1
(
x
):
x
[[
1
,
2
,
3
]]
=
x
[[
1
,
2
,
3
]]
return
x
def
test2
(
x
):
x
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
=
x
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
return
x
def
test3
(
x
):
x
[[]]
=
x
[[]]
return
x
def
test4
(
x
):
x
[[],
[],
[]]
=
x
[[],
[],
[]]
return
x
run_test
((
10
,
10
,
0
),
test1
)
run_test
((
10
,
10
,
0
),
test2
)
run_test
((
10
,
10
,
0
),
test3
)
run_test
((
10
,
10
,
0
),
test4
)
run_test
((
10
,
10
,
10
),
test3
)
run_test
((
10
,
10
,
10
),
test4
)
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
5431929e
...
@@ -860,8 +860,8 @@ def test_condtake():
...
@@ -860,8 +860,8 @@ def test_condtake():
np
.
testing
.
assert_equal
(
idx
.
numpy
(),
np
.
where
(
y
.
reshape
(
-
1
))[
0
])
np
.
testing
.
assert_equal
(
idx
.
numpy
(),
np
.
where
(
y
.
reshape
(
-
1
))[
0
])
#
@pytest.mark.parametrize("is_symbolic", [None, False, True])
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
,
False
,
True
])
def
test_condtake
(
is_symbolic
=
None
):
def
test_condtake
(
is_symbolic
):
shapes
=
[
shapes
=
[
(
3
,
3
,
3
),
(
3
,
3
,
3
),
(
0
,),
(
0
,),
...
...
src/opr/impl/indexing.cpp
浏览文件 @
5431929e
...
@@ -292,8 +292,6 @@ cg::OperatorNodeBase::NodeProp*
...
@@ -292,8 +292,6 @@ cg::OperatorNodeBase::NodeProp*
IndexingMultiAxisVecBase
<
Opr
>::
do_make_node_prop
()
const
{
IndexingMultiAxisVecBase
<
Opr
>::
do_make_node_prop
()
const
{
auto
prop
=
Super
::
do_make_node_prop
();
auto
prop
=
Super
::
do_make_node_prop
();
using
DT
=
NodeProp
::
DepType
;
using
DT
=
NodeProp
::
DepType
;
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
prop
->
add_dep_type_existing_var
(
input
(
0
),
DT
::
VALUE_ALLOW_EMPTY
);
prop
->
add_dep_type_existing_var
(
input
(
0
),
DT
::
VALUE_ALLOW_EMPTY
);
for
(
auto
i
:
m_input2idxonly_axis_indexer
)
{
for
(
auto
i
:
m_input2idxonly_axis_indexer
)
{
if
(
i
)
{
if
(
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录