Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
bc9aa47a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
bc9aa47a
编写于
2月 28, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/indexing): support newaxis
GitOrigin-RevId: 8338c4b47542671f07cee9d68bbe8c35c25c5f16
上级
9779bc7f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
78 addition
and
60 deletion
+78
-60
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+53
-57
imperative/python/test/unit/core/test_indexing_op.py
imperative/python/test/unit/core/test_indexing_op.py
+25
-3
未找到文件。
imperative/python/src/tensor_utils.cpp
浏览文件 @
bc9aa47a
...
...
@@ -459,12 +459,8 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
if
(
!
dtype_equal
(
cur
,
descr
))
{
std
::
shared_ptr
<
OpDef
>
op
=
TypeCvt
::
make
(
npy
::
dtype_np2mgb_descr
(
descr
));
py
::
object
Op
=
py
::
cast
(
op
);
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
2
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
tensor
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
PyObject
*
p
[
2
]
=
{
Op
.
ptr
(),
tensor
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
2
));
return
ret
[
0
];
}
else
{
return
py
::
reinterpret_borrow
<
py
::
object
>
(
tensor
);
...
...
@@ -514,7 +510,7 @@ py::object _convert_inputs_cpp(
}
}
auto
convert
=
[
&
](
py
::
object
value
)
{
if
(
value
.
ptr
()
==
Py_None
)
{
if
(
value
.
is_none
()
)
{
return
value
;
}
return
_convert_single_value_cpp
(
value
,
dtype
,
device
);
...
...
@@ -545,12 +541,9 @@ py::object _astensor1d_cpp(
if
(
device
.
ptr
()
!=
Py_None
)
{
std
::
shared_ptr
<
OpDef
>
op
=
Copy
::
make
(
device_obj
.
cast
<
CompNode
>
());
py
::
object
Op
=
py
::
cast
(
op
);
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
2
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
ret
.
ptr
();
py
::
tuple
copy_ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
PyObject
*
p
[
2
]
=
{
Op
.
ptr
(),
ret
.
ptr
()};
py
::
tuple
copy_ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
2
));
return
copy_ret
[
0
];
}
return
ret
;
...
...
@@ -590,7 +583,7 @@ py::object _astensor1d_cpp(
c_args
[
lis
.
size
()]
=
Py_None
;
py
::
tuple
inp_tup
=
py
::
reinterpret_steal
<
py
::
tuple
>
(
convert_inputs_cpp
(
NULL
,
c_args
.
data
(),
c_args
.
size
()));
if
(
device_obj
.
ptr
()
==
Py_None
)
{
if
(
device_obj
.
is_none
()
)
{
std
::
vector
<
PyObject
*>
inp
(
inp_tup
.
size
());
for
(
size_t
i
=
0
;
i
<
inp_tup
.
size
();
++
i
)
{
inp
[
i
]
=
inp_tup
[
i
].
ptr
();
...
...
@@ -637,15 +630,10 @@ py::object _get_index(py::object tensor, py::object src) {
return
tensor
;
}
}
static
std
::
shared_ptr
<
OpDef
>
op
=
CondTake
::
make
();
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
3
);
std
::
shared_ptr
<
OpDef
>
op
=
CondTake
::
make
();
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
tensor
.
ptr
();
p
[
2
]
=
tensor
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
PyObject
*
p
[
3
]
=
{
Op
.
ptr
(),
tensor
.
ptr
(),
tensor
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
3
));
return
ret
[
1
];
}
...
...
@@ -666,15 +654,10 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) {
}
else
{
iobj
=
py
::
reinterpret_borrow
<
py
::
object
>
(
index
);
}
static
std
::
shared_ptr
<
OpDef
>
op
=
CondTake
::
make
();
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
3
);
std
::
shared_ptr
<
OpDef
>
op
=
CondTake
::
make
();
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
tensor
.
ptr
();
p
[
2
]
=
iobj
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
PyObject
*
p
[
3
]
=
{
Op
.
ptr
(),
tensor
.
ptr
(),
iobj
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
3
));
return
ret
;
}
...
...
@@ -685,7 +668,9 @@ py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
bool
has_unknown_ndim_bool_index
=
false
;
for
(
size_t
i
=
0
;
i
<
tuple_size
;
++
i
)
{
py
::
object
handle
=
tuple_val
[
i
];
if
(
handle
.
ptr
()
==
Py_Ellipsis
)
{
if
(
handle
.
is_none
())
{
continue
;
}
else
if
(
handle
.
ptr
()
==
Py_Ellipsis
)
{
pos
=
static_cast
<
int
>
(
i
);
for
(
size_t
j
=
0
;
j
<
i
;
++
j
)
{
py
::
object
t
=
tuple_val
[
j
];
...
...
@@ -749,8 +734,14 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
size_t
offset
=
0
;
size_t
tdim
=
0
;
size_t
nonedim
=
0
;
for
(
size_t
i
=
0
;
i
<
tuple_val
.
size
();
++
i
)
{
py
::
handle
k
=
tuple_val
[
i
];
if
(
k
.
ptr
()
==
Py_None
)
{
nonedim
++
;
new_tuple_val
.
append
(
k
);
continue
;
}
if
(
is_bool_dtype
(
k
.
ptr
()))
{
size_t
ndim
=
getattr
(
k
,
"ndim"
).
cast
<
size_t
>
();
if
(
ndim
>
1
)
{
...
...
@@ -777,7 +768,7 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
Py_XDECREF
(
sym
);
if
(
is_sym
)
{
py
::
object
tshape
=
getattr
(
tensor
,
"shape"
);
for
(
size_t
j
=
0
;
j
<
i
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
i
-
nonedim
;
++
j
)
{
new_shape
.
append
(
tshape
[
py
::
int_
(
j
)]);
}
new_shape
.
append
(
kshape
[
py
::
int_
(
0
)]);
...
...
@@ -789,7 +780,7 @@ py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
tensor
=
_reshape_cpp
(
tensor
,
shape_tensor
);
cur_shape
=
_make_shape_tuple
(
shape_tensor
);
}
else
{
for
(
size_t
j
=
0
;
j
<
i
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
i
-
nonedim
;
++
j
)
{
new_shape
.
append
(
cur_shape
[
j
]);
}
new_shape
.
append
(
py
::
reinterpret_borrow
<
py
::
tuple
>
(
kshape
)[
0
]);
...
...
@@ -838,8 +829,8 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
size_t
idx_ndim
=
0
;
for
(
size_t
i
=
0
;
i
<
tuple_val
.
size
();
++
i
)
{
py
::
object
k
=
tuple_val
[
i
];
if
(
k
.
ptr
()
==
Py_None
)
{
throw
py
::
index_error
(
"newaxis is not allowed here"
)
;
if
(
k
.
is_none
()
)
{
continue
;
}
else
if
(
k
.
ptr
()
==
Py_Ellipsis
)
{
need_remove_ellipsis
=
true
;
}
else
{
...
...
@@ -878,6 +869,20 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
}
}
std
::
vector
<
int32_t
>
axis
;
for
(
size_t
i
=
0
;
i
<
tuple_val
.
size
();
++
i
)
{
if
(
tuple_val
[
i
].
is_none
())
{
axis
.
push_back
(
i
);
}
}
if
(
axis
.
size
())
{
std
::
shared_ptr
<
OpDef
>
op
=
AddAxis
::
make
(
axis
);
py
::
object
Op
=
py
::
cast
(
op
);
PyObject
*
p
[
2
]
=
{
Op
.
ptr
(),
inp
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
2
));
inp
=
ret
[
0
];
}
py
::
list
items
;
py
::
list
tensors
;
int
cur_axis
=
-
1
;
...
...
@@ -885,6 +890,9 @@ py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
for
(
size_t
i
=
0
;
i
<
tuple_val
.
size
();
++
i
)
{
py
::
object
handle
=
tuple_val
[
i
];
cur_axis
++
;
if
(
handle
.
is_none
())
{
continue
;
}
if
(
!
is_scalar
(
handle
.
ptr
())
&&
!
PySlice_Check
(
handle
.
ptr
()))
{
use_subtensor
=
false
;
}
...
...
@@ -970,11 +978,11 @@ py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) {
if
(
PyList_Check
(
shape_hdl
.
ptr
())
||
PyTuple_Check
(
shape_hdl
.
ptr
()))
{
lis
=
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
shape_hdl
.
ptr
()));
for
(
size_t
i
=
0
;
i
<
lis
.
size
();
++
i
)
{
if
(
lis
[
i
].
ptr
()
==
Py_None
)
{
if
(
lis
[
i
].
is_none
()
)
{
auto_infer
=
true
;
size_t
right
=
lis
.
size
()
-
i
;
py
::
object
tshp
=
getattr
(
inp_hdl
,
"_tuple_shape"
);
if
(
tshp
.
ptr
()
==
Py_None
)
{
if
(
tshp
.
is_none
()
)
{
throw
py
::
index_error
(
"does not support `None` with unknown shape"
);
}
py
::
tuple
inp_shape
=
py
::
reinterpret_borrow
<
py
::
tuple
>
(
tshp
);
...
...
@@ -1116,7 +1124,7 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
{
item
[
0
].
cast
<
int8_t
>
(),
item
[
1
].
cast
<
bool
>
(),
item
[
2
].
cast
<
bool
>
(),
item
[
3
].
cast
<
bool
>
(),
item
[
4
].
cast
<
bool
>
()});
}
st
atic
st
d
::
shared_ptr
<
OpDef
>
op
;
std
::
shared_ptr
<
OpDef
>
op
;
if
(
up
[
3
].
cast
<
bool
>
())
{
op
=
Subtensor
::
make
(
cpp_items
);
}
else
{
...
...
@@ -1155,7 +1163,7 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
{
item
[
0
].
cast
<
int8_t
>
(),
item
[
1
].
cast
<
bool
>
(),
item
[
2
].
cast
<
bool
>
(),
item
[
3
].
cast
<
bool
>
(),
item
[
4
].
cast
<
bool
>
()});
}
st
atic
st
d
::
shared_ptr
<
OpDef
>
op
,
set_op
;
std
::
shared_ptr
<
OpDef
>
op
,
set_op
;
if
(
up
[
3
].
cast
<
bool
>
())
{
op
=
Subtensor
::
make
(
cpp_items
);
}
else
{
...
...
@@ -1340,13 +1348,9 @@ py::object _expand_dims_cpp(py::handle inp_hdl, py::handle axis_hdl) {
}
std
::
sort
(
axis
.
begin
(),
axis
.
end
());
std
::
shared_ptr
<
OpDef
>
op
=
AddAxis
::
make
(
axis
=
axis
);
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
2
);
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
inp_hdl
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
PyObject
*
p
[
2
]
=
{
Op
.
ptr
(),
inp_hdl
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
2
));
return
ret
[
0
];
}
...
...
@@ -1390,13 +1394,9 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) {
axis
[
i
]
-=
static_cast
<
int32_t
>
(
i
);
}
std
::
shared_ptr
<
OpDef
>
op
=
RemoveAxis
::
make
(
axis
=
axis
);
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
2
);
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
inp_hdl
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
PyObject
*
p
[
2
]
=
{
Op
.
ptr
(),
inp_hdl
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
2
));
return
ret
[
0
];
}
py
::
object
_transpose_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
args
)
{
...
...
@@ -1437,13 +1437,9 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
}
}
std
::
shared_ptr
<
OpDef
>
op
=
Dimshuffle
::
make
(
pattern
);
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
2
);
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
inp_hdl
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
PyObject
*
p
[
2
]
=
{
Op
.
ptr
(),
inp_hdl
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
2
));
return
ret
[
0
];
}
...
...
imperative/python/test/unit/core/test_indexing_op.py
浏览文件 @
bc9aa47a
...
...
@@ -436,6 +436,8 @@ def test_advance_indexing_high_level(test_varnode):
x
=
np
.
arange
(
27
).
reshape
(
3
,
3
,
3
).
astype
(
"int32"
)
xx
=
make_tensor
(
x
,
network
)
y
=
np
.
array
([
0
,
2
],
dtype
=
np
.
int32
)
z
=
np
.
array
([[
0
,
1
],
[
1
,
2
]],
dtype
=
np
.
int32
)
np
.
testing
.
assert_equal
(
x
[
1
,
:,
:],
get_value
(
xx
[
1
,
:,
:]))
np
.
testing
.
assert_equal
(
x
[
1
,
:,
1
],
get_value
(
xx
[
1
,
:,
1
]))
...
...
@@ -444,6 +446,21 @@ def test_advance_indexing_high_level(test_varnode):
np
.
testing
.
assert_equal
(
x
[:,
1
,
1
],
get_value
(
xx
[:,
1
,
1
]))
np
.
testing
.
assert_equal
(
x
[:,
1
],
get_value
(
xx
[:,
1
]))
np
.
testing
.
assert_equal
(
x
[
1
,
1
:
2
],
get_value
(
xx
[
1
,
1
:
2
]))
np
.
testing
.
assert_equal
(
x
[:
2
,
y
,
[
0
,
1
]],
get_value
(
xx
[:
2
,
y
,
[
0
,
1
]]))
np
.
testing
.
assert_equal
(
x
[
None
,
None
],
get_value
(
xx
[
None
,
None
]))
np
.
testing
.
assert_equal
(
x
[:,
None
,
...],
get_value
(
xx
[:,
None
,
...]))
np
.
testing
.
assert_equal
(
x
[
1
,
None
,
:,
1
],
get_value
(
xx
[
1
,
None
,
:,
1
]))
np
.
testing
.
assert_equal
(
x
[:,
None
,
1
,
None
],
get_value
(
xx
[:,
None
,
1
,
None
]))
np
.
testing
.
assert_equal
(
x
[:
2
,
y
,
None
,
[
0
,
1
]],
get_value
(
xx
[:
2
,
y
,
None
,
[
0
,
1
]]))
np
.
testing
.
assert_equal
(
x
[
None
,
:,
None
,
[
0
,
2
],
None
,
[
1
,
2
]],
get_value
(
xx
[
None
,
:,
None
,
[
0
,
2
],
None
,
[
1
,
2
]]),
)
np
.
testing
.
assert_equal
(
x
[
z
],
get_value
(
xx
[
z
]))
np
.
testing
.
assert_equal
(
x
[
z
,
None
],
get_value
(
xx
[
z
,
None
]))
np
.
testing
.
assert_equal
(
x
[
None
,
z
],
get_value
(
xx
[
None
,
z
]))
np
.
testing
.
assert_equal
(
x
[
z
,
None
,
z
],
get_value
(
xx
[
z
,
None
,
z
]))
np
.
testing
.
assert_equal
(
x
[
None
,
z
,
None
],
get_value
(
xx
[
None
,
z
,
None
]))
x_
=
x
.
copy
()
x_
[
1
,
1
,
1
]
=
-
1
...
...
@@ -592,16 +609,24 @@ def test_advance_indexing_with_bool(test_varnode):
b
=
(
np
.
random
.
sample
((
2
,
3
,
4
))
>
0.5
).
astype
(
"bool"
)
bb
=
make_tensor
(
b
,
network
)
np
.
testing
.
assert_equal
(
a
[
b
,
:,
0
:
4
:
2
],
get_value
(
aa
[
bb
,
:,
0
:
4
:
2
]))
np
.
testing
.
assert_equal
(
a
[
None
,
b
,
:,
0
:
4
:
2
],
get_value
(
aa
[
None
,
bb
,
:,
0
:
4
:
2
]))
b
=
(
np
.
random
.
sample
((
4
,
3
,
4
))
>
0.5
).
astype
(
"bool"
)
bb
=
make_tensor
(
b
,
network
)
np
.
testing
.
assert_equal
(
a
[...,
b
,
0
:
2
],
get_value
(
aa
[...,
bb
,
0
:
2
]))
np
.
testing
.
assert_equal
(
a
[
None
,
...,
b
,
None
,
0
:
2
],
get_value
(
aa
[
None
,
...,
bb
,
None
,
0
:
2
])
)
b
=
(
np
.
random
.
sample
((
3
,
4
,
3
))
>
0.5
).
astype
(
"bool"
)
bb
=
make_tensor
(
b
,
network
)
np
.
testing
.
assert_equal
(
a
[:,
b
,
0
:
2
,
[
True
,
False
]],
get_value
(
aa
[:,
bb
,
0
:
2
,
[
True
,
False
]])
)
np
.
testing
.
assert_equal
(
a
[:,
b
,
None
,
0
:
2
,
[
True
,
False
]],
get_value
(
aa
[:,
bb
,
None
,
0
:
2
,
[
True
,
False
]]),
)
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
True
,
False
,
None
])
...
...
@@ -781,9 +806,6 @@ def test_indexing_error(test_varnode):
aa
=
make_tensor
(
a
,
network
)
bb
=
make_tensor
(
b
,
network
)
with
pytest
.
raises
(
IndexError
):
aa
[
None
]
# newaxis is not allowed
with
pytest
.
raises
(
IndexError
):
aa
[...,
...]
# only one ellipsis is allowed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录