Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
282dfc62
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
282dfc62
编写于
4月 20, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): alloc enum type class on heap
GitOrigin-RevId: d2b2acea229df68151f04ce17c1e73621dd7fb60
上级
1e6ef377
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
62 addition
and
32 deletion
+62
-32
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+4
-6
imperative/python/test/unit/core/test_imperative_rt.py
imperative/python/test/unit/core/test_imperative_rt.py
+13
-0
imperative/tablegen/targets/python_c_extension.cpp
imperative/tablegen/targets/python_c_extension.cpp
+44
-26
imperative/test/CMakeLists.txt
imperative/test/CMakeLists.txt
+1
-0
未找到文件。
imperative/python/src/ops.cpp
浏览文件 @
282dfc62
...
@@ -170,7 +170,7 @@ struct EnumTrait;
...
@@ -170,7 +170,7 @@ struct EnumTrait;
PyObject_HEAD \
PyObject_HEAD \
T value; \
T value; \
constexpr static const char *name = EnumTrait<T>::name; \
constexpr static const char *name = EnumTrait<T>::name; \
static PyTypeObject type; \
static PyTypeObject
*
type; \
static const char* members[]; \
static const char* members[]; \
static std::unordered_map<std::string, T> mem2value; \
static std::unordered_map<std::string, T> mem2value; \
static PyObject* pyobj_insts[];
static PyObject* pyobj_insts[];
...
@@ -196,7 +196,7 @@ struct EnumWrapper {
...
@@ -196,7 +196,7 @@ struct EnumWrapper {
}
}
static
bool
load
(
py
::
handle
src
,
T
&
value
)
{
static
bool
load
(
py
::
handle
src
,
T
&
value
)
{
PyObject
*
obj
=
src
.
ptr
();
PyObject
*
obj
=
src
.
ptr
();
if
(
PyObject_TypeCheck
(
obj
,
&
type
))
{
if
(
PyObject_TypeCheck
(
obj
,
type
))
{
value
=
reinterpret_cast
<
EnumWrapper
*>
(
obj
)
->
value
;
value
=
reinterpret_cast
<
EnumWrapper
*>
(
obj
)
->
value
;
return
true
;
return
true
;
}
}
...
@@ -224,7 +224,6 @@ struct EnumWrapper {
...
@@ -224,7 +224,6 @@ struct EnumWrapper {
template
<
typename
T
>
template
<
typename
T
>
struct
BitCombinedEnumWrapper
{
struct
BitCombinedEnumWrapper
{
PyEnumHead
PyEnumHead
static
PyNumberMethods
number_methods
;
std
::
string
to_string
()
const
{
std
::
string
to_string
()
const
{
uint32_t
value_int
=
static_cast
<
uint32_t
>
(
value
);
uint32_t
value_int
=
static_cast
<
uint32_t
>
(
value
);
if
(
value_int
==
0
)
{
if
(
value_int
==
0
)
{
...
@@ -302,7 +301,7 @@ struct BitCombinedEnumWrapper {
...
@@ -302,7 +301,7 @@ struct BitCombinedEnumWrapper {
}
}
static
bool
load
(
py
::
handle
src
,
T
&
value
)
{
static
bool
load
(
py
::
handle
src
,
T
&
value
)
{
PyObject
*
obj
=
src
.
ptr
();
PyObject
*
obj
=
src
.
ptr
();
if
(
PyObject_TypeCheck
(
obj
,
&
type
))
{
if
(
PyObject_TypeCheck
(
obj
,
type
))
{
value
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
;
value
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
;
return
true
;
return
true
;
}
}
...
@@ -330,8 +329,7 @@ struct BitCombinedEnumWrapper {
...
@@ -330,8 +329,7 @@ struct BitCombinedEnumWrapper {
auto
v
=
static_cast
<
std
::
underlying_type_t
<
T
>>
(
value
);
auto
v
=
static_cast
<
std
::
underlying_type_t
<
T
>>
(
value
);
mgb_assert
(
v
<=
EnumTrait
<
T
>::
max
);
mgb_assert
(
v
<=
EnumTrait
<
T
>::
max
);
if
((
!
v
)
||
(
v
&
(
v
-
1
)))
{
if
((
!
v
)
||
(
v
&
(
v
-
1
)))
{
PyTypeObject
*
pytype
=
&
type
;
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
PyObject
*
obj
=
pytype
->
tp_alloc
(
pytype
,
0
);
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
value
;
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
value
;
return
obj
;
return
obj
;
}
else
{
}
else
{
...
...
imperative/python/test/unit/core/test_imperative_rt.py
浏览文件 @
282dfc62
...
@@ -69,3 +69,16 @@ def test_raw_tensor():
...
@@ -69,3 +69,16 @@ def test_raw_tensor():
np
.
testing
.
assert_allclose
(
x
*
x
,
yy
.
numpy
())
np
.
testing
.
assert_allclose
(
x
*
x
,
yy
.
numpy
())
(
yy
,)
=
apply
(
Elemwise
(
Elemwise
.
Mode
.
MUL
),
xx
,
xx
)
(
yy
,)
=
apply
(
Elemwise
(
Elemwise
.
Mode
.
MUL
),
xx
,
xx
)
np
.
testing
.
assert_allclose
(
x
*
x
,
yy
.
numpy
())
np
.
testing
.
assert_allclose
(
x
*
x
,
yy
.
numpy
())
def
test_opdef_path
():
from
megengine.core.ops.builtin
import
Elemwise
assert
Elemwise
.
__module__
==
"megengine.core._imperative_rt.ops"
assert
Elemwise
.
__name__
==
"Elemwise"
assert
Elemwise
.
__qualname__
==
"Elemwise"
Mode
=
Elemwise
.
Mode
assert
Mode
.
__module__
==
"megengine.core._imperative_rt.ops"
assert
Mode
.
__name__
==
"Mode"
assert
Mode
.
__qualname__
==
"Elemwise.Mode"
imperative/tablegen/targets/python_c_extension.cpp
浏览文件 @
282dfc62
...
@@ -97,7 +97,7 @@ void EnumAttrEmitter::emit_tpl_spl() {
...
@@ -97,7 +97,7 @@ void EnumAttrEmitter::emit_tpl_spl() {
if
(
!
firstOccur
)
return
;
if
(
!
firstOccur
)
return
;
os
<<
tgfmt
(
os
<<
tgfmt
(
"template<> PyTypeObject
$enumTpl<$opClass::$enumClass>::type = {}
;
\n
"
,
"template<> PyTypeObject
* $enumTpl<$opClass::$enumClass>::type = nullptr
;
\n
"
,
&
ctx
);
&
ctx
);
auto
quote
=
[
&
](
auto
&&
i
)
->
std
::
string
{
auto
quote
=
[
&
](
auto
&&
i
)
->
std
::
string
{
...
@@ -120,13 +120,6 @@ $enumTpl<$opClass::$enumClass>::mem2value = {$0};
...
@@ -120,13 +120,6 @@ $enumTpl<$opClass::$enumClass>::mem2value = {$0};
"template<> PyObject* "
"template<> PyObject* "
"$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};
\n
"
,
"$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};
\n
"
,
&
ctx
,
attr
->
getEnumMembers
().
size
());
&
ctx
,
attr
->
getEnumMembers
().
size
());
if
(
attr
->
getEnumCombinedFlag
())
{
os
<<
tgfmt
(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods = {};
\n
"
,
&
ctx
);
}
}
}
Initproc
EnumAttrEmitter
::
emit_initproc
()
{
Initproc
EnumAttrEmitter
::
emit_initproc
()
{
...
@@ -140,45 +133,70 @@ void $0(PyTypeObject& py_type) {
...
@@ -140,45 +133,70 @@ void $0(PyTypeObject& py_type) {
if
(
firstOccur
)
{
if
(
firstOccur
)
{
os
<<
tgfmt
(
R"(
os
<<
tgfmt
(
R"(
e_type = {PyVarObject_HEAD_INIT(NULL, 0)};
static PyType_Slot slots[] = {
e_type.tp_name = "megengine.core._imperative_rt.ops.$opClass.$enumClass";
{Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
e_type.tp_basicsize = sizeof($enumTpl<$opClass::$enumClass>);
{Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
e_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
e_type.tp_doc = "$opClass.$enumClass";
e_type.tp_base = &PyBaseObject_Type;
e_type.tp_repr = $enumTpl<$opClass::$enumClass>::py_repr;
e_type.tp_richcompare = $enumTpl<$opClass::$enumClass>::tp_richcompare;
)"
,
&
ctx
);
)"
,
&
ctx
);
if
(
attr
->
getEnumCombinedFlag
())
{
if
(
attr
->
getEnumCombinedFlag
())
{
// only bit combined enum could new instance because bitwise operation,
// only bit combined enum could new instance because bitwise operation,
// others should always use singleton
// others should always use singleton
os
<<
tgfmt
(
R"(
os
<<
tgfmt
(
R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
{Py_tp_new, (void*)$enumTpl<$opClass::$enumClass>::py_new_combined_enum},
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
{Py_nb_or, (void*)$enumTpl<$opClass::$enumClass>::py_or},
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
{Py_nb_and, (void*)$enumTpl<$opClass::$enumClass>::py_and},
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
e_type.tp_as_number = &number_method;
)"
,
&
ctx
);
)"
,
&
ctx
);
}
}
os
<<
R"(
{0, NULL}
};)"
;
os
<<
tgfmt
(
R"(
static PyType_Spec spec = {
// name
"megengine.core._imperative_rt.ops.$opClass.$enumClass",
// basicsize
sizeof($enumTpl<$opClass::$enumClass>),
// itemsize
0,
// flags
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
// slots
slots
};)"
,
&
ctx
);
os
<<
tgfmt
(
R"(
e_type = reinterpret_cast<PyTypeObject*>(PyType_FromSpec(&spec));
)"
,
&
ctx
);
os
<<
" mgb_assert(PyType_Ready(&e_type) >= 0);
\n
"
;
for
(
auto
&&
i
:
{
std
::
pair
<
std
::
string
,
std
::
string
>
{
"__name__"
,
tgfmt
(
"$enumClass"
,
&
ctx
)},
{
"__module__"
,
"megengine.core._imperative_rt.ops"
},
{
"__qualname__"
,
tgfmt
(
"$opClass.$enumClass"
,
&
ctx
)}})
{
os
<<
formatv
(
R"(
mgb_assert(
e_type->tp_setattro(
reinterpret_cast<PyObject*>(e_type),
py::cast("{0}").release().ptr(),
py::cast("{1}").release().ptr()) >= 0);
)"
,
i
.
first
,
i
.
second
);
}
auto
&&
members
=
attr
->
getEnumMembers
();
auto
&&
members
=
attr
->
getEnumMembers
();
for
(
size_t
idx
=
0
;
idx
<
members
.
size
();
++
idx
)
{
for
(
size_t
idx
=
0
;
idx
<
members
.
size
();
++
idx
)
{
os
<<
tgfmt
(
R"({
os
<<
tgfmt
(
R"({
PyObject* inst = e_type
.tp_alloc(&
e_type, 0);
PyObject* inst = e_type
->tp_alloc(
e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type
.
tp_dict, "$0", inst) >= 0);
mgb_assert(PyDict_SetItemString(e_type
->
tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})"
,
&
ctx
,
members
[
idx
],
idx
);
})"
,
&
ctx
,
members
[
idx
],
idx
);
}
}
os
<<
" PyType_Modified(&e_type);
\n
"
;
}
}
os
<<
tgfmt
(
R"(
os
<<
tgfmt
(
R"(
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(
&
e_type)) >= 0);
py_type.tp_dict, "$enumClass", reinterpret_cast<PyObject*>(e_type)) >= 0);
)"
,
&
ctx
);
)"
,
&
ctx
);
os
<<
"}
\n
"
;
os
<<
"}
\n
"
;
return
initproc
;
return
initproc
;
...
...
imperative/test/CMakeLists.txt
浏览文件 @
282dfc62
...
@@ -11,6 +11,7 @@ endif()
...
@@ -11,6 +11,7 @@ endif()
# TODO: turn python binding into a static/object library
# TODO: turn python binding into a static/object library
add_executable
(
imperative_test
${
SOURCES
}
${
SRCS
}
)
add_executable
(
imperative_test
${
SOURCES
}
${
SRCS
}
)
add_dependencies
(
imperative_test mgb_opdef
)
target_include_directories
(
imperative_test PRIVATE
${
MGB_TEST_DIR
}
/include ../src/include
${
MGB_OPDEF_OUT_DIR
}
)
target_include_directories
(
imperative_test PRIVATE
${
MGB_TEST_DIR
}
/include ../src/include
${
MGB_OPDEF_OUT_DIR
}
)
# Python binding
# Python binding
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录