Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9279104b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9279104b
编写于
8月 20, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): add opdef serialization and apply_module_trace
GitOrigin-RevId: 5b45bded1de8e1fb36447d4469423ef68ff627e8
上级
aa204040
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
337 addition
and
14 deletion
+337
-14
imperative/python/megengine/experimental/traced_module/__init__.py
...e/python/megengine/experimental/traced_module/__init__.py
+7
-0
imperative/python/megengine/experimental/traced_module/serialization.py
...hon/megengine/experimental/traced_module/serialization.py
+34
-0
imperative/python/src/module_trace.cpp
imperative/python/src/module_trace.cpp
+41
-0
imperative/python/src/module_trace.h
imperative/python/src/module_trace.h
+20
-0
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+78
-0
imperative/python/src/pyext17.h
imperative/python/src/pyext17.h
+1
-8
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+38
-1
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+9
-4
imperative/python/test/unit/core/test_serialization.py
imperative/python/test/unit/core/test_serialization.py
+27
-0
imperative/tablegen/targets/python_c_extension.cpp
imperative/tablegen/targets/python_c_extension.cpp
+82
-1
未找到文件。
imperative/python/megengine/experimental/traced_module/__init__.py
0 → 100644
浏览文件 @
9279104b
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
imperative/python/megengine/experimental/traced_module/serialization.py
0 → 100644
浏览文件 @
9279104b
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Dict
from
...core._imperative_rt
import
OpDef
from
...core.ops
import
builtin
from
...version
import
__version__
OPDEF_PARAM_LOADER
=
{}
def
get_opdef_state
(
obj
:
OpDef
)
->
Dict
:
state
=
obj
.
__getstate__
()
state
[
"type"
]
=
type
(
obj
)
state
[
"version"
]
=
__version__
return
state
def
load_opdef_from_state
(
state
:
Dict
)
->
OpDef
:
assert
"type"
in
state
and
issubclass
(
state
[
"type"
],
OpDef
)
assert
"version"
in
state
opdef_type
=
state
.
pop
(
"type"
)
if
opdef_type
in
OPDEF_PARAM_LOADER
:
loader
=
OPDEF_PARAM_LOADER
[
opdef_type
]
state
=
loader
(
state
)
state
.
pop
(
"version"
)
opdef_obj
=
opdef_type
()
opdef_obj
.
__setstate__
(
state
)
return
opdef_obj
imperative/python/src/module_trace.cpp
0 → 100644
浏览文件 @
9279104b
/**
* \file imperative/python/src/module_trace.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./module_trace.h"
#include "./helper.h" // include op pybind11 caster
namespace
py
=
pybind11
;
namespace
mgb
::
imperative
::
python
{
apply_result_t
apply_module_trace
(
ApplyContext
&
ctx
)
{
apply_result_t
outputs
;
auto
args
=
py
::
tuple
(
ctx
.
nargs
+
1
);
args
[
0
]
=
py
::
cast
(
ctx
.
op
);
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
args
[
i
+
1
]
=
TensorWrapper
::
make
(
ctx
.
args
[
i
]
->
shared_from_this
());
}
auto
pyout
=
PyObject_Call
(
cpp_apply_module_trace
,
args
.
ptr
(),
nullptr
);
if
(
!
pyout
)
throw
py
::
error_already_set
();
auto
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
pyout
);
// assumption: python function always returns PyList
auto
tup
=
py
::
reinterpret_borrow
<
py
::
list
>
(
ret
);
for
(
auto
i
=
0
;
i
<
tup
.
size
();
i
++
)
{
auto
tw
=
TensorWrapper
::
try_cast
(
tup
[
i
].
ptr
());
outputs
.
emplace_back
(
tw
->
m_tensor
);
}
return
outputs
;
}
}
// namespace mgb::imperative::python
imperative/python/src/module_trace.h
0 → 100644
浏览文件 @
9279104b
/**
* \file imperative/python/src/module_trace.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "./tensor.h"
namespace
mgb
::
imperative
::
python
{
apply_result_t
apply_module_trace
(
ApplyContext
&
ctx
);
}
// namespace mgb::imperative::python
imperative/python/src/ops.cpp
浏览文件 @
9279104b
...
...
@@ -88,6 +88,19 @@ PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
return
obj
;
}
template
<
typename
T
,
typename
SNIFAE
=
void
>
struct
serialization
{
static
T
load
(
py
::
object
obj
)
{
return
py
::
cast
<
T
>
(
obj
);
}
template
<
typename
U
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
std
::
decay_t
<
U
>
>>>
static
py
::
object
dump
(
U
&&
t
)
{
return
py
::
cast
(
std
::
forward
<
U
>
(
t
));
}
};
template
<
typename
T
>
void
py_dealloc_generic
(
PyObject
*
obj
)
{
reinterpret_cast
<
T
*>
(
obj
)
->
op
.
reset
();
...
...
@@ -127,6 +140,13 @@ struct PyOpDef {
static
PyGetSetDef
py_getsetters
[];
static
Py_hash_t
tp_hash
(
PyObject
*
obj
);
static
PyObject
*
tp_richcompare
(
PyObject
*
self
,
PyObject
*
other
,
int
op
);
static
PyObject
*
py_repr
(
PyObject
*
self
)
{
return
py
::
cast
(
reinterpret_cast
<
PyOpDef
*>
(
self
)
->
op
->
make_name
())
.
release
()
.
ptr
();
}
};
PyTypeObject
PyOpType
(
OpDef
);
std
::
unordered_map
<
mgb
::
Typeinfo
*
,
PyTypeObject
*>
PyOp
(
OpDef
)
::
ctype2pytype
;
...
...
@@ -191,6 +211,13 @@ struct EnumWrapper {
std
::
string
(
name
)
+
"."
+
reinterpret_cast
<
EnumWrapper
*>
(
self
)
->
to_string
())
.
release
().
ptr
();
}
static
PyObject
*
py_dump
(
PyObject
*
self
)
{
return
py
::
cast
(
reinterpret_cast
<
EnumWrapper
*>
(
self
)
->
to_string
())
.
release
()
.
ptr
();
}
static
PyObject
*
tp_richcompare
(
PyObject
*
self
,
PyObject
*
other
,
int
op
)
{
if
(
op
==
Py_EQ
||
op
==
Py_NE
)
{
T
lhs
,
rhs
;
...
...
@@ -279,6 +306,19 @@ struct BitCombinedEnumWrapper {
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
to_string
())
.
release
().
ptr
();
}
static
PyObject
*
py_dump
(
PyObject
*
self
)
{
std
::
vector
<
std
::
string
>
result
;
auto
value
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
;
uint32_t
value_int
=
static_cast
<
uint32_t
>
(
value
);
for
(
uint32_t
i
=
0
;
i
<
32
;
i
++
)
{
if
(
value_int
>>
i
&
1
)
{
result
.
push_back
(
members
[
i
]);
}
}
return
py
::
tuple
(
py
::
cast
(
result
)).
release
().
ptr
();
}
static
PyObject
*
py_or
(
PyObject
*
self
,
PyObject
*
other
)
{
if
(
!
(
self
->
ob_type
==
other
->
ob_type
)){
return
PyErr_Format
(
...
...
@@ -326,6 +366,24 @@ struct BitCombinedEnumWrapper {
return
false
;
}
}
if
(
py
::
isinstance
<
py
::
tuple
>
(
src
))
{
auto
params
=
py
::
cast
<
std
::
vector
<
std
::
string
>>
(
src
);
bool
first
=
true
;
for
(
auto
s
:
params
){
auto
&&
iter
=
mem2value
.
find
(
normalize_enum
(
s
));
if
(
iter
!=
mem2value
.
end
())
{
if
(
first
)
{
value
=
iter
->
second
;
first
=
false
;
}
else
{
value
|=
iter
->
second
;
}
}
else
{
return
false
;
}
}
return
true
;
}
if
(
py
::
isinstance
<
py
::
int_
>
(
obj
))
{
auto
v
=
py
::
cast
<
std
::
underlying_type_t
<
T
>>
(
src
);
if
(
v
>
EnumTrait
<
T
>::
max
)
{
...
...
@@ -351,6 +409,25 @@ struct BitCombinedEnumWrapper {
}
};
template
<
typename
T
>
struct
serialization
<
T
,
std
::
enable_if_t
<
std
::
is_enum_v
<
std
::
decay_t
<
T
>>>>
{
static
T
load
(
py
::
object
obj
)
{
auto
caster
=
pybind11
::
detail
::
type_caster
<
T
>
();
if
(
caster
.
load
(
obj
,
true
))
{
return
caster
;
}
else
{
PyErr_SetString
(
PyExc_RuntimeError
,
"load faild
\n
"
);
return
caster
;
}
}
static
py
::
object
dump
(
T
t
)
{
return
py
::
cast
(
t
).
attr
(
"dump"
)();
}
};
void
_init_py_op_def
(
py
::
module
m
)
{
using
py_op
=
PyOp
(
OpDef
);
auto
&
py_type
=
PyOpType
(
OpDef
);
...
...
@@ -363,6 +440,7 @@ void _init_py_op_def(py::module m) {
py_type
.
tp_hash
=
PyOp
(
OpDef
)
::
tp_hash
;
py_type
.
tp_richcompare
=
PyOp
(
OpDef
)
::
tp_richcompare
;
py_type
.
tp_getset
=
py_op
::
py_getsetters
;
py_type
.
tp_repr
=
py_op
::
py_repr
;
mgb_assert
(
PyType_Ready
(
&
py_type
)
>=
0
);
m
.
add_object
(
"OpDef"
,
reinterpret_cast
<
PyObject
*>
(
&
py_type
));
}
...
...
imperative/python/src/pyext17.h
浏览文件 @
9279104b
...
...
@@ -451,18 +451,11 @@ public:
template
<
typename
...
Args
>
static
PyObject
*
cnew
(
Args
&&
...
args
)
{
auto
*
pytype
=
type
().
operator
->
();
auto
*
self
=
pytype
->
tp_alloc
(
pytype
,
0
);
auto
*
inst
=
reinterpret_cast
<
wrap_t
*>
(
self
)
->
inst
();
if
constexpr
(
has_vectorcall
&&
tp_vectorcall
::
valid
)
{
reinterpret_cast
<
wrap_t
*>
(
self
)
->
vectorcall_slot
=
&
tp_vectorcall
::
template
impl
<
>;
}
new
(
inst
)
T
(
std
::
forward
<
Args
>
(
args
)...);
return
self
;
return
cnew_with_type
(
pytype
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
...
Args
>
static
PyObject
*
cnew_with_type
(
PyTypeObject
*
pytype
,
Args
&&
...
args
)
{
auto
*
self
=
pytype
->
tp_alloc
(
pytype
,
0
);
auto
*
inst
=
reinterpret_cast
<
wrap_t
*>
(
self
)
->
inst
();
if
constexpr
(
has_vectorcall
&&
tp_vectorcall
::
valid
)
{
...
...
imperative/python/src/tensor.cpp
浏览文件 @
9279104b
...
...
@@ -20,6 +20,7 @@
#include "./tensor.h"
#include "./grad.h"
#include "./trace.h"
#include "./module_trace.h"
#include "./common.h"
#include "./numpy_dtypes.h"
#include "./graph_rt.h"
...
...
@@ -41,6 +42,7 @@ interpreter::Interpreter::Channel* interpreter_for_py;
PyObject
*
cpp_apply_with_tracing
,
*
cpp_apply_const_with_tracing
;
PyObject
*
cpp_apply_backward_varnode
;
PyObject
*
cpp_apply_module_trace
;
std
::
shared_ptr
<
Tensor
>
make_const
(
imperative
::
TensorPtr
value
)
{
if
(
!
(
ApplyContext
::
global_enable
&
Tensor
::
Flags
::
TRACE
))
{
...
...
@@ -70,6 +72,7 @@ std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
REGISTE_APPLY_FUNC
(
cpp_apply_with_tracing
)
REGISTE_APPLY_FUNC
(
cpp_apply_const_with_tracing
)
REGISTE_APPLY_FUNC
(
cpp_apply_backward_varnode
)
REGISTE_APPLY_FUNC
(
cpp_apply_module_trace
)
#undef REGISTE_APPLY_FUNC
...
...
@@ -79,6 +82,14 @@ Tensor::flags_t ApplyContext::global_enable = 0;
void
set_tracing
()
{
ApplyContext
::
global_enable
|=
Tensor
::
Flags
::
TRACE
;
}
void
unset_tracing
()
{
ApplyContext
::
global_enable
&=
~
Tensor
::
Flags
::
TRACE
;
}
void
set_module_tracing
()
{
ApplyContext
::
global_enable
|=
Tensor
::
Flags
::
MODULE_TRACE
;
}
void
unset_module_tracing
()
{
ApplyContext
::
global_enable
&=
~
Tensor
::
Flags
::
MODULE_TRACE
;
}
bool
is_tracing_module
()
{
return
ApplyContext
::
global_enable
&
Tensor
::
Flags
::
MODULE_TRACE
;
}
bool
skip_tracing
=
false
;
apply_result_t
apply
(
ApplyContext
&
ctx
)
{
...
...
@@ -117,6 +128,11 @@ apply_result_t apply(ApplyContext& ctx) {
return
ret
;
}
if
(
flags
&
Tensor
::
Flags
::
MODULE_TRACE
)
{
return
apply_module_trace
(
ctx
);
}
if
(
flags
&
Tensor
::
Flags
::
TRACE
)
{
return
apply_trace
(
ctx
);
}
else
{
...
...
@@ -310,6 +326,21 @@ REGISTE_TENSORWRAPPER_FUNC(bool, recording)
#undef REGISTE_TENSORWRAPPER_FUNC
PyObject
*
TensorWrapper
::
module_trace_info
()
{
if
(
!
m_tensor
->
m_module_trace_info
.
ptr
())
{
PyErr_SetString
(
PyExc_AttributeError
,
"Has no attribute named
\'
_NodeMixin__node
\'
, please "
"set it first"
);
return
nullptr
;
}
return
m_tensor
->
m_module_trace_info
.
inc_ref
().
ptr
();
}
void
TensorWrapper
::
set_module_trace_info
(
PyObject
*
obj
)
{
m_tensor
->
m_module_trace_info
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
}
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
...
...
@@ -495,7 +526,9 @@ void TensorWrapper::reset(PyObject* tensor) {
}
std
::
string
user_custom_name
=
m_tensor
->
user_custom_name
;
std
::
string
automatic_name
=
m_tensor
->
automatic_name
;
auto
module_trace_info
=
m_tensor
->
m_module_trace_info
;
m_tensor
=
t
->
m_tensor
;
m_tensor
->
m_module_trace_info
=
module_trace_info
;
m_tensor
->
user_custom_name
=
user_custom_name
;
m_tensor
->
automatic_name
=
automatic_name
;
}
...
...
@@ -856,6 +889,7 @@ void init_tensor(py::module m) {
.
def_getset
<&
TensorWrapper
::
trace_mixin_info
,
&
TensorWrapper
::
set_trace_mixin_info
>
(
"_trace_mixin_info"
)
.
def_getset
<&
TensorWrapper
::
user_custom_name
,
&
TensorWrapper
::
set_user_custom_name
>
(
"c_name"
)
.
def_getset
<&
TensorWrapper
::
automatic_name
,
&
TensorWrapper
::
set_automatic_name
>
(
"_name"
)
.
def_getset
<&
TensorWrapper
::
module_trace_info
,
&
TensorWrapper
::
set_module_trace_info
>
(
"_NodeMixin__node"
)
.
finalize
();
if
(
!
tensor_type
)
throw
py
::
error_already_set
();
py
::
setattr
(
m
,
"Tensor"
,
tensor_type
);
...
...
@@ -998,7 +1032,7 @@ void init_tensor(py::module m) {
m
.
def
(
"set_cpp_apply_with_tracing"
,
&
set_cpp_apply_with_tracing
);
m
.
def
(
"set_cpp_apply_const_with_tracing"
,
&
set_cpp_apply_const_with_tracing
);
m
.
def
(
"set_cpp_apply_backward_varnode"
,
&
set_cpp_apply_backward_varnode
);
m
.
def
(
"set_cpp_apply_module_trace"
,
&
set_cpp_apply_module_trace
);
m
.
attr
(
"skip_tracing"
)
=
&
skip_tracing
;
py
::
class_
<
SharedHandle
>
(
m
,
"SharedHandle"
)
...
...
@@ -1016,6 +1050,9 @@ void init_tensor(py::module m) {
m
.
def
(
"set_allow_higher_order_directive"
,
[](
bool
value
){
GradKey
::
allow_higher_order_directive
=
value
;
});
m
.
def
(
"set_module_tracing"
,
&
set_module_tracing
);
m
.
def
(
"unset_module_tracing"
,
&
unset_module_tracing
);
m
.
def
(
"is_tracing_module"
,
&
is_tracing_module
);
}
#undef MGE_PY_INTERFACE
...
...
imperative/python/src/tensor.h
浏览文件 @
9279104b
...
...
@@ -96,6 +96,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
static
constexpr
flags_t
SCALAR
=
1
;
static
constexpr
flags_t
GRAD
=
1
<<
1
;
static
constexpr
flags_t
TRACE
=
1
<<
2
;
static
constexpr
flags_t
MODULE_TRACE
=
1
<<
3
;
};
flags_t
m_flags
=
0
;
...
...
@@ -106,6 +107,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
std
::
string
user_custom_name
;
std
::
string
automatic_name
;
cg
::
VarNode
*
m_var
;
pybind11
::
object
m_module_trace_info
;
using
Handle
=
interpreter
::
Interpreter
::
Handle
;
...
...
@@ -158,10 +160,10 @@ struct TensorWrapper {
using
wrap_t
=
pyext17
::
wrap
<
TensorWrapper
>
;
friend
wrap_t
;
inline
static
TensorWrapper
*
cast
(
PyObject
*
o
p
)
{
return
reinterpret_cast
<
wrap_t
*>
(
op
)
->
inst
();}
inline
static
TensorWrapper
*
try_cast
(
PyObject
*
o
p
)
{
if
(
!
wrap_t
::
type
().
isinstance
(
o
p
))
return
nullptr
;
return
cast
(
o
p
);
inline
static
TensorWrapper
*
cast
(
PyObject
*
o
bj
)
{
return
reinterpret_cast
<
wrap_t
*>
(
obj
)
->
inst
();}
inline
static
TensorWrapper
*
try_cast
(
PyObject
*
o
bj
)
{
if
(
!
wrap_t
::
type
().
isinstance
(
o
bj
))
return
nullptr
;
return
cast
(
o
bj
);
}
inline
ObjectPtr
<
TensorWrapper
,
pybind11
::
handle
>
self
()
{
return
wrap_t
::
pycast
(
this
);}
...
...
@@ -206,6 +208,8 @@ struct TensorWrapper {
void
set_compiled_info
(
PyObject
*
);
PyObject
*
trace_mixin_info
();
void
set_trace_mixin_info
(
PyObject
*
);
PyObject
*
module_trace_info
();
void
set_module_trace_info
(
PyObject
*
);
PyObject
*
user_custom_name
();
void
set_user_custom_name
(
PyObject
*
);
PyObject
*
automatic_name
();
...
...
@@ -331,6 +335,7 @@ void init_tensor(pybind11::module);
extern
PyObject
*
cpp_apply_with_tracing
;
extern
PyObject
*
cpp_apply_backward_varnode
;
extern
PyObject
*
cpp_apply_module_trace
;
}
// namespace mgb::imperative::python
...
...
imperative/python/test/unit/core/test_serialization.py
浏览文件 @
9279104b
...
...
@@ -14,6 +14,11 @@ import numpy as np
import
megengine
as
mge
from
megengine
import
Parameter
,
Tensor
from
megengine.core.ops
import
builtin
from
megengine.experimental.traced_module.serialization
import
(
get_opdef_state
,
load_opdef_from_state
,
)
def
test_tensor_serialization
():
...
...
@@ -86,3 +91,25 @@ def test_compatibility():
test_old_tensor
(
"tensor_v1_1.mge"
)
test_old_tensor
(
"tensor_v1_2.mge"
)
def
test_opdef_serialization
():
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Elemwise
(
mode
=
"Add"
)
pickle
.
dump
(
get_opdef_state
(
x
),
f
)
f
.
seek
(
0
)
load_x
=
load_opdef_from_state
(
pickle
.
load
(
f
))
assert
x
==
load_x
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Convolution
(
stride_h
=
9
,
compute_mode
=
"float32"
)
x
.
strategy
=
(
builtin
.
Convolution
.
Strategy
.
PROFILE
|
builtin
.
Convolution
.
Strategy
.
HEURISTIC
|
builtin
.
Convolution
.
Strategy
.
REPRODUCIBLE
)
pickle
.
dump
(
get_opdef_state
(
x
),
f
)
f
.
seek
(
0
)
load_x
=
load_opdef_from_state
(
pickle
.
load
(
f
))
assert
x
.
strategy
==
load_x
.
strategy
assert
x
==
load_x
imperative/tablegen/targets/python_c_extension.cpp
浏览文件 @
9279104b
...
...
@@ -34,6 +34,7 @@ private:
void
emit_class
();
void
emit_py_init
();
void
emit_py_getsetters
();
void
emit_py_methods
();
Initproc
emit_initproc
();
MgbOp
&
op
;
...
...
@@ -133,9 +134,16 @@ void $0(PyTypeObject& py_type) {
if
(
firstOccur
)
{
os
<<
tgfmt
(
R"(
static PyMethodDef tp_methods[] = {
{const_cast<char*>("dump"), (PyCFunction)$enumTpl<$opClass::$enumClass>::py_dump, METH_NOARGS, NULL},
{NULL} /* Sentinel */
};
)"
,
&
ctx
);
os
<<
tgfmt
(
R"(
static PyType_Slot slots[] = {
{Py_tp_repr, (void*)$enumTpl<$opClass::$enumClass>::py_repr},
{Py_tp_richcompare, (void*)$enumTpl<$opClass::$enumClass>::tp_richcompare},
{Py_tp_methods, tp_methods},
)"
,
&
ctx
);
if
(
attr
->
getEnumCombinedFlag
())
{
// only bit combined enum could new instance because bitwise operation,
...
...
@@ -212,17 +220,62 @@ Initproc OpDefEmitter::emit() {
emit_class
();
emit_py_init
();
emit_py_getsetters
();
emit_py_methods
();
return
emit_initproc
();
}
void
OpDefEmitter
::
emit_class
()
{
auto
&&
className
=
op
.
getCppClassName
();
std
::
string
method_defs
;
std
::
vector
<
std
::
string
>
body
;
llvm
::
for_each
(
op
.
getMgbAttributes
(),
[
&
](
auto
&&
attr
)
{
body
.
push_back
(
formatv
(
R"(
{{"{0}", serialization<decltype(opdef.{0})>::dump(opdef.{0})})"
,
attr
.
name
));
});
method_defs
+=
formatv
(
R"(
static PyObject* getstate(PyObject* self, PyObject*) {{
auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {{
{1}
};
return py::cast(state).release().ptr();
})"
,
className
,
llvm
::
join
(
body
,
","
));
body
.
clear
();
llvm
::
for_each
(
op
.
getMgbAttributes
(),
[
&
](
auto
&&
attr
)
{
body
.
push_back
(
formatv
(
R"(
{{
auto&& iter = state.find("{0}");
if (iter != state.end()) {
opdef.{0} = serialization<decltype(opdef.{0})>::load(iter->second);
}
})"
,
attr
.
name
));
});
method_defs
+=
formatv
(
R"(
static PyObject* setstate(PyObject* self, PyObject* args) {{
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp({0})*>(self)->inst();
static_cast<void>(opdef);
{1}
Py_RETURN_NONE;
})"
,
className
,
llvm
::
join
(
body
,
"
\n
"
));
os
<<
tgfmt
(
R"(
PyOpDefBegin($_self) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
$0
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd($_self)
)"
,
&
ctx
);
)"
,
&
ctx
,
method_defs
);
}
void
OpDefEmitter
::
emit_py_init
()
{
...
...
@@ -302,6 +355,33 @@ PyGetSetDef PyOp($_self)::py_getsetters[] = {
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
op
.
getMgbAttributes
(),
f
),
"
\n
"
));
}
void
OpDefEmitter
::
emit_py_methods
(){
// generate methods
std
::
string
method_defs
;
std
::
vector
<
std
::
string
>
method_items
;
{
auto
&&
className
=
op
.
getCppClassName
();
// generate getstate
method_items
.
push_back
(
formatv
(
"{{const_cast<char*>(
\"
__getstate__
\"
), PyOp({0})::getstate, METH_NOARGS,
\"
{0} getstate
\"
},"
,
className
));
// generate setstate
method_items
.
push_back
(
formatv
(
"{{const_cast<char*>(
\"
__setstate__
\"
), PyOp({0})::setstate, METH_VARARGS,
\"
{0} setstate
\"
},"
,
className
));
}
os
<<
tgfmt
(
R"(
PyMethodDef PyOp($_self)::tp_methods[] = {
$0
{NULL} /* Sentinel */
};
)"
,
&
ctx
,
llvm
::
join
(
method_items
,
"
\n
"
));
}
Initproc
OpDefEmitter
::
emit_initproc
()
{
std
::
string
initproc
=
formatv
(
"_init_py_{0}"
,
op
.
getCppClassName
());
std
::
string
subclass_init_call
;
...
...
@@ -321,6 +401,7 @@ void $0(py::module m) {
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
$1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录