提交 ad87f78a 编写于 作者: M Megvii Engine Team

chore(imperative): refine tblgen for generating op name

GitOrigin-RevId: f47ceae726aeb8b385901dd9d2964982da3fe447
上级 4b98c721
......@@ -11,6 +11,11 @@ import io
from gen_param_defs import member_defs, ParamDef, IndentWriterBase
# FIXME: move supportToString flag definition into the param def source file
ENUM_TO_STRING_SPECIAL_RULES = [
("Elemwise", "Mode"),
("ElemwiseMultiType", "Mode")
]
class ConverterWriter(IndentWriterBase):
_skip_current_param = False
......@@ -86,7 +91,10 @@ class ConverterWriter(IndentWriterBase):
def format(v):
return '\"{}\"'.format(str(v))
enum_def += ','.join(format(i) for i in e.members)
enum_def += "]>"
enum_def += "]"
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
enum_def += ", 1" # whether generate ToStringTrait
enum_def += ">"
self._write("def {} : {};".format(td_class, enum_def))
if self._skip_current_param:
......
......@@ -12,6 +12,7 @@
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h"
......
......@@ -179,6 +179,34 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
);
}
static void gen_to_string_trait_for_enum(raw_ostream &os, MgbOp& op) {
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
if (attr->supportToString()) {
std::vector<std::string> case_body;
std::string ename = formatv("{0}::{1}",
op.getCppClassName(), attr->getEnumName());
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){
case_body.push_back(formatv(
"case {0}::{1}: return \"{1}\";", ename, v));
});
os << formatv(R"(
template <>
struct ToStringTrait<{0}> {
std::string operator()({0} e) const {
switch (e) {
{1}
default:
return "{0}::Unknown";
}
}
};
)", ename, llvm::join(case_body, "\n"));
}
}
}
}
static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
auto&& className = op.getCppClassName();
os << formatv(
......@@ -241,7 +269,13 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
os << formatv(
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
);
os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx);
os << formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx);
os << "}\n";
os << "} // anonymous namespace\n";
......@@ -577,6 +611,7 @@ static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
for_each_operator(os, keeper, gen_op_def_c_header_single);
for_each_operator(os, keeper, gen_to_string_trait_for_enum);
return false;
}
......
......@@ -74,6 +74,9 @@ struct MgbEnumAttrMixin : public MgbAttrWrapperBase {
std::vector<StringRef> getEnumMembers() const {
return getBaseRecord()->getValueAsListOfStrings("enumMembers");
}
bool supportToString() const {
return getBaseRecord()->getValueAsBit("supportToString");
}
};
struct MgbHashableAttrMixin : public MgbAttrWrapperBase {
......@@ -170,6 +173,12 @@ public:
}
return ret;
}
std::string getNameFunctionTemplate() const {
if (auto f = getDef().getValueAsOptionalString("nameFunction")) {
return f.getValue().str();
}
return formatv(" return \"{0}\";\n", getCppClassName());
}
};
struct MgbHashableOpMixin : public MgbOpBase {
......@@ -241,30 +250,6 @@ private:
body += " return props_;\n";
return body;
}
std::string getModeName() const {
std::string body = formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
getCppClassName()
);
for (auto&& it : getMgbAttributes()) {
if (it.name == "mode") {
auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr);
body += " switch (op_.mode){\n";
for (auto&& enumMember: enumAttr->getEnumMembers()) {
body += formatv(
" case {0}::{1}::{2}:\n",
getCppClassName(), enumAttr->getEnumName(), enumMember
);
body += formatv(" return \"{0}\";\n", enumMember);
}
body += formatv(
" default: return \"{0}::Unknown\";\n", getCppClassName());
body += " }\n";
}
}
return body;
}
public:
static bool classof(const Operator* op) {
return op->getDef().isSubClassOf("MgbHashableOpMixin");
......@@ -288,12 +273,6 @@ public:
}
return getDefaultPropsFunction();
}
std::string getNameFunctionTemplate() const {
if (getDef().getValueAsBit("usingModeName")) {
return getModeName();
}
return formatv(" return \"{0}\";\n", getCppClassName());
}
};
} // namespace tblgen
......
......@@ -33,10 +33,11 @@ class MgbHashableAttrMixin {
string reprFunction = "std::to_string($0)";
}
class MgbEnumAttrMixin<string namespace, string name, list<string> members> {
class MgbEnumAttrMixin<string namespace, string name, list<string> members, bit toString> {
string parentNamespace = namespace;
string enumName = name;
list<string> enumMembers = members;
bit supportToString = toString;
}
class MgbAttrWrapper;
......@@ -165,8 +166,8 @@ class MgbTupleAttr<list<MgbAttrWrapper> args>:
}
// -- enum types
class MgbEnumAttr<string namespace, string enumName, list<string> members>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members> {
class MgbEnumAttr<string namespace, string enumName, list<string> members, bit toString=0>:
HashableAttr<namespace # "::" # enumName>, MgbEnumAttrMixin<namespace, enumName, members, toString> {
let storageType = "::mlir::IntegerAttr";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
......@@ -242,7 +243,6 @@ class MgbPackedParamBase<string className, string accessor>:
class MgbHashableOpMixin {
string hashFunction = ?;
string cmpFunction = ?;
bit usingModeName = 0;
}
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
......@@ -251,6 +251,7 @@ class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=
dag extraArguments = (ins);
// TODO: remove it
code extraOpdefDecl = ?;
code nameFunction = ?;
let arguments = !con(
!foldl(inputs, params, args, param, !con(args, param.fields)),
......
......@@ -21,7 +21,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
let inputs = (ins Variadic<AnyType>:$input);
let results = (outs AnyType);
let usingModeName = 1;
let nameFunction = [{
return to_string($_self.mode);
}];
}
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;
......@@ -248,7 +250,9 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara
let extraArguments = (ins
MgbDTypeAttr:$dtype
);
let usingModeName = 1;
let nameFunction = [{
return to_string($_self.mode);
}];
}
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册