Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3e206d89
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3e206d89
编写于
3月 03, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge/functional): speed up Split
GitOrigin-RevId: 43550a0706a2794421de56067a11864c10b85c67
上级
730ddc2d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
111 addition
and
54 deletion
+111
-54
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+2
-44
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+2
-0
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+89
-0
imperative/python/src/tensor_utils.h
imperative/python/src/tensor_utils.h
+2
-0
imperative/src/impl/ops/tensor_manip.cpp
imperative/src/impl/ops/tensor_manip.cpp
+13
-8
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+2
-1
src/opr/include/megbrain/opr/tensor_manip.h
src/opr/include/megbrain/opr/tensor_manip.h
+1
-1
未找到文件。
imperative/python/megengine/functional/tensor.py
浏览文件 @
3e206d89
...
...
@@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Tuple, Union
import
numpy
as
np
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt.core2
import
SymbolVar
,
apply
,
dtype_promotion
from
..core._imperative_rt.core2
import
SymbolVar
,
apply
,
dtype_promotion
,
split_cpp
from
..core._wrap
import
as_device
from
..core.ops
import
builtin
from
..core.ops.builtin
import
Copy
,
Identity
...
...
@@ -477,50 +477,8 @@ def split(inp, nsplits_or_sections, axis=0):
[(4, 20), (3, 20), (3, 20)]
[(10, 6), (10, 11), (10, 3)]
"""
ndim
=
len
(
inp
.
shape
)
if
axis
>=
ndim
:
raise
ValueError
(
"Invalid axis {}"
.
format
(
axis
))
Ntotal
=
inp
.
shape
[
axis
]
if
isinstance
(
nsplits_or_sections
,
Sequence
):
Nsections
=
len
(
nsplits_or_sections
)
+
1
is_array
=
True
else
:
Nsections
=
int
(
nsplits_or_sections
)
is_array
=
False
if
is_array
:
partitions
=
[]
div_points
=
[
0
]
+
list
(
nsplits_or_sections
)
+
[
Ntotal
]
for
i
in
range
(
1
,
len
(
div_points
)):
if
div_points
[
i
-
1
]
>
div_points
[
i
]:
raise
ValueError
(
"Invalid nsplits_or_secions: {}"
.
format
(
nsplits_or_sections
)
)
partitions
.
append
(
div_points
[
i
]
-
div_points
[
i
-
1
])
else
:
# scalar
if
Nsections
<=
0
:
raise
ValueError
(
"Number sections must be larger than 0"
)
if
Nsections
>
Ntotal
:
raise
ValueError
(
"The size {} at dim {} cannot be split into {} sections"
.
format
(
Ntotal
,
axis
,
Nsections
)
)
partitions
=
[]
for
i
in
range
(
Nsections
):
section_size
=
(
Ntotal
+
Nsections
-
i
-
1
)
//
Nsections
partitions
.
append
(
section_size
)
partitions
=
[
part
if
isinstance
(
part
,
(
SymbolVar
,
Tensor
))
else
Const
(
part
,
dtype
=
"int32"
,
device
=
inp
.
device
)(
inp
)[
0
]
for
part
in
partitions
]
op
=
builtin
.
Split
(
axis
=
axis
)
return
apply
(
op
,
inp
,
*
partitions
)
return
split_cpp
(
inp
,
nsplits_or_sections
,
axis
)
def
_get_idx
(
index
,
axis
):
...
...
imperative/python/src/tensor.cpp
浏览文件 @
3e206d89
...
...
@@ -633,6 +633,7 @@ WRAP_FUNC_PY35(get_device);
WRAP_FUNC_PY35
(
make_shape_tuple
);
WRAP_FUNC_PY35
(
getitem_cpp
);
WRAP_FUNC_PY35
(
setitem_cpp
);
WRAP_FUNC_PY35
(
split_cpp
);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
...
...
@@ -765,6 +766,7 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE
(
make_shape_tuple
,
make_shape_tuple
),
MGE_PY_INTERFACE
(
getitem_cpp
,
getitem_cpp
),
MGE_PY_INTERFACE
(
setitem_cpp
,
setitem_cpp
),
MGE_PY_INTERFACE
(
split_cpp
,
split_cpp
),
{
nullptr
,
nullptr
,
0
,
nullptr
}};
for
(
auto
&&
def
:
method_defs
)
{
if
(
def
.
ml_meth
!=
nullptr
)
{
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
3e206d89
...
...
@@ -603,6 +603,86 @@ py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_h
return
res
;
}
bool
is_tensor_or_symbolvar
(
py
::
handle
arg
)
{
return
bool
(
TensorWrapper
::
try_cast
(
arg
.
ptr
()))
||
py
::
isinstance
<
PySymbolVar
>
(
arg
);
}
bool
is_py_sequence
(
py
::
handle
arg
)
{
if
(
PyArray_Check
(
arg
.
ptr
())
||
TensorWrapper
::
try_cast
(
arg
.
ptr
())
||
py
::
isinstance
<
PySymbolVar
>
(
arg
))
{
return
false
;
}
return
PySequence_Check
(
arg
.
ptr
());
}
py
::
object
_split_cpp
(
py
::
handle
inp_hdl
,
py
::
handle
nsplits_or_sections_hdl
,
py
::
handle
axis_hdl
)
{
py
::
object
shape_obj
=
getattr
(
inp_hdl
,
"shape"
);
py
::
object
n_total
=
shape_obj
[
axis_hdl
];
int
ndim
=
shape_obj
.
attr
(
"__len__"
)().
cast
<
int
>
();
int
axis
=
axis_hdl
.
cast
<
int
>
();
if
(
axis
>=
ndim
)
{
throw
py
::
value_error
(
"Invalid axis "
+
std
::
to_string
(
axis
));
}
int
n_sections
;
bool
is_array
;
if
(
is_py_sequence
(
nsplits_or_sections_hdl
))
{
n_sections
=
PySequence_Length
(
nsplits_or_sections_hdl
.
ptr
())
+
1
;
is_array
=
true
;
}
else
{
n_sections
=
getattr
(
nsplits_or_sections_hdl
,
"__int__"
)().
cast
<
int
>
();
is_array
=
false
;
}
py
::
list
partitions
;
std
::
shared_ptr
<
OpDef
>
op
;
std
::
vector
<
PyObject
*>
p
;
if
(
is_array
)
{
py
::
list
div_points
;
py
::
list
sections
=
py
::
reinterpret_borrow
<
py
::
object
>
(
nsplits_or_sections_hdl
);
div_points
.
append
(
0
);
for
(
size_t
i
=
0
;
i
<
sections
.
size
();
++
i
)
{
div_points
.
append
(
sections
[
i
]);
}
div_points
.
append
(
n_total
);
for
(
size_t
i
=
1
;
i
<
div_points
.
size
();
++
i
)
{
if
(
div_points
[
i
-
1
]
>
div_points
[
i
])
{
throw
py
::
value_error
(
"Invalid nsplits_or_secions: "
+
repr
(
nsplits_or_sections_hdl
).
cast
<
std
::
string
>
());
}
py
::
object
pos
=
div_points
[
i
]
-
div_points
[
i
-
1
];
if
(
is_tensor_or_symbolvar
(
pos
))
{
partitions
.
append
(
pos
);
}
else
{
partitions
.
append
(
_Const
(
pos
,
py
::
cast
((
mgb
::
DType
)
dtype
::
Int32
()),
getattr
(
inp_hdl
,
"device"
),
inp_hdl
));
}
}
op
=
Split
::
make
(
axis
,
0
);
p
.
resize
(
partitions
.
size
()
+
2
);
for
(
size_t
i
=
0
;
i
<
partitions
.
size
();
++
i
)
{
p
[
i
+
2
]
=
partitions
[
i
].
ptr
();
}
}
else
{
if
(
n_sections
<=
0
)
{
throw
py
::
value_error
(
"Number sections must be larger than 0"
);
}
if
(
py
::
int_
(
n_sections
)
>
n_total
)
{
throw
py
::
value_error
(
"The size "
+
repr
(
n_total
).
cast
<
std
::
string
>
()
+
" at dim "
+
std
::
to_string
(
axis
)
+
" cannot be split into "
+
std
::
to_string
(
n_sections
)
+
" sections"
);
}
op
=
Split
::
make
(
axis
,
n_sections
);
p
.
resize
(
2
);
}
py
::
object
Op
=
py
::
cast
(
op
);
p
[
0
]
=
Op
.
ptr
();
p
[
1
]
=
inp_hdl
.
ptr
();
return
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
.
data
(),
p
.
size
()));
}
PyObject
*
make_shape_tuple
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_make_shape_tuple
(
py
::
handle
(
args
[
0
])).
release
().
ptr
();
...
...
@@ -627,4 +707,13 @@ PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
PyObject
*
split_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
)
{
try
{
return
_split_cpp
(
py
::
handle
(
args
[
0
]),
py
::
handle
(
args
[
1
]),
py
::
handle
(
args
[
2
]))
.
release
()
.
ptr
();
}
PYEXT17_TRANSLATE_EXC_RET
(
nullptr
)
}
}
// namespace mgb::imperative::python
imperative/python/src/tensor_utils.h
浏览文件 @
3e206d89
...
...
@@ -8,4 +8,6 @@ PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject
*
setitem_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
split_cpp
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
);
}
// namespace mgb::imperative::python
\ No newline at end of file
imperative/src/impl/ops/tensor_manip.cpp
浏览文件 @
3e206d89
...
...
@@ -285,7 +285,7 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
opt
.
method
==
Options
::
Method
::
SPECIFY
,
"only Split with SPECIFY output shapes is supported"
);
mgb_assert
(
opt
.
partition
.
size
()
==
opt
.
nr_part
);
return
Split
::
make
(
axis
);
return
Split
::
make
(
axis
,
0
);
}
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
@@ -293,13 +293,18 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto
&&
sp
=
static_cast
<
const
Split
&>
(
def
);
OperatorNodeConfig
config
{
sp
.
make_name
()};
opr
::
Split
::
Options
opt
;
opt
.
axis
=
sp
.
axis
;
opt
.
method
=
Options
::
Method
::
SPECIFY
;
mgb_assert
(
inputs
.
size
()
>
1
);
opt
.
nr_part
=
inputs
.
size
()
-
1
;
opt
.
partition
.
resize
(
opt
.
nr_part
);
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
opt
.
partition
[
i
-
1
]
=
inputs
[
i
];
if
(
sp
.
nsections
)
{
opt
=
Options
::
make_average
(
sp
.
axis
,
sp
.
nsections
);
opt
.
method
=
Options
::
Method
::
CALL_BACK
;
}
else
{
opt
.
axis
=
sp
.
axis
;
opt
.
method
=
Options
::
Method
::
SPECIFY
;
mgb_assert
(
inputs
.
size
()
>
1
);
opt
.
nr_part
=
inputs
.
size
()
-
1
;
opt
.
partition
.
resize
(
opt
.
nr_part
);
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
opt
.
partition
[
i
-
1
]
=
inputs
[
i
];
}
return
opr
::
Split
::
make
(
inputs
[
0
],
opt
,
config
);
}
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
3e206d89
...
...
@@ -426,7 +426,8 @@ def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;
def Split: MgbHashableOp<"Split", [EmptyParam]> {
let extraArguments = (ins
MgbI32Attr:$axis
MgbI32Attr:$axis,
MgbI32Attr:$nsections
);
}
...
...
src/opr/include/megbrain/opr/tensor_manip.h
浏览文件 @
3e206d89
...
...
@@ -422,7 +422,7 @@ public:
/*!
* \brief make split option by splitting into average parts
*/
static
Options
make_average
(
int
axis
,
size_t
nr_part
);
MGE_WIN_DECLSPEC_FUC
static
Options
make_average
(
int
axis
,
size_t
nr_part
);
static
Options
make_partition
(
int
axis
,
const
SymbolVarArray
&
partition
);
static
Options
make_partition
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录