Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ad87f78a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ad87f78a
编写于
2月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
chore(imperative): refine tblgen for generating op name
GitOrigin-RevId: f47ceae726aeb8b385901dd9d2964982da3fe447
上级
4b98c721
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
66 addition
and
38 deletion
+66
-38
dnn/scripts/gen_tablegen.py
dnn/scripts/gen_tablegen.py
+9
-1
imperative/src/include/megbrain/imperative/ops/autogen.h
imperative/src/include/megbrain/imperative/ops/autogen.h
+1
-0
imperative/tablegen/autogen.cpp
imperative/tablegen/autogen.cpp
+36
-1
imperative/tablegen/helper.h
imperative/tablegen/helper.h
+9
-30
src/core/include/megbrain/ir/base.td
src/core/include/megbrain/ir/base.td
+5
-4
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+6
-2
未找到文件。
dnn/scripts/gen_tablegen.py
浏览文件 @
ad87f78a
...
...
@@ -11,6 +11,11 @@ import io
from
gen_param_defs
import
member_defs
,
ParamDef
,
IndentWriterBase
# FIXME: move supportToString flag definition into the param def source file
ENUM_TO_STRING_SPECIAL_RULES
=
[
(
"Elemwise"
,
"Mode"
),
(
"ElemwiseMultiType"
,
"Mode"
)
]
class
ConverterWriter
(
IndentWriterBase
):
_skip_current_param
=
False
...
...
@@ -86,7 +91,10 @@ class ConverterWriter(IndentWriterBase):
def
format
(
v
):
return
'
\"
{}
\"
'
.
format
(
str
(
v
))
enum_def
+=
','
.
join
(
format
(
i
)
for
i
in
e
.
members
)
enum_def
+=
"]>"
enum_def
+=
"]"
if
ENUM_TO_STRING_SPECIAL_RULES
.
count
((
p
.
name
,
e
.
name
)):
enum_def
+=
", 1"
# whether generate ToStringTrait
enum_def
+=
">"
self
.
_write
(
"def {} : {};"
.
format
(
td_class
,
enum_def
))
if
self
.
_skip_current_param
:
...
...
imperative/src/include/megbrain/imperative/ops/autogen.h
浏览文件 @
ad87f78a
...
...
@@ -12,6 +12,7 @@
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h"
...
...
imperative/tablegen/autogen.cpp
浏览文件 @
ad87f78a
...
...
@@ -179,6 +179,34 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
);
}
static
void
gen_to_string_trait_for_enum
(
raw_ostream
&
os
,
MgbOp
&
op
)
{
for
(
auto
&&
i
:
op
.
getMgbAttributes
())
{
if
(
auto
attr
=
llvm
::
dyn_cast
<
MgbEnumAttr
>
(
&
i
.
attr
))
{
if
(
attr
->
supportToString
())
{
std
::
vector
<
std
::
string
>
case_body
;
std
::
string
ename
=
formatv
(
"{0}::{1}"
,
op
.
getCppClassName
(),
attr
->
getEnumName
());
llvm
::
for_each
(
attr
->
getEnumMembers
(),
[
&
](
auto
&&
v
){
case_body
.
push_back
(
formatv
(
"case {0}::{1}: return
\"
{1}
\"
;"
,
ename
,
v
));
});
os
<<
formatv
(
R"(
template <>
struct ToStringTrait<{0}> {
std::string operator()({0} e) const {
switch (e) {
{1}
default:
return "{0}::Unknown";
}
}
};
)"
,
ename
,
llvm
::
join
(
case_body
,
"
\n
"
));
}
}
}
}
static
void
gen_op_def_c_body_single
(
raw_ostream
&
os
,
MgbOp
&
op
)
{
auto
&&
className
=
op
.
getCppClassName
();
os
<<
formatv
(
...
...
@@ -241,7 +269,13 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
os
<<
formatv
(
"std::string {0}(const OpDef& def_) {{
\n
"
,
formatMethImpl
(
"make_name"
)
);
os
<<
mlir
::
tblgen
::
tgfmt
(
hashable
->
getNameFunctionTemplate
(),
&
ctx
);
os
<<
formatv
(
" auto&& op_ = def_.cast_final_safe<{0}>();
\n
"
" static_cast<void>(op_);
\n
"
,
className
);
ctx
.
withSelf
(
"op_"
);
os
<<
mlir
::
tblgen
::
tgfmt
(
op
.
getNameFunctionTemplate
(),
&
ctx
);
os
<<
"}
\n
"
;
os
<<
"} // anonymous namespace
\n
"
;
...
...
@@ -577,6 +611,7 @@ static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
static
bool
gen_op_def_c_header
(
raw_ostream
&
os
,
RecordKeeper
&
keeper
)
{
for_each_operator
(
os
,
keeper
,
gen_op_def_c_header_single
);
for_each_operator
(
os
,
keeper
,
gen_to_string_trait_for_enum
);
return
false
;
}
...
...
imperative/tablegen/helper.h
浏览文件 @
ad87f78a
...
...
@@ -74,6 +74,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
std
::
vector
<
StringRef
>
getEnumMembers
()
const
{
return
getBaseRecord
()
->
getValueAsListOfStrings
(
"enumMembers"
);
}
bool
supportToString
()
const
{
return
getBaseRecord
()
->
getValueAsBit
(
"supportToString"
);
}
};
struct
MgbHashableAttrMixin
:
public
MgbAttrWrapperBase
{
...
...
@@ -170,6 +173,12 @@ public:
}
return
ret
;
}
std
::
string
getNameFunctionTemplate
()
const
{
if
(
auto
f
=
getDef
().
getValueAsOptionalString
(
"nameFunction"
))
{
return
f
.
getValue
().
str
();
}
return
formatv
(
" return
\"
{0}
\"
;
\n
"
,
getCppClassName
());
}
};
struct
MgbHashableOpMixin
:
public
MgbOpBase
{
...
...
@@ -241,30 +250,6 @@ private:
body
+=
" return props_;
\n
"
;
return
body
;
}
std
::
string
getModeName
()
const
{
std
::
string
body
=
formatv
(
" auto&& op_ = def_.cast_final_safe<{0}>();
\n
"
" static_cast<void>(op_);
\n
"
,
getCppClassName
()
);
for
(
auto
&&
it
:
getMgbAttributes
())
{
if
(
it
.
name
==
"mode"
)
{
auto
*
enumAttr
=
llvm
::
dyn_cast
<
MgbEnumAttrMixin
>
(
&
it
.
attr
);
body
+=
" switch (op_.mode){
\n
"
;
for
(
auto
&&
enumMember
:
enumAttr
->
getEnumMembers
())
{
body
+=
formatv
(
" case {0}::{1}::{2}:
\n
"
,
getCppClassName
(),
enumAttr
->
getEnumName
(),
enumMember
);
body
+=
formatv
(
" return
\"
{0}
\"
;
\n
"
,
enumMember
);
}
body
+=
formatv
(
" default: return
\"
{0}::Unknown
\"
;
\n
"
,
getCppClassName
());
body
+=
" }
\n
"
;
}
}
return
body
;
}
public:
static
bool
classof
(
const
Operator
*
op
)
{
return
op
->
getDef
().
isSubClassOf
(
"MgbHashableOpMixin"
);
...
...
@@ -288,12 +273,6 @@ public:
}
return
getDefaultPropsFunction
();
}
std
::
string
getNameFunctionTemplate
()
const
{
if
(
getDef
().
getValueAsBit
(
"usingModeName"
))
{
return
getModeName
();
}
return
formatv
(
" return
\"
{0}
\"
;
\n
"
,
getCppClassName
());
}
};
}
// namespace tblgen
...
...
src/core/include/megbrain/ir/base.td
浏览文件 @
ad87f78a
...
...
@@ -33,10 +33,11 @@ class MgbHashableAttrMixin {
string reprFunction = "std::to_string($0)";
}
class MgbEnumAttrMixin<string namespace, string name, list<string> members> {
class MgbEnumAttrMixin<string namespace, string name, list<string> members
, bit toString
> {
string parentNamespace = namespace;
string enumName = name;
list<string> enumMembers = members;
bit supportToString = toString;
}
class MgbAttrWrapper;
...
...
@@ -165,8 +166,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>:
}
// -- enum types
class MgbEnumAttr<string namespace, string enumName, list<string> members>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> {
class MgbEnumAttr<string namespace, string enumName, list<string> members
, bit toString=0
>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members
, toString
> {
let storageType = "::mlir::IntegerAttr";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
...
...
@@ -242,7 +243,6 @@ class MgbPackedParamBase<string className, string accessor>:
class MgbHashableOpMixin {
string hashFunction = ?;
string cmpFunction = ?;
bit usingModeName = 0;
}
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
...
...
@@ -251,6 +251,7 @@ class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=
dag extraArguments = (ins);
// TODO: remove it
code extraOpdefDecl = ?;
code nameFunction = ?;
let arguments = !con(
!foldl(inputs, params, args, param, !con(args, param.fields)),
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
ad87f78a
...
...
@@ -21,7 +21,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
let inputs = (ins Variadic<AnyType>:$input);
let results = (outs AnyType);
let usingModeName = 1;
let nameFunction = [{
return to_string($_self.mode);
}];
}
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;
...
...
@@ -248,7 +250,9 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara
let extraArguments = (ins
MgbDTypeAttr:$dtype
);
let usingModeName = 1;
let nameFunction = [{
return to_string($_self.mode);
}];
}
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录