Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a4d473c9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a4d473c9
编写于
3月 03, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge/functional): speed up AddAxis
GitOrigin-RevId: 92a3e1bdd3c4f0d1d68d8571cad78c1c8ea0f634
上级
3e206d89
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
74 addition
and
22 deletion
+74
-22
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+8
-22
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+2
-0
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+62
-0
imperative/python/src/tensor_utils.h
imperative/python/src/tensor_utils.h
+2
-0
未找到文件。
imperative/python/megengine/functional/tensor.py
浏览文件 @
a4d473c9
...
...
@@ -12,7 +12,13 @@ from typing import Iterable, Optional, Sequence, Tuple, Union
import
numpy
as
np
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt.core2
import
SymbolVar
,
apply
,
dtype_promotion
,
split_cpp
from
..core._imperative_rt.core2
import
(
SymbolVar
,
apply
,
dtype_promotion
,
expand_dims_cpp
,
split_cpp
,
)
from
..core._wrap
import
as_device
from
..core.ops
import
builtin
from
..core.ops.builtin
import
Copy
,
Identity
...
...
@@ -959,27 +965,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
(1, 2)
"""
def
get_axes
():
try
:
return
[
int
(
axis
)]
except
(
TypeError
,
ValueError
):
pass
return
list
(
map
(
int
,
axis
))
axis
=
get_axes
()
try
:
ndim
=
inp
.
ndim
+
len
(
axis
)
axis
=
sorted
(
i
+
ndim
if
i
<
0
else
i
for
i
in
axis
)
except
ValueError
:
if
any
([
ind
<
0
for
ind
in
axis
]):
raise
IndexError
(
"Does not support negative index when tensor's ndim is unknown"
)
axis
=
sorted
(
axis
)
assert
axis
,
"axis could not be empty"
op
=
builtin
.
AddAxis
(
axis
=
axis
)
(
result
,)
=
apply
(
op
,
inp
)
return
result
return
expand_dims_cpp
(
inp
,
axis
)
def
squeeze
(
inp
:
Tensor
,
axis
:
Optional
[
Union
[
int
,
Sequence
[
int
]]]
=
None
)
->
Tensor
:
...
...
imperative/python/src/tensor.cpp
浏览文件 @
a4d473c9
...
...
@@ -634,6 +634,7 @@ WRAP_FUNC_PY35(make_shape_tuple);
WRAP_FUNC_PY35
(
getitem_cpp
);
WRAP_FUNC_PY35
(
setitem_cpp
);
WRAP_FUNC_PY35
(
split_cpp
);
WRAP_FUNC_PY35
(
expand_dims_cpp
);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
...
...
@@ -767,6 +768,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE
(
getitem_cpp
,
getitem_cpp
),
MGE_PY_INTERFACE
(
setitem_cpp
,
setitem_cpp
),
MGE_PY_INTERFACE
(
split_cpp
,
split_cpp
),
MGE_PY_INTERFACE
(
expand_dims_cpp
,
expand_dims_cpp
),
{
nullptr
,
nullptr
,
0
,
nullptr
}};
for
(
auto
&&
def
:
method_defs
)
{
if
(
def
.
ml_meth
!=
nullptr
)
{
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
a4d473c9
...
...
@@ -683,6 +683,59 @@ py::object _split_cpp(
return
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
}
py
::
object
_expand_dims_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
axis_hdl
)
{
std
::
vector
<
int32_t
>
axis
;
if
(
is_py_sequence
(
axis_hdl
.
ptr
()))
{
py
::
list
tmp_list
=
py
::
reinterpret_steal
<
py
::
list
>
(
PySequence_List
(
axis_hdl
.
ptr
()));
for
(
size_t
i
=
0
;
i
<
tmp_list
.
size
();
++
i
)
{
axis
.
push_back
(
tmp_list
[
i
].
attr
(
"__int__"
)().
cast
<
int32_t
>
());
}
}
else
{
axis
.
push_back
(
getattr
(
axis_hdl
,
"__int__"
)().
cast
<
int
>
());
}
bool
unknown_ndim
=
true
;
size_t
ndim
=
axis
.
size
();
if
(
auto
p
=
TensorWrapper
::
try_cast
(
inp_hdl
.
ptr
()))
{
auto
&&
shape
=
p
->
m_tensor
->
shape
();
if
(
shape
)
{
unknown_ndim
=
false
;
ndim
+=
shape
->
ndim
;
}
}
else
{
auto
&&
var
=
inp_hdl
.
cast
<
PySymbolVar
*>
();
auto
&&
mgr
=
var
->
m_node
->
owner_graph
()
->
static_infer_manager
();
auto
&&
shape
=
mgr
.
infer_shape_fallible
(
var
->
m_node
);
if
(
shape
)
{
unknown_ndim
=
false
;
ndim
+=
shape
->
ndim
;
}
}
for
(
size_t
i
=
0
;
i
<
axis
.
size
();
++
i
)
{
if
(
axis
[
i
]
<
0
)
{
if
(
unknown_ndim
)
{
throw
py
::
index_error
(
"Does not support negative index when tensor's ndim is "
"unknown"
);
}
axis
[
i
]
+=
ndim
;
}
}
if
(
!
axis
.
size
())
{
throw
py
::
index_error
(
"axis could not be empty"
);
}
std
::
sort
(
axis
.
begin
(),
axis
.
end
());
std
::
shared_ptr
<
OpDef
>
op
=
AddAxis
::
make
(
axis
=
axis
);
std
::
vector
<
PyObject
*>
p
;
p
.
resize
(
2
);
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
inp_hdl
.
ptr
();
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
return
ret
[
0
];
}
PyObject
*
make_shape_tuple
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_make_shape_tuple
(
py
::
handle
(
args
[
0
])).
release
().
ptr
();
...
...
@@ -716,4 +769,13 @@ PyObject* split_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
expand_dims_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_expand_dims_cpp
(
py
::
handle
(
args
[
0
]),
py
::
handle
(
args
[
1
]))
.
release
()
.
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
}
// namespace mgb::imperative::python
imperative/python/src/tensor_utils.h
浏览文件 @
a4d473c9
...
...
@@ -10,4 +10,6 @@ PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject
*
split_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
expand_dims_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
}
// namespace mgb::imperative::python
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录