Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a8108522
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a8108522
编写于
4月 14, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): refactor enum param type caster
GitOrigin-RevId: 1aae07f143b8d1c0176a41de790fee8d6b2f1a25
上级
dcff115e
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
281 addition
and
163 deletion
+281
-163
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+3
-12
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+121
-116
imperative/python/src/ops.h
imperative/python/src/ops.h
+21
-0
imperative/tablegen/CMakeLists.txt
imperative/tablegen/CMakeLists.txt
+2
-1
imperative/tablegen/autogen.cpp
imperative/tablegen/autogen.cpp
+9
-2
imperative/tablegen/targets/macros.cpp
imperative/tablegen/targets/macros.cpp
+56
-0
imperative/tablegen/targets/macros.h
imperative/tablegen/targets/macros.h
+19
-0
imperative/tablegen/targets/python_c_extension.cpp
imperative/tablegen/targets/python_c_extension.cpp
+50
-32
未找到文件。
imperative/python/src/graph_rt.cpp
浏览文件 @
a8108522
...
...
@@ -21,6 +21,7 @@
#include "./helper.h"
#include "megbrain/plugin/profiler.h"
#include "./common.h"
#include "./ops.h"
#include "megbrain/gopt/inference.h"
...
...
@@ -265,18 +266,8 @@ void init_graph_rt(py::module m) {
});
m
.
def
(
"modify_opr_algo_strategy_inplace"
,
[](
const
VarNodeArray
&
dest_vars
,
const
std
::
string
&
strategy
)
{
_AlgoStrategy
stg
;
const
std
::
unordered_map
<
std
::
string
,
std
::
function
<
void
()
>>
m
{
{
"HEURISTIC"
,
[
&
]()
{
stg
=
_AlgoStrategy
::
HEURISTIC
;
}},
{
"PROFILE"
,
[
&
]()
{
stg
=
_AlgoStrategy
::
PROFILE
;
}},
{
"REPRODUCIBLE"
,
[
&
]()
{
stg
=
_AlgoStrategy
::
REPRODUCIBLE
;
}},
{
"OPTIMIZED"
,
[
&
]()
{
stg
=
_AlgoStrategy
::
OPTIMIZED
;
}},
};
auto
it
=
m
.
find
(
strategy
);
mgb_assert
(
it
!=
m
.
end
(),
"Invalid strategy string!"
);
it
->
second
();
mgb
::
gopt
::
modify_opr_algo_strategy_inplace
(
dest_vars
,
stg
);
const
_AlgoStrategy
&
strategy
)
{
mgb
::
gopt
::
modify_opr_algo_strategy_inplace
(
dest_vars
,
strategy
);
});
m
.
def
(
"get_info_for_strip"
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
)
{
...
...
imperative/python/src/ops.cpp
浏览文件 @
a8108522
...
...
@@ -73,29 +73,6 @@ PyTypeObject PyOpType(name);
} \
} while (0)
template
<
typename
T
,
typename
SFINAE
=
void
>
struct
pyobj_convert_generic
{
static
T
from
(
PyObject
*
obj
)
{
// TODO: remove this guard which is used for pybind11 implicit conversion
py
::
detail
::
loader_life_support
guard
{};
return
py
::
cast
<
T
>
(
py
::
handle
(
obj
));
}
template
<
typename
U
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
std
::
decay_t
<
U
>
>>>
static
PyObject
*
to
(
U
&&
t
)
{
return
py
::
cast
(
std
::
forward
<
U
>
(
t
)).
release
().
ptr
();
}
};
template
<
typename
T
,
typename
SFINAE
=
void
>
struct
EnumTrait
;
template
<
typename
T
>
struct
EnumTrait
<
T
,
std
::
enable_if_t
<
std
::
is_enum_v
<
T
>>>
{
static
constexpr
bool
is_bit_combined
=
false
;
static
constexpr
std
::
underlying_type_t
<
T
>
max
=
0
;
};
template
<
typename
T
>
PyObject
*
py_new_generic
(
PyTypeObject
*
type
,
PyObject
*
,
PyObject
*
)
{
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
...
...
@@ -115,7 +92,7 @@ void py_dealloc_generic(PyObject* obj) {
template
<
typename
T
,
typename
U
,
U
T
::
Ty
::*
attr
>
PyObject
*
py_get_generic_impl
(
PyObject
*
obj
,
void
*
/* closure */
)
{
auto
&
op
=
reinterpret_cast
<
T
*>
(
obj
)
->
inst
();
return
py
obj_convert_generic
<
U
>::
to
(
op
.
*
attr
);
return
py
::
cast
(
op
.
*
attr
).
release
().
ptr
(
);
}
#define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
...
...
@@ -128,7 +105,9 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
}
auto
&
op
=
reinterpret_cast
<
T
*>
(
obj
)
->
inst
();
try
{
op
.
*
attr
=
pyobj_convert_generic
<
U
>::
from
(
value
);
// TODO: remove this guard which is used for pybind11 implicit conversion
py
::
detail
::
loader_life_support
guard
{};
op
.
*
attr
=
py
::
cast
<
U
>
(
py
::
handle
(
value
));
}
CATCH_ALL
(
-
1
)
return
0
;
}
...
...
@@ -148,8 +127,8 @@ PyTypeObject PyOpType(OpDef);
std
::
unordered_map
<
mgb
::
Typeinfo
*
,
PyTypeObject
*>
PyOp
(
OpDef
)
::
ctype2pytype
;
PyObject
*
py_get_scope
(
PyObject
*
obj
,
void
*
/* closure */
)
{
return
py
obj_convert_generic
<
std
::
string
>::
to
(
reinterpret_cast
<
PyOp
(
OpDef
)
*>
(
obj
)
->
op
->
scope
()
);
return
py
::
cast
(
reinterpret_cast
<
PyOp
(
OpDef
)
*>
(
obj
)
->
op
->
scope
()).
release
().
ptr
(
);
}
int
py_set_scope
(
PyObject
*
obj
,
PyObject
*
value
,
void
*
/* closure */
)
{
...
...
@@ -159,7 +138,7 @@ int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
}
try
{
reinterpret_cast
<
PyOp
(
OpDef
)
*>
(
obj
)
->
op
->
set_scope
(
py
obj_convert_generic
<
std
::
string
>::
from
(
value
));
->
set_scope
(
py
::
cast
<
std
::
string
>
(
py
::
handle
(
value
)
));
}
CATCH_ALL
(
-
1
)
return
0
;
}
...
...
@@ -183,24 +162,29 @@ PyObject* PyOp(OpDef)::tp_richcompare(PyObject *self, PyObject *other, int op) {
Py_RETURN_NOTIMPLEMENTED
;
}
template
<
typename
T
>
struct
EnumTrait
;
#define PyEnumHead \
static_assert(std::is_enum_v<T>); \
PyObject_HEAD \
T value; \
constexpr static const char *name = EnumTrait<T>::name; \
static PyTypeObject type; \
static const char* members[]; \
static std::unordered_map<std::string, T> mem2value; \
static PyObject* pyobj_insts[];
template
<
typename
T
>
struct
EnumWrapper
{
static_assert
(
std
::
is_enum_v
<
T
>
);
PyObject_HEAD
T
value
;
static
const
char
*
name
;
static
PyTypeObject
type
;
static
std
::
unordered_map
<
T
,
std
::
string
>
type2str
;
static
std
::
unordered_map
<
std
::
string
,
T
>
str2type
;
EnumWrapper
()
=
default
;
EnumWrapper
(
T
v
)
:
value
(
v
)
{}
EnumWrapper
(
std
::
string
&&
str
)
:
EnumWrapper
(
str2type
.
at
(
normalize_enum
(
str
)))
{}
PyEnumHead
std
::
string
to_string
()
const
{
return
type2str
.
at
(
value
)
;
return
members
[
static_cast
<
size_t
>
(
value
)]
;
}
static
PyObject
*
py_repr
(
PyObject
*
self
)
{
return
pyobj_convert_generic
<
std
::
string
>::
to
(
std
::
string
(
name
)
+
"."
+
reinterpret_cast
<
EnumWrapper
*>
(
self
)
->
to_string
());
return
py
::
cast
(
std
::
string
(
name
)
+
"."
+
reinterpret_cast
<
EnumWrapper
*>
(
self
)
->
to_string
())
.
release
().
ptr
();
}
static
PyObject
*
tp_richcompare
(
PyObject
*
self
,
PyObject
*
other
,
int
op
)
{
T
lhs
=
reinterpret_cast
<
EnumWrapper
*>
(
self
)
->
value
,
...
...
@@ -210,59 +194,52 @@ struct EnumWrapper {
}
Py_RETURN_NOTIMPLEMENTED
;
}
};
template
<
typename
T
>
struct
pyobj_convert_generic
<
T
,
std
::
enable_if_t
<
std
::
is_enum_v
<
std
::
decay_t
<
T
>>
&&
!
EnumTrait
<
T
>::
is_bit_combined
>>
{
using
Wrapper
=
EnumWrapper
<
T
>
;
static
T
from
(
PyObject
*
obj
)
{
if
(
PyObject_TypeCheck
(
obj
,
&
Wrapper
::
type
))
{
return
reinterpret_cast
<
Wrapper
*>
(
obj
)
->
value
;
static
bool
load
(
py
::
handle
src
,
T
&
value
)
{
PyObject
*
obj
=
src
.
ptr
();
if
(
PyObject_TypeCheck
(
obj
,
&
type
))
{
value
=
reinterpret_cast
<
EnumWrapper
*>
(
obj
)
->
value
;
return
true
;
}
if
(
py
::
isinstance
<
py
::
str
>
(
src
))
{
auto
&&
iter
=
mem2value
.
find
(
normalize_enum
(
py
::
cast
<
std
::
string
>
(
src
)));
if
(
iter
!=
mem2value
.
end
())
{
value
=
iter
->
second
;
return
true
;
}
else
{
return
false
;
}
}
// try as string
// TODO: type checkcd
return
Wrapper
(
pyobj_convert_generic
<
std
::
string
>::
from
(
obj
)).
value
;
return
false
;
}
static
PyObject
*
to
(
T
t
)
{
PyTypeObject
*
pytype
=
&
Wrapper
::
type
;
PyObject
*
obj
=
pytype
->
tp_alloc
(
pytype
,
0
);
reinterpret_cast
<
Wrapper
*>
(
obj
)
->
value
=
t
;
static
PyObject
*
cast
(
const
T
&
value
)
{
auto
v
=
static_cast
<
std
::
underlying_type_t
<
T
>>
(
value
);
mgb_assert
(
v
<=
EnumTrait
<
T
>::
max
);
PyObject
*
obj
=
pyobj_insts
[
v
];
Py_INCREF
(
obj
);
return
obj
;
}
};
template
<
typename
T
>
struct
BitCombinedEnumWrapper
{
static_assert
(
std
::
is_enum_v
<
T
>
);
PyObject_HEAD
T
value
;
static
const
char
*
name
;
static
PyTypeObject
type
;
static
std
::
unordered_map
<
T
,
std
::
string
>
type2str
;
static
std
::
unordered_map
<
std
::
string
,
T
>
str2type
;
PyEnumHead
static
PyNumberMethods
number_methods
;
BitCombinedEnumWrapper
()
=
default
;
BitCombinedEnumWrapper
(
T
v
)
:
value
(
v
)
{}
BitCombinedEnumWrapper
(
std
::
string
&&
str
)
:
BitCombinedEnumWrapper
(
str2type
.
at
(
normalize_enum
(
str
)))
{}
std
::
string
to_string
()
const
{
if
(
static_cast
<
uint32_t
>
(
value
)
==
0
)
{
uint32_t
value_int
=
static_cast
<
uint32_t
>
(
value
);
if
(
value_int
==
0
)
{
return
"None"
;
}
else
{
auto
ret
=
std
::
string
()
;
std
::
string
ret
;
bool
first
=
true
;
for
(
uint32_t
i
=
0
;
i
<
32
;
i
++
)
{
uint32_t
value_int
=
static_cast
<
uint32_t
>
(
value
);
auto
it
=
type2str
.
find
(
static_cast
<
T
>
((
1
<<
i
)
&
value_int
));
if
(
it
!=
type2str
.
end
())
{
if
(
value_int
>>
i
&
1
)
{
if
(
!
first
)
{
ret
+=
" + "
;
}
else
{
first
=
false
;
}
ret
+=
(
std
::
string
(
name
)
+
"."
+
it
->
second
);
ret
+=
(
std
::
string
(
name
)
+
"."
+
members
[
i
]
);
}
}
return
ret
;
...
...
@@ -280,17 +257,20 @@ struct BitCombinedEnumWrapper {
return
nullptr
;
}
T
value
;
try
{
value
=
pyobj_convert_generic
<
T
>::
from
(
input
);
}
CATCH_ALL
(
nullptr
);
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
value
;
return
obj
;
if
(
load
(
input
,
value
))
{
return
cast
(
value
);
}
else
{
PyErr_SetString
(
PyExc_RuntimeError
,
mgb
::
ssprintf
(
"Cannot convert type %s to type %s
\n
"
,
input
->
ob_type
->
tp_name
,
name
).
c_str
());
return
nullptr
;
}
}
}
static
PyObject
*
py_repr
(
PyObject
*
self
)
{
return
pyobj_convert_generic
<
std
::
string
>::
to
(
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
to_string
());
return
py
::
cast
(
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
to_string
())
.
release
().
ptr
();
}
static
PyObject
*
py_or
(
PyObject
*
self
,
PyObject
*
other
)
{
if
(
!
(
self
->
ob_type
==
other
->
ob_type
)){
...
...
@@ -298,12 +278,9 @@ struct BitCombinedEnumWrapper {
PyExc_RuntimeError
,
"Operand in or operator must be the same type."
);
}
PyObject
*
obj
=
type
.
tp_alloc
(
&
type
,
0
);
T
lhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
,
rhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
other
)
->
value
;
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
static_cast
<
T
>
(
static_cast
<
uint32_t
>
(
lhs
)
|
static_cast
<
uint32_t
>
(
rhs
));
return
obj
;
return
cast
(
lhs
|
rhs
);
}
static
PyObject
*
py_and
(
PyObject
*
self
,
PyObject
*
other
)
{
if
(
!
(
self
->
ob_type
==
other
->
ob_type
))
{
...
...
@@ -311,12 +288,9 @@ struct BitCombinedEnumWrapper {
PyExc_RuntimeError
,
"Operand in and operator must be the same type."
);
}
PyObject
*
obj
=
type
.
tp_alloc
(
&
type
,
0
);
T
lhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
,
rhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
other
)
->
value
;
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
static_cast
<
T
>
(
static_cast
<
uint32_t
>
(
lhs
)
&
static_cast
<
uint32_t
>
(
rhs
));
return
obj
;
return
cast
(
lhs
&
rhs
);
}
static
PyObject
*
tp_richcompare
(
PyObject
*
self
,
PyObject
*
other
,
int
op
)
{
T
lhs
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
,
...
...
@@ -326,32 +300,45 @@ struct BitCombinedEnumWrapper {
}
Py_RETURN_NOTIMPLEMENTED
;
}
};
template
<
typename
T
>
struct
pyobj_convert_generic
<
T
,
std
::
enable_if_t
<
std
::
is_enum_v
<
std
::
decay_t
<
T
>>
&&
EnumTrait
<
T
>::
is_bit_combined
>>
{
using
Wrapper
=
BitCombinedEnumWrapper
<
T
>
;
static
T
from
(
PyObject
*
obj
)
{
if
(
PyObject_TypeCheck
(
obj
,
&
Wrapper
::
type
))
{
return
reinterpret_cast
<
Wrapper
*>
(
obj
)
->
value
;
}
else
if
(
PyLong_Check
(
obj
))
{
auto
value
=
pyobj_convert_generic
<
std
::
underlying_type_t
<
T
>>::
from
(
obj
);
mgb_throw_if
(
value
>
EnumTrait
<
T
>::
max
,
mgb
::
MegBrainError
,
"out of range, cannot convert %zu to %s"
,
static_cast
<
uint32_t
>
(
value
),
Wrapper
::
name
);
return
static_cast
<
T
>
(
value
);
static
bool
load
(
py
::
handle
src
,
T
&
value
)
{
PyObject
*
obj
=
src
.
ptr
();
if
(
PyObject_TypeCheck
(
obj
,
&
type
))
{
value
=
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
;
return
true
;
}
if
(
py
::
isinstance
<
py
::
str
>
(
src
))
{
auto
&&
iter
=
mem2value
.
find
(
normalize_enum
(
py
::
cast
<
std
::
string
>
(
src
)));
if
(
iter
!=
mem2value
.
end
())
{
value
=
iter
->
second
;
return
true
;
}
else
{
return
false
;
}
}
if
(
py
::
isinstance
<
py
::
int_
>
(
obj
))
{
auto
v
=
py
::
cast
<
std
::
underlying_type_t
<
T
>>
(
src
);
if
(
v
>
EnumTrait
<
T
>::
max
)
{
return
false
;
}
value
=
static_cast
<
T
>
(
v
);
return
true
;
}
// try as string
// TODO: type checkcd
return
Wrapper
(
pyobj_convert_generic
<
std
::
string
>::
from
(
obj
)).
value
;
return
false
;
}
static
PyObject
*
to
(
T
t
)
{
PyTypeObject
*
pytype
=
&
Wrapper
::
type
;
PyObject
*
obj
=
pytype
->
tp_alloc
(
pytype
,
0
);
reinterpret_cast
<
Wrapper
*>
(
obj
)
->
value
=
t
;
return
obj
;
static
PyObject
*
cast
(
const
T
&
value
)
{
auto
v
=
static_cast
<
std
::
underlying_type_t
<
T
>>
(
value
);
mgb_assert
(
v
<=
EnumTrait
<
T
>::
max
);
if
((
!
v
)
||
(
v
&
(
v
-
1
)))
{
PyTypeObject
*
pytype
=
&
type
;
PyObject
*
obj
=
pytype
->
tp_alloc
(
pytype
,
0
);
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
value
;
return
obj
;
}
else
{
PyObject
*
obj
=
pyobj_insts
[
__builtin_ctz
(
v
)];
Py_INCREF
(
obj
);
return
obj
;
}
}
};
...
...
@@ -443,7 +430,6 @@ void _init_py_op_base(py::module m) {
#include "opdef.cpy.inl"
#undef CATCH_ALL
}
// anonymous namespace
namespace
PYBIND11_NAMESPACE
{
...
...
@@ -478,6 +464,25 @@ handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
reinterpret_cast
<
PyOp
(
OpDef
)
*>
(
obj
)
->
op
=
const_cast
<
OpDef
&>
(
op
).
shared_from_this
();
return
py
::
handle
(
obj
);
}
#define ENUM_CASTER_IMPL(T) \
bool type_caster<T>::load(handle src, bool) { \
return EnumWrapper<T>::load(src, value); \
} \
handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
return EnumWrapper<T>::cast(value); \
}
FOR_EACH_ENUM_PARAM
(
ENUM_CASTER_IMPL
)
#define BIT_COMBINED_ENUM_CASTER_IMPL(T) \
bool type_caster<T>::load(handle src, bool) { \
return BitCombinedEnumWrapper<T>::load(src, value); \
} \
handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
return BitCombinedEnumWrapper<T>::cast(value); \
}
FOR_EACH_BIT_COMBINED_ENUM_PARAM
(
BIT_COMBINED_ENUM_CASTER_IMPL
)
}
// detail
}
// PYBIND11_NAMESPACE
...
...
imperative/python/src/ops.h
浏览文件 @
a8108522
...
...
@@ -12,5 +12,26 @@
#pragma once
#include "./helper.h"
#include "./enum_macro.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h"
namespace
PYBIND11_NAMESPACE
{
namespace
detail
{
#define ENUM_CASTER_DEF(name) \
template<> struct type_caster<name> { \
PYBIND11_TYPE_CASTER(name, _(#name)); \
public: \
bool load(handle src, bool); \
static handle cast(const name& v, return_value_policy, handle); \
};
FOR_EACH_ENUM_PARAM
(
ENUM_CASTER_DEF
)
FOR_EACH_BIT_COMBINED_ENUM_PARAM
(
ENUM_CASTER_DEF
)
}
// detail
}
// PYBIND11_NAMESPACE
void
init_ops
(
pybind11
::
module
m
);
imperative/tablegen/CMakeLists.txt
浏览文件 @
a8108522
...
...
@@ -12,5 +12,6 @@ tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header")
tablegen
(
MGB opdef.cpp.inl
${
MGE_IR_INCLUDE_DIRS
}
"--gen-cpp-body"
)
tablegen
(
MGB opdef.py.inl
${
MGE_IR_INCLUDE_DIRS
}
"--gen-python-binding"
)
tablegen
(
MGB opdef.cpy.inl
${
MGE_IR_INCLUDE_DIRS
}
"--gen-python-c-extension"
)
add_custom_target
(
mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl param_defs_tblgen
)
tablegen
(
MGB enum_macro.h
${
MGE_IR_INCLUDE_DIRS
}
"--gen-enum-list-macro"
)
add_custom_target
(
mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl opdef.cpy.inl enum_macro.h param_defs_tblgen
)
set
(
MGB_OPDEF_OUT_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
PARENT_SCOPE
)
imperative/tablegen/autogen.cpp
浏览文件 @
a8108522
...
...
@@ -12,6 +12,7 @@
#include "./targets/cpp_class.h"
#include "./targets/pybind11.h"
#include "./targets/python_c_extension.h"
#include "./targets/macros.h"
using
llvm
::
raw_ostream
;
using
llvm
::
RecordKeeper
;
...
...
@@ -21,7 +22,8 @@ enum ActionType {
CppHeader
,
CppBody
,
Pybind
,
CPython
CPython
,
EnumListMacro
};
// NOLINTNEXTLINE
...
...
@@ -34,7 +36,9 @@ llvm::cl::opt<ActionType> action(
clEnumValN
(
Pybind
,
"gen-python-binding"
,
"Generate pybind11 python bindings"
),
clEnumValN
(
CPython
,
"gen-python-c-extension"
,
"Generate python c extensions"
)));
"Generate python c extensions"
),
clEnumValN
(
EnumListMacro
,
"gen-enum-list-macro"
,
"Generate enum param list macro"
)));
using
namespace
mlir
::
tblgen
;
...
...
@@ -53,5 +57,8 @@ int main(int argc, char **argv) {
if
(
action
==
ActionType
::
CPython
)
{
return
TableGenMain
(
argv
[
0
],
&
gen_op_def_python_c_extension
);
}
if
(
action
==
ActionType
::
EnumListMacro
)
{
return
TableGenMain
(
argv
[
0
],
&
gen_enum_param_list_macro
);
}
return
-
1
;
}
imperative/tablegen/targets/macros.cpp
0 → 100644
浏览文件 @
a8108522
/**
* \file imperative/tablegen/targets/macros.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./cpp_class.h"
#include "../emitter.h"
namespace
mlir
::
tblgen
{
bool
gen_enum_param_list_macro
(
raw_ostream
&
os
,
llvm
::
RecordKeeper
&
keeper
)
{
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
enums
;
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
bit_enums
;
Environment
env
;
foreach_operator
(
keeper
,
[
&
](
MgbOp
&
op
)
{
for
(
auto
&&
i
:
op
.
getAttributes
())
{
if
(
auto
attr
=
llvm
::
dyn_cast
<
MgbEnumAttr
>
(
&
i
.
attr
))
{
auto
insert
=
[
&
](
const
MgbEnumAttr
&
attr
)
{
auto
&&
item
=
std
::
make_pair
(
attr
.
getParentNamespace
(),
attr
.
getEnumName
());
if
(
env
.
enumAlias
.
emplace
(
attr
.
getBaseRecord
()
->
getID
(),
std
::
move
(
item
)).
second
)
{
if
(
attr
.
getEnumCombinedFlag
())
{
bit_enums
.
emplace_back
(
item
);
}
else
{
enums
.
emplace_back
(
item
);
}
}
};
if
(
auto
alias
=
llvm
::
dyn_cast
<
MgbAliasAttr
>
(
attr
))
{
auto
&&
aliasBase
=
alias
->
getAliasBase
();
insert
(
llvm
::
cast
<
MgbEnumAttr
>
(
aliasBase
));
}
else
{
insert
(
*
attr
);
}
}
}
});
os
<<
"#define FOR_EACH_ENUM_PARAM(cb)"
;
for
(
auto
&&
i
:
enums
)
{
os
<<
formatv
(
"
\\\n
cb({0}::{1});"
,
i
.
first
,
i
.
second
);
}
os
<<
"
\n
"
;
os
<<
"#define FOR_EACH_BIT_COMBINED_ENUM_PARAM(cb)"
;
for
(
auto
&&
i
:
bit_enums
)
{
os
<<
formatv
(
"
\\\n
cb({0}::{1});"
,
i
.
first
,
i
.
second
);
}
os
<<
"
\n
"
;
return
false
;
}
}
// namespace mlir::tblgen
imperative/tablegen/targets/macros.h
0 → 100644
浏览文件 @
a8108522
/**
* \file imperative/tablegen/targets/macros.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "../helper.h"
namespace
mlir
::
tblgen
{
bool
gen_enum_param_list_macro
(
raw_ostream
&
os
,
llvm
::
RecordKeeper
&
keeper
);
}
// namespace mlir::tblgen
imperative/tablegen/targets/python_c_extension.cpp
浏览文件 @
a8108522
...
...
@@ -60,6 +60,7 @@ public:
Initproc
emit
();
protected:
void
emit_trait
();
void
emit_tpl_spl
();
Initproc
emit_initproc
();
...
...
@@ -69,50 +70,63 @@ protected:
};
Initproc
EnumAttrEmitter
::
emit
()
{
emit_trait
();
emit_tpl_spl
();
return
emit_initproc
();
}
void
EnumAttrEmitter
::
emit_trait
()
{
if
(
!
firstOccur
)
return
;
auto
enumMax
=
[
&
]
{
if
(
attr
->
getEnumCombinedFlag
())
{
return
formatv
(
"(1llu << {0}) - 1"
,
attr
->
getEnumMembers
().
size
());
}
else
{
return
formatv
(
"{0} - 1"
,
attr
->
getEnumMembers
().
size
());
}
};
os
<<
tgfmt
(
R"(
template<> struct EnumTrait<$opClass::$enumClass> {
static constexpr const char *name = "$opClass.$enumClass";
static constexpr std::underlying_type_t<$opClass::$enumClass> max = $0;
};
)"
,
&
ctx
,
enumMax
());
}
void
EnumAttrEmitter
::
emit_tpl_spl
()
{
if
(
!
firstOccur
)
return
;
os
<<
tgfmt
(
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type
=
{};
\n
"
,
"template<> PyTypeObject $enumTpl<$opClass::$enumClass>::type
=
{};
\n
"
,
&
ctx
);
auto
quote
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
formatv
(
"
\"
{0}
\"
"
,
i
);
};
os
<<
tgfmt
(
R"(
template<> const char*
$enumTpl<$opClass::$enumClass>::members[] = {$0};
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
attr
->
getEnumMembers
(),
quote
),
", "
));
auto
mem2value
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
tgfmt
(
"{normalize_enum(
\"
$0
\"
), $opClass::$enumClass::$0}"
,
&
ctx
,
i
);
};
os
<<
tgfmt
(
R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
$enumTpl<$opClass::$enumClass>::mem2value = {$0};
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
attr
->
getEnumMembers
(),
mem2value
),
", "
));
os
<<
tgfmt
(
"template<>
const char* $enumTpl<$opClass::$enumClass>::name =
"
"
\"
$opClass.$enumClass
\"
;
\n
"
,
&
ctx
);
"template<>
PyObject*
"
"
$enumTpl<$opClass::$enumClass>::pyobj_insts[$0] = {nullptr};
\n
"
,
&
ctx
,
attr
->
getEnumMembers
().
size
()
);
if
(
attr
->
getEnumCombinedFlag
())
{
os
<<
tgfmt
(
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods
=
{};
\n
"
,
"$enumTpl<$opClass::$enumClass>::number_methods
=
{};
\n
"
,
&
ctx
);
os
<<
tgfmt
(
R"(
template<> struct EnumTrait<$opClass::$enumClass> {
static constexpr bool is_bit_combined = true;
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1;
};
)"
,
&
ctx
,
attr
->
getEnumMembers
().
size
());
}
auto
str2type
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
tgfmt
(
"{normalize_enum(
\"
$0
\"
), $opClass::$enumClass::$0}"
,
&
ctx
,
i
);
};
os
<<
tgfmt
(
R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
$enumTpl<$opClass::$enumClass>::str2type = {$0};
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
attr
->
getEnumMembers
(),
str2type
),
", "
));
auto
type2str
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
tgfmt
(
"{$opClass::$enumClass::$0, normalize_enum(
\"
$0
\"
)}"
,
&
ctx
,
i
);
};
os
<<
tgfmt
(
R"(
template<> std::unordered_map<$opClass::$enumClass, std::string>
$enumTpl<$opClass::$enumClass>::type2str = {$0};
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
attr
->
getEnumMembers
(),
type2str
),
", "
));
}
Initproc
EnumAttrEmitter
::
emit_initproc
()
{
...
...
@@ -150,14 +164,16 @@ void $0(PyTypeObject& py_type) {
os
<<
" mgb_assert(PyType_Ready(&e_type) >= 0);
\n
"
;
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
auto
&&
members
=
attr
->
getEnumMembers
();
for
(
size_t
idx
=
0
;
idx
<
members
.
size
();
++
idx
)
{
os
<<
tgfmt
(
R"({
PyObject* inst = e_type.tp_alloc(&e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type.tp_dict, "$0", inst) >= 0);
PyType_Modified(&e_type)
;
})"
,
&
ctx
,
i
);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst
;
})"
,
&
ctx
,
members
[
idx
],
idx
);
}
os
<<
" PyType_Modified(&e_type);
\n
"
;
}
os
<<
tgfmt
(
R"(
...
...
@@ -225,8 +241,10 @@ void OpDefEmitter::emit_py_init() {
initBody
+=
tgfmt
(
R"(
if ($0) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp($_self)*>(self)->inst().$0 =
pyobj_convert_generic<decltype($_self::$0)>::from($0
);
py::cast<decltype($_self::$0)>(py::handle($0)
);
} CATCH_ALL(-1)
}
)"
,
&
ctx
,
attr
.
name
);
...
...
@@ -236,7 +254,7 @@ void OpDefEmitter::emit_py_init() {
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py
obj_convert_generic<std::string>::from(scope
));
->set_scope(py
::cast<std::string>(py::handle(scope)
));
} CATCH_ALL(-1)
}
)"
,
&
ctx
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录