Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fb49a283
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fb49a283
编写于
9月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mgb/dnn): refactor enum used in serializing
GitOrigin-RevId: e57af4a59c9b4e090f3972b4d0cf01a2737f8355
上级
d69b5903
变更
10
展开全部
显示空白变更内容
内联
并排
Showing
10 changed file
with
410 addition
and
370 deletion
+410
-370
dnn/scripts/gen_flatbuffers_schema.py
dnn/scripts/gen_flatbuffers_schema.py
+10
-2
dnn/scripts/gen_param_defs.py
dnn/scripts/gen_param_defs.py
+35
-24
dnn/scripts/gen_tablegen.py
dnn/scripts/gen_tablegen.py
+5
-3
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+281
-281
imperative/tablegen/helper.h
imperative/tablegen/helper.h
+11
-8
imperative/tablegen/targets/cpp_class.cpp
imperative/tablegen/targets/cpp_class.cpp
+7
-3
imperative/tablegen/targets/pybind11.cpp
imperative/tablegen/targets/pybind11.cpp
+8
-7
imperative/tablegen/targets/python_c_extension.cpp
imperative/tablegen/targets/python_c_extension.cpp
+13
-3
tools/gen_header_for_bin_reduce.py
tools/gen_header_for_bin_reduce.py
+2
-1
tools/param_defs/mgb_opr_param_defs.py
tools/param_defs/mgb_opr_param_defs.py
+38
-38
未找到文件。
dnn/scripts/gen_flatbuffers_schema.py
浏览文件 @
fb49a283
...
@@ -23,8 +23,14 @@ def _cname_to_fbname(cname):
...
@@ -23,8 +23,14 @@ def _cname_to_fbname(cname):
}[
cname
]
}[
cname
]
def
scramble_enum_member_name
(
name
):
def
scramble_enum_member_name
(
name
):
s
=
name
.
find
(
'<<'
)
if
s
!=
-
1
:
name
=
name
[
0
:
name
.
find
(
'='
)
+
1
]
+
' '
+
name
[
s
+
2
:]
if
name
in
(
"MIN"
,
"MAX"
):
if
name
in
(
"MIN"
,
"MAX"
):
return
name
+
"_"
return
name
+
"_"
o_name
=
name
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
o_name
in
(
"MIN"
,
"MAX"
):
return
name
.
replace
(
o_name
,
o_name
+
"_"
)
return
name
return
name
class
FlatBuffersWriter
(
IndentWriterBase
):
class
FlatBuffersWriter
(
IndentWriterBase
):
...
@@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase):
...
@@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase):
if
e
.
combined
:
if
e
.
combined
:
default
=
e
.
compose_combined_enum
(
e
.
default
)
default
=
e
.
compose_combined_enum
(
e
.
default
)
else
:
else
:
default
=
scramble_enum_member_name
(
str
(
e
.
members
[
e
.
default
]))
default
=
scramble_enum_member_name
(
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_write
(
"%s:%s%s = %s;"
,
e
.
name_field
,
p
.
name
,
e
.
name
,
default
)
self
.
_write
(
"%s:%s%s = %s;"
,
e
.
name_field
,
p
.
name
,
e
.
name
,
default
)
def
_resolve_const
(
self
,
v
):
def
_resolve_const
(
self
,
v
):
...
@@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase):
...
@@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase):
if
s
.
combined
:
if
s
.
combined
:
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
else
:
else
:
default
=
scramble_enum_member_name
(
str
(
s
.
members
[
e
.
get_default
()]))
default
=
scramble_enum_member_name
(
str
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_write
(
"%s:%s = %s;"
,
e
.
name_field
,
enum_name
,
default
)
self
.
_write
(
"%s:%s = %s;"
,
e
.
name_field
,
enum_name
,
default
)
def
_get_fb_default
(
self
,
cppdefault
):
def
_get_fb_default
(
self
,
cppdefault
):
...
...
dnn/scripts/gen_param_defs.py
浏览文件 @
fb49a283
...
@@ -121,10 +121,12 @@ class member_defs:
...
@@ -121,10 +121,12 @@ class member_defs:
def
normalize_enum_value
(
self
,
value
):
def
normalize_enum_value
(
self
,
value
):
def
normalize
(
v
):
def
normalize
(
v
):
if
isinstance
(
v
,
str
):
if
isinstance
(
v
,
str
):
if
v
not
in
self
.
members
:
for
idx
,
m
in
enumerate
(
self
.
members
):
m
=
str
(
m
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
v
==
m
:
return
idx
raise
ValueError
(
raise
ValueError
(
"enum member '{}' does not exist."
.
format
(
v
))
"enum member '{}' does not exist."
.
format
(
v
))
v
=
self
.
members
.
index
(
v
)
assert
isinstance
(
v
,
int
)
assert
isinstance
(
v
,
int
)
return
v
return
v
if
self
.
combined
:
if
self
.
combined
:
...
@@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase):
...
@@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase):
self
.
_write_doc
(
e
.
name
)
self
.
_write_doc
(
e
.
name
)
for
idx
,
emem
in
enumerate
(
e
.
members
)
:
for
emem
in
e
.
members
:
if
e
.
combined
:
if
e
.
combined
:
self
.
_write
(
'%s
= 1 << %d'
,
emem
,
idx
)
self
.
_write
(
'%s
'
,
emem
)
self
.
_write_doc
(
emem
)
self
.
_write_doc
(
emem
)
else
:
else
:
self
.
_write
(
'%s = "%s"'
,
emem
,
emem
)
v
=
str
(
emem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
n
=
int
(
str
(
emem
).
split
(
'='
)[
1
])
self
.
_write
(
'%s = "%s"'
,
v
,
v
)
self
.
_write_doc
(
emem
)
self
.
_write_doc
(
emem
)
self
.
_enum_member2num
.
append
(
'id({}.{}):{}'
.
format
(
self
.
_enum_member2num
.
append
(
'id({}.{}):{}'
.
format
(
qualname
,
emem
,
idx
))
qualname
,
v
,
n
))
for
emem
,
emem_alias
in
e
.
member_alias
:
for
emem
,
emem_alias
in
e
.
member_alias
:
em_a
=
emem_alias
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
e
.
combined
:
if
e
.
combined
:
self
.
_write
(
'%s = %s'
,
em
em_alias
,
e
.
compose_combined_enum
(
emem
))
self
.
_write
(
'%s = %s'
,
em
_a
,
e
.
compose_combined_enum
(
emem
))
else
:
else
:
self
.
_write
(
'%s = %s'
,
emem_alias
,
emem
)
em
=
str
(
emem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
self
.
_write
(
'%s = %s'
,
em_a
,
em
)
self
.
_unindent
()
self
.
_unindent
()
self
.
_write
(
''
)
self
.
_write
(
''
)
...
@@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase):
...
@@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase):
if
e
.
combined
:
if
e
.
combined
:
default
=
e
.
compose_combined_enum
(
e
.
default
)
default
=
e
.
compose_combined_enum
(
e
.
default
)
else
:
else
:
default
=
"'{}'"
.
format
(
e
.
members
[
e
.
default
])
default
=
"'{}'"
.
format
(
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
name
=
e
.
name_field
,
name
=
e
.
name_field
,
...
@@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase):
...
@@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase):
if
s
.
combined
:
if
s
.
combined
:
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
else
:
else
:
default
=
"'{}'"
.
format
(
s
.
members
[
e
.
get_default
()
])
default
=
"'{}'"
.
format
(
s
tr
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
name
=
e
.
name_field
,
name
=
e
.
name_field
,
cvt
=
'{}.convert({})'
.
format
(
qualname
,
e
.
name_field
),
cvt
=
'{}.convert({})'
.
format
(
qualname
,
e
.
name_field
),
...
@@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase):
...
@@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase):
def
_on_member_enum
(
self
,
e
):
def
_on_member_enum
(
self
,
e
):
self
.
_write_doc
(
e
.
name
)
self
.
_write_doc
(
e
.
name
)
self
.
_write
(
'enum class %s: uint32_t {'
,
e
.
name
,
indent
=
1
)
self
.
_write
(
'enum class %s: uint32_t {'
,
e
.
name
,
indent
=
1
)
for
i
dx
,
i
in
enumerate
(
e
.
members
)
:
for
i
in
e
.
members
:
self
.
_write_doc
(
i
)
self
.
_write_doc
(
i
)
v
=
'{} = {}'
.
format
(
i
,
idx
)
v
=
str
(
i
)
if
e
.
combined
:
v
=
'{} = 1 << {}'
.
format
(
i
,
idx
)
if
i
is
not
e
.
members
[
-
1
]
or
e
.
member_alias
:
if
i
is
not
e
.
members
[
-
1
]
or
e
.
member_alias
:
v
+=
','
v
+=
','
self
.
_write
(
v
)
self
.
_write
(
v
)
...
@@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase):
...
@@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase):
if
e
.
combined
:
if
e
.
combined
:
self
.
_write
(
'%s = %s,'
,
alias
,
e
.
compose_combined_enum
(
mem
))
self
.
_write
(
'%s = %s,'
,
alias
,
e
.
compose_combined_enum
(
mem
))
else
:
else
:
self
.
_write
(
'%s = %s,'
,
alias
,
mem
)
self
.
_write
(
'%s = %s,'
,
str
(
alias
).
split
(
' '
)[
0
].
split
(
'='
)[
0
],
str
(
mem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
self
.
_write
(
'};'
,
indent
=-
1
)
self
.
_write
(
'};'
,
indent
=-
1
)
self
.
_non_static_members
.
append
(
e
)
self
.
_non_static_members
.
append
(
e
)
self
.
_write
(
'static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;'
,
self
.
_write
(
'static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;'
,
...
@@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase):
...
@@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase):
if
e
.
combined
:
if
e
.
combined
:
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
else
:
else
:
default
=
'{}::{}'
.
format
(
e
.
name
,
e
.
members
[
e
.
default
])
value
=
str
(
e
.
members
[
e
.
default
])
value
=
value
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
default
=
'{}::{}'
.
format
(
e
.
name
,
value
)
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
def
_on_member_enum_alias
(
self
,
e
):
def
_on_member_enum_alias
(
self
,
e
):
...
@@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase):
...
@@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase):
if
s
.
combined
:
if
s
.
combined
:
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
s
.
compose_combined_enum
(
e
.
default
))
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
s
.
compose_combined_enum
(
e
.
default
))
else
:
else
:
default
=
'{}::{}'
.
format
(
e
.
name
,
s
.
members
[
e
.
get_default
()])
value
=
str
(
s
.
members
[
e
.
get_default
()])
value
=
value
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
default
=
'{}::{}'
.
format
(
e
.
name
,
value
)
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
def
_on_member_field
(
self
,
f
):
def
_on_member_field
(
self
,
f
):
...
@@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter):
...
@@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter):
def
_on_member_enum
(
self
,
e
):
def
_on_member_enum
(
self
,
e
):
self
.
_write_doc
(
e
.
name
)
self
.
_write_doc
(
e
.
name
)
self
.
_write
(
'struct %s {'
,
e
.
name
,
indent
=
1
)
self
.
_write
(
'struct %s {'
,
e
.
name
,
indent
=
1
)
for
idx
,
val
in
enumerate
(
e
.
members
)
:
for
val
in
e
.
members
:
self
.
_write_doc
(
val
)
self
.
_write_doc
(
val
)
self
.
_write
(
'static const uint32_t %s = %d;'
,
val
,
idx
)
v
=
str
(
val
)
self
.
_write
(
'static const uint32_t %s;'
,
v
)
for
mem
,
alias
in
e
.
member_alias
:
for
mem
,
alias
in
e
.
member_alias
:
self
.
_write
(
'static const uint32_t %s = %s;'
,
alias
,
mem
)
self
.
_write
(
'static const uint32_t %s = %s;'
,
str
(
alias
).
split
(
' '
)[
0
].
split
(
'='
)[
0
],
str
(
mem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
self
.
_write
(
'};'
,
indent
=-
1
)
self
.
_write
(
'};'
,
indent
=-
1
)
def
_on_member_enum_alias
(
self
,
e
):
def
_on_member_enum_alias
(
self
,
e
):
...
@@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase):
...
@@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase):
members
=
e
.
src_enum
.
members
members
=
e
.
src_enum
.
members
else
:
else
:
members
=
e
.
members
members
=
e
.
members
for
idx
,
i
in
enumerate
(
members
):
for
i
in
members
:
v
=
str
(
i
)
v
=
v
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
self
.
_write
(
'case %s::%s::%s: return "%s";'
,
self
.
_write
(
'case %s::%s::%s: return "%s";'
,
self
.
_param_name
,
e
.
name
,
i
,
i
,
indent
=
0
)
self
.
_param_name
,
e
.
name
,
v
,
v
,
indent
=
0
)
self
.
_write
(
'default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));'
,
self
.
_write
(
'default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));'
,
self
.
_param_name
,
e
.
name
,
indent
=
0
)
self
.
_param_name
,
e
.
name
,
indent
=
0
)
self
.
_write
(
'}'
,
indent
=-
1
)
self
.
_write
(
'}'
,
indent
=-
1
)
...
...
dnn/scripts/gen_tablegen.py
浏览文件 @
fb49a283
...
@@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase):
...
@@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase):
fullname
=
"::megdnn::param::{}"
.
format
(
p
.
name
)
fullname
=
"::megdnn::param::{}"
.
format
(
p
.
name
)
enum_def
=
"MgbEnumAttr<
\"
{}
\"
,
\"
{}
\"
, ["
.
format
(
fullname
,
e
.
name
)
enum_def
=
"MgbEnumAttr<
\"
{}
\"
,
\"
{}
\"
, ["
.
format
(
fullname
,
e
.
name
)
def
format
(
v
):
def
format
(
v
):
return
'
\"
{}
\"
'
.
format
(
str
(
v
))
return
'
\"
{}
\"
'
.
format
(
str
(
v
)
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
enum_def
+=
','
.
join
(
format
(
i
)
for
i
in
e
.
members
)
enum_def
+=
','
.
join
(
format
(
i
)
for
i
in
e
.
members
)
if
e
.
combined
:
if
e
.
combined
:
...
@@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase):
...
@@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase):
default_val
=
"static_cast<{}::{}>({})"
.
format
(
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
fullname
,
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
else
:
else
:
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
e
.
members
[
e
.
default
])
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
...
@@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase):
...
@@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase):
default_val
=
"static_cast<{}::{}>({})"
.
format
(
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
s
.
compose_combined_enum
(
e
.
get_default
()))
fullname
,
e
.
name
,
s
.
compose_combined_enum
(
e
.
get_default
()))
else
:
else
:
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
s
.
members
[
e
.
get_default
()])
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
str
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
...
...
dnn/scripts/opr_param_defs.py
浏览文件 @
fb49a283
此差异已折叠。
点击以展开。
imperative/tablegen/helper.h
浏览文件 @
fb49a283
...
@@ -241,14 +241,17 @@ private:
...
@@ -241,14 +241,17 @@ private:
if
(
auto
*
enumAttr
=
llvm
::
dyn_cast
<
MgbEnumAttrMixin
>
(
&
it
.
attr
))
{
if
(
auto
*
enumAttr
=
llvm
::
dyn_cast
<
MgbEnumAttrMixin
>
(
&
it
.
attr
))
{
body
+=
formatv
(
" switch ({0}){{
\n
"
,
"$_self."
+
it
.
name
);
body
+=
formatv
(
" switch ({0}){{
\n
"
,
"$_self."
+
it
.
name
);
for
(
auto
&&
enumMember
:
enumAttr
->
getEnumMembers
())
{
for
(
auto
&&
enumMember
:
enumAttr
->
getEnumMembers
())
{
body
+=
formatv
(
size_t
d1
=
enumMember
.
find
(
' '
);
" case {0}::{1}::{2}:
\n
"
,
size_t
d2
=
enumMember
.
find
(
'='
);
getCppClassName
(),
enumAttr
->
getEnumName
(),
enumMember
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
);
body
+=
formatv
(
" case {0}::{1}::{2}:
\n
"
,
body
+=
formatv
(
getCppClassName
(),
" props_.emplace_back(
\"
{0}
\"
,
\"
{1}
\"
);
\n
"
,
enumAttr
->
getEnumName
(),
it
.
name
,
enumMember
enumMember
.
substr
(
0
,
d
));
);
body
+=
formatv
(
" props_.emplace_back(
\"
{0}
\"
, "
"
\"
{1}
\"
);
\n
"
,
it
.
name
,
enumMember
.
substr
(
0
,
d
));
body
+=
" break;
\n
"
;
body
+=
" break;
\n
"
;
}
}
body
+=
" default: break;
\n
"
;
body
+=
" default: break;
\n
"
;
...
...
imperative/tablegen/targets/cpp_class.cpp
浏览文件 @
fb49a283
...
@@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() {
...
@@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() {
std
::
vector
<
std
::
string
>
case_body
;
std
::
vector
<
std
::
string
>
case_body
;
std
::
string
ename
=
formatv
(
"{0}::{1}"
,
std
::
string
ename
=
formatv
(
"{0}::{1}"
,
op
.
getCppClassName
(),
attr
->
getEnumName
());
op
.
getCppClassName
(),
attr
->
getEnumName
());
llvm
::
for_each
(
attr
->
getEnumMembers
(),
[
&
](
auto
&&
v
){
llvm
::
for_each
(
attr
->
getEnumMembers
(),
[
&
](
auto
&&
v
)
{
case_body
.
push_back
(
formatv
(
size_t
d1
=
v
.
find
(
' '
);
"case {0}::{1}: return
\"
{1}
\"
;"
,
ename
,
v
));
size_t
d2
=
v
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
case_body
.
push_back
(
formatv
(
"case {0}::{1}: return
\"
{1}
\"
;"
,
ename
,
v
.
substr
(
0
,
d
)));
});
});
os
<<
formatv
(
R"(
os
<<
formatv
(
R"(
template <>
template <>
...
...
imperative/tablegen/targets/pybind11.cpp
浏览文件 @
fb49a283
...
@@ -50,14 +50,15 @@ void OpDefEmitter::emit() {
...
@@ -50,14 +50,15 @@ void OpDefEmitter::emit() {
);
);
std
::
vector
<
std
::
string
>
body
;
std
::
vector
<
std
::
string
>
body
;
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
os
<<
formatv
(
size_t
d1
=
i
.
find
(
' '
);
"
\n
.value(
\"
{2}
\"
, {0}::{1}::{2})"
,
size_t
d2
=
i
.
find
(
'='
);
className
,
attr
->
getEnumName
(),
i
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
);
os
<<
formatv
(
"
\n
.value(
\"
{2}
\"
, {0}::{1}::{2})"
,
className
,
attr
->
getEnumName
(),
i
.
substr
(
0
,
d
));
body
.
push_back
(
formatv
(
body
.
push_back
(
formatv
(
"if (str ==
\"
{2}
\"
) return {0}::{1}::{2};"
,
"if (str ==
\"
{2}
\"
) return {0}::{1}::{2};"
,
className
,
attr
->
getEnumName
(),
i
className
,
attr
->
getEnumName
(),
i
.
substr
(
0
,
d
)));
));
}
}
if
(
attr
->
getEnumCombinedFlag
())
{
if
(
attr
->
getEnumCombinedFlag
())
{
//! define operator |
//! define operator |
...
...
imperative/tablegen/targets/python_c_extension.cpp
浏览文件 @
fb49a283
...
@@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() {
...
@@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() {
&
ctx
);
&
ctx
);
auto
quote
=
[
&
](
auto
&&
i
)
->
std
::
string
{
auto
quote
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
formatv
(
"
\"
{0}
\"
"
,
i
);
size_t
d1
=
i
.
find
(
' '
);
size_t
d2
=
i
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
return
formatv
(
"
\"
{0}
\"
"
,
i
.
substr
(
0
,
d
));
};
};
os
<<
tgfmt
(
R"(
os
<<
tgfmt
(
R"(
template<> const char*
template<> const char*
...
@@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0};
...
@@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0};
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
attr
->
getEnumMembers
(),
quote
),
", "
));
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
attr
->
getEnumMembers
(),
quote
),
", "
));
auto
mem2value
=
[
&
](
auto
&&
i
)
->
std
::
string
{
auto
mem2value
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
tgfmt
(
"{normalize_enum(
\"
$0
\"
), $opClass::$enumClass::$0}"
,
&
ctx
,
i
);
size_t
d1
=
i
.
find
(
' '
);
size_t
d2
=
i
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
return
tgfmt
(
"{normalize_enum(
\"
$0
\"
), $opClass::$enumClass::$0}"
,
&
ctx
,
i
.
substr
(
0
,
d
));
};
};
os
<<
tgfmt
(
R"(
os
<<
tgfmt
(
R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
template<> std::unordered_map<std::string, $opClass::$enumClass>
...
@@ -192,12 +199,15 @@ os << tgfmt(R"(
...
@@ -192,12 +199,15 @@ os << tgfmt(R"(
auto
&&
members
=
attr
->
getEnumMembers
();
auto
&&
members
=
attr
->
getEnumMembers
();
for
(
size_t
idx
=
0
;
idx
<
members
.
size
();
++
idx
)
{
for
(
size_t
idx
=
0
;
idx
<
members
.
size
();
++
idx
)
{
size_t
d1
=
members
[
idx
].
find
(
' '
);
size_t
d2
=
members
[
idx
].
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
os
<<
tgfmt
(
R"({
os
<<
tgfmt
(
R"({
PyObject* inst = e_type->tp_alloc(e_type, 0);
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})"
,
&
ctx
,
members
[
idx
],
idx
);
})"
,
&
ctx
,
members
[
idx
]
.
substr
(
0
,
d
)
,
idx
);
}
}
}
}
...
...
tools/gen_header_for_bin_reduce.py
浏览文件 @
fb49a283
...
@@ -136,12 +136,13 @@ class HeaderGen:
...
@@ -136,12 +136,13 @@ class HeaderGen:
mode_list
=
[
i
.
strip
()
for
i
in
fin
]
mode_list
=
[
i
.
strip
()
for
i
in
fin
]
for
i
in
mode_list
:
for
i
in
mode_list
:
i
=
i
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
i
in
self
.
_elemwise_modes
:
if
i
in
self
.
_elemwise_modes
:
content
=
'_cb({})'
.
format
(
i
)
content
=
'_cb({})'
.
format
(
i
)
else
:
else
:
content
=
''
content
=
''
self
.
_write_def
(
self
.
_write_def
(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'
.
format
(
i
),
content
)
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'
.
format
(
i
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
),
content
)
self
.
_write_def
(
'MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)'
,
self
.
_write_def
(
'MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)'
,
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)'
)
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)'
)
...
...
tools/param_defs/mgb_opr_param_defs.py
浏览文件 @
fb49a283
...
@@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields(
...
@@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'ExecutionPolicy'
,
version
=
0
,
is_legacy
=
True
).
(
pdef
(
'ExecutionPolicy'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
'Strategy'
,
add_enum
(
'Strategy'
,
Doc
(
'HEURISTIC'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'HEURISTIC
= 0
'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'HEURISTIC_REPRODUCIBLE'
,
'use heuristic to choose the fastest algorithm, '
Doc
(
'HEURISTIC_REPRODUCIBLE
= 1
'
,
'use heuristic to choose the fastest algorithm, '
'and the chosen algorithm is reproducible'
),
'and the chosen algorithm is reproducible'
),
Doc
(
'PROFILE'
,
Doc
(
'PROFILE
= 2
'
,
'run possible algorithms on real device to find the best'
),
'run possible algorithms on real device to find the best'
),
Doc
(
'PROFILE_REPRODUCIBLE'
,
Doc
(
'PROFILE_REPRODUCIBLE
= 3
'
,
'the fastest of profile result that is also reproducible'
),
'the fastest of profile result that is also reproducible'
),
Doc
(
'PROFILE_HEURISTIC'
,
Doc
(
'PROFILE_HEURISTIC
= 4
'
,
'use profile result and heuristic to choose the fastest algorithm'
)).
'use profile result and heuristic to choose the fastest algorithm'
)).
add_fields
(
'uint64'
,
add_fields
(
'uint64'
,
Doc
(
'workspace_limit'
,
'workspace limit in bytes'
),
Doc
(
'workspace_limit'
,
'workspace limit in bytes'
),
...
@@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields(
...
@@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'ExecutionPolicy'
,
'specify how to select an algorithm for an operator'
,
version
=
1
).
(
pdef
(
'ExecutionPolicy'
,
'specify how to select an algorithm for an operator'
,
version
=
1
).
add_bit_combination_enum
(
'Strategy'
,
add_bit_combination_enum
(
'Strategy'
,
Doc
(
'HEURISTIC'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'HEURISTIC
= 1 << 0
'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'PROFILE'
,
Doc
(
'PROFILE
= 1 << 1
'
,
'run possible algorithms on real device to find the best'
),
'run possible algorithms on real device to find the best'
),
Doc
(
'REPRODUCIBLE'
,
Doc
(
'REPRODUCIBLE
= 1 << 2
'
,
'when profile or heuristic algo selection it require the algos'
'when profile or heuristic algo selection it require the algos'
'must be reproducible'
),
'must be reproducible'
),
Doc
(
'OPTIMIZED'
,
Doc
(
'OPTIMIZED
= 1 << 3
'
,
'profile require algos are optmized to achieve fast-profile'
),
'profile require algos are optmized to achieve fast-profile'
),
default
=
(
'HEURISTIC'
,),
default
=
(
'HEURISTIC'
,),
member_alias
=
[((
'HEURISTIC'
,
'REPRODUCIBLE'
),
'HEURISTIC_REPRODUCIBLE'
),
member_alias
=
[((
'HEURISTIC'
,
'REPRODUCIBLE'
),
'HEURISTIC_REPRODUCIBLE'
),
...
@@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields(
...
@@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'CollectiveComm'
,
'collective communication between multiple computing '
(
pdef
(
'CollectiveComm'
,
'collective communication between multiple computing '
'nodes on localhost'
)
'nodes on localhost'
)
.
add_enum
(
Doc
(
'Mode'
,
'mode of collective communication'
),
.
add_enum
(
Doc
(
'Mode'
,
'mode of collective communication'
),
Doc
(
'REDUCE_SUM'
,
'reduce by sum to output computing node'
),
Doc
(
'REDUCE_SUM
= 0
'
,
'reduce by sum to output computing node'
),
Doc
(
'BROADCAST'
,
'copy input value to each output computing node'
),
Doc
(
'BROADCAST
= 1
'
,
'copy input value to each output computing node'
),
Doc
(
'ALL_GATHER'
,
'each output comp node gets the concatenated '
Doc
(
'ALL_GATHER
= 2
'
,
'each output comp node gets the concatenated '
'value of all inputs'
),
'value of all inputs'
),
Doc
(
'REDUCE_SCATTER_SUM'
,
Doc
(
'REDUCE_SCATTER_SUM
= 3
'
,
'reduce inputs by sum and each output gets one part of it'
),
'reduce inputs by sum and each output gets one part of it'
),
Doc
(
'ALL_REDUCE_SUM'
,
'every output gets the sum of all inputs'
),
Doc
(
'ALL_REDUCE_SUM
= 4
'
,
'every output gets the sum of all inputs'
),
Doc
(
'ALL_REDUCE_MAX'
,
'every output gets the max of all inputs'
),
Doc
(
'ALL_REDUCE_MAX
= 5
'
,
'every output gets the max of all inputs'
),
Doc
(
'ALL_REDUCE_MIN'
,
'every output gets the min of all inputs'
),
Doc
(
'ALL_REDUCE_MIN
= 6
'
,
'every output gets the min of all inputs'
),
Doc
(
'ALL_REDUCE_PROD'
,
'every output gets the prod of all inputs'
),
Doc
(
'ALL_REDUCE_PROD
= 7
'
,
'every output gets the prod of all inputs'
),
Doc
(
'GATHER'
,
'concat inputs to one node'
),
Doc
(
'GATHER
= 8
'
,
'concat inputs to one node'
),
Doc
(
'SCATTER'
,
'scatter input to each output computing node'
),
Doc
(
'SCATTER
= 9
'
,
'scatter input to each output computing node'
),
Doc
(
'ALL_TO_ALL'
,
'scatter inputs and gather them on each computing node'
),
Doc
(
'ALL_TO_ALL
= 10
'
,
'scatter inputs and gather them on each computing node'
),
name_field
=
'mode'
))
name_field
=
'mode'
))
(
pdef
(
'FakeSerializedDType'
,
(
pdef
(
'FakeSerializedDType'
,
...
@@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields(
...
@@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields(
'evaluate a predicate and branch keys to setup ExecutionMask objects '
'evaluate a predicate and branch keys to setup ExecutionMask objects '
'with associated predicate proxy vars (PPVs)'
)
'with associated predicate proxy vars (PPVs)'
)
.
add_enum
(
Doc
(
'Mode'
,
'how to compare predicate var with branch keys'
),
.
add_enum
(
Doc
(
'Mode'
,
'how to compare predicate var with branch keys'
),
Doc
(
'CASE'
,
Doc
(
'CASE
= 0
'
,
'The outputs correspond to branch keys, '
'The outputs correspond to branch keys, '
'and the one which equals predicate would be activated. '
'and the one which equals predicate would be activated. '
'This behaves like a case-statement in many languages.'
),
'This behaves like a case-statement in many languages.'
),
Doc
(
'CASE_FALLBACK'
,
'like :attr:`CASE`, but add an extra output '
Doc
(
'CASE_FALLBACK
= 1
'
,
'like :attr:`CASE`, but add an extra output '
'that would be activated if no branch is matched'
),
'that would be activated if no branch is matched'
),
Doc
(
'PIECEWISE'
,
'One more outputs would be produced than the '
Doc
(
'PIECEWISE
= 2
'
,
'One more outputs would be produced than the '
'number of branch keys, representing the interval in which the '
'number of branch keys, representing the interval in which the '
'predicate var fits in. The intervals are defined as '
'predicate var fits in. The intervals are defined as '
r
':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, '
r
':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, '
...
@@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields(
...
@@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'CondExecPredLogical'
,
(
pdef
(
'CondExecPredLogical'
,
'compute a logical function over a set of PPVs'
)
'compute a logical function over a set of PPVs'
)
.
add_enum
(
'Mode'
,
Doc
(
'OR'
,
'logical or'
),
.
add_enum
(
'Mode'
,
Doc
(
'OR
= 0
'
,
'logical or'
),
Doc
(
'AND'
,
'logical and'
),
Doc
(
'AND
= 1
'
,
'logical and'
),
Doc
(
'XOR'
,
'exclusive-or'
),
Doc
(
'XOR
= 2
'
,
'exclusive-or'
),
Doc
(
'NOR'
,
'not or(inputs)'
),
Doc
(
'NOR
= 3
'
,
'not or(inputs)'
),
Doc
(
'NAND'
,
'not and(inputs)'
),
Doc
(
'NAND
= 4
'
,
'not and(inputs)'
),
Doc
(
'XNOR'
,
'not xor(inputs)'
))
Doc
(
'XNOR
= 5
'
,
'not xor(inputs)'
))
)
)
(
pdef
(
'CondExecMark'
,
(
pdef
(
'CondExecMark'
,
'add ExecutionMask of the input PPV to this opr and readers of the '
'add ExecutionMask of the input PPV to this opr and readers of the '
'outputs of this opr'
)
'outputs of this opr'
)
.
add_enum
(
Doc
(
'GradMode'
,
'mode for computing the gradient'
),
.
add_enum
(
Doc
(
'GradMode'
,
'mode for computing the gradient'
),
Doc
(
'SUM'
,
'normal gradient mode: sum all the activated components'
),
Doc
(
'SUM
= 0
'
,
'normal gradient mode: sum all the activated components'
),
Doc
(
'SUM_COND_OUT'
,
'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
Doc
(
'SUM_COND_OUT
= 1
'
,
'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
'oprs that depend on the gradient opr would not be executed '
'oprs that depend on the gradient opr would not be executed '
'if the forward var is not used.'
),
'if the forward var is not used.'
),
name_field
=
'grad_mode'
)
name_field
=
'grad_mode'
)
...
@@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields(
...
@@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields(
execution into account, this option can be used to bypass static
execution into account, this option can be used to bypass static
inference errors. This is currently only used by automatically
inference errors. This is currently only used by automatically
generated gradient oprs."""
),
generated gradient oprs."""
),
Doc
(
'SHAPE_VALUE'
,
'enable both shape and value inference'
),
Doc
(
'SHAPE_VALUE
= 0
'
,
'enable both shape and value inference'
),
Doc
(
'SHAPE_ONLY'
,
Doc
(
'SHAPE_ONLY
= 1
'
,
'only enable shape inference (disable value inference)'
),
'only enable shape inference (disable value inference)'
),
Doc
(
'NONE'
,
'disable both shape and value inference'
),
Doc
(
'NONE
= 2
'
,
'disable both shape and value inference'
),
name_field
=
'static_infer'
)
name_field
=
'static_infer'
)
)
)
...
@@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields(
...
@@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields(
'number of output vars (i.e. vars per branch)'
),
'number of output vars (i.e. vars per branch)'
),
1
)
1
)
.
add_enum
(
'Mode'
,
.
add_enum
(
'Mode'
,
Doc
(
'EXACT_ONE'
,
'copy the var whose mask is activated to the output'
Doc
(
'EXACT_ONE
= 0
'
,
'copy the var whose mask is activated to the output'
', requiring that exactly one branch is active'
),
', requiring that exactly one branch is active'
),
Doc
(
'EXACT_ONE_SAME_SHAPE'
,
'like :attr:`EXACT_ONE` with the '
Doc
(
'EXACT_ONE_SAME_SHAPE
= 1
'
,
'like :attr:`EXACT_ONE` with the '
'requirement that all branches have the same shape, so shape '
'requirement that all branches have the same shape, so shape '
'inference can be easier'
),
'inference can be easier'
),
Doc
(
'SUM'
,
'sum all the active branches into output var; require '
Doc
(
'SUM
= 2
'
,
'sum all the active branches into output var; require '
'all branches to have the same shape. Extra shape vars are '
'all branches to have the same shape. Extra shape vars are '
'needed in this mod, so the outputs can be initialized to zero '
'needed in this mod, so the outputs can be initialized to zero '
'when no input is active (and their shapes are probably '
'when no input is active (and their shapes are probably '
'unknown).'
),
'unknown).'
),
Doc
(
'SUM_COND_OUT'
,
'like :attr:`SUM` but also add an ExecutionMask'
Doc
(
'SUM_COND_OUT
= 3
'
,
'like :attr:`SUM` but also add an ExecutionMask'
' to the readers of output vars, so they would be skipped if '
' to the readers of output vars, so they would be skipped if '
' no branch is taken'
)
' no branch is taken'
)
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录