/** * \file imperative/tablegen/targets/cpp_class.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./cpp_class.h" #include "../emitter.h" namespace mlir::tblgen { namespace { llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { // Note: we have already registered the corresponding attr wrappers // for following basic ctypes so we needn't handle them here /* auto&& attr_type_name = attr.getAttrDefName(); if (attr_type_name == "UI32Attr") { return "uint32_t"; } if (attr_type_name == "UI64Attr") { return "uint64_t"; } if (attr_type_name == "I32Attr") { return "int32_t"; } if (attr_type_name == "F32Attr") { return "float"; } if (attr_type_name == "F64Attr") { return "double"; } if (attr_type_name == "StrAttr") { return "std::string"; } if (attr_type_name == "BoolAttr") { return "bool"; }*/ auto&& attr = llvm::cast(attr_); if (auto e = llvm::dyn_cast(&attr)) { return e->getEnumName(); } return attr.getUnderlyingType(); } class OpDefEmitter final: public EmitterBase { public: OpDefEmitter(MgbOp& op_, raw_ostream& os_): EmitterBase(os_), op(op_) {} void emit_header(); void emit_tpl_spl(); void emit_body(); private: MgbOp& op; }; void OpDefEmitter::emit_header() { os << formatv( "class {0} : public OpDefImplBase<{0}> {{\n" " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" "public:\n", op.getCppClassName() ); // handle enum alias for (auto &&i : op.getMgbAttributes()) { if (auto attr = llvm::dyn_cast(&i.attr)) { os << formatv( " using {0} = {1};\n", attr->getEnumName(), attr->getUnderlyingType() ); } } for (auto &&i : op.getMgbAttributes()) { auto defaultValue = i.attr.getDefaultValue().str(); if (!defaultValue.empty()) { defaultValue = formatv(" = {0}", defaultValue); } os << formatv( " {0} {1}{2};\n", attr_to_ctype(i.attr), i.name, defaultValue ); } auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { os << formatv( " {0}({1}){2}{3}\n", op.getCppClassName(), paramList, memInitList, body ); }; gen_ctor("", "", " = default;"); if (!op.getMgbAttributes().empty()) { std::vector paramList, initList; for (auto &&i : op.getMgbAttributes()) { paramList.push_back(formatv( "{0} {1}_", attr_to_ctype(i.attr), i.name )); initList.push_back(formatv( "{0}({0}_)", i.name )); } paramList.push_back("std::string scope_ = {}"); gen_ctor(llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), " { set_scope(scope_); }"); } auto packedParams = op.getPackedParams(); if (!packedParams.empty()) { std::vector paramList, initList; for (auto &&p : packedParams) { auto&& paramFields = p.getFields(); auto&& paramType = p.getFullName(); auto&& paramName = formatv("packed_param_{0}", paramList.size()); paramList.push_back( paramFields.empty() ? paramType.str() : formatv("{0} {1}", paramType, paramName) ); for (auto&& i : paramFields) { initList.push_back(formatv( "{0}({1}.{0})", i.name, paramName )); } } for (auto&& i : op.getExtraArguments()) { paramList.push_back(formatv( "{0} {1}_", attr_to_ctype(i.attr), i.name )); initList.push_back(formatv( "{0}({0}_)", i.name )); } gen_ctor(llvm::join(paramList, ", "), initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); } if (!packedParams.empty()) { for (auto&& p : packedParams) { auto accessor = p.getAccessor(); if (!accessor.empty()) { os << formatv( " {0} {1}() const {{\n", p.getFullName(), accessor ); std::vector fields; for (auto&& i : p.getFields()) { fields.push_back(i.name); } os << formatv( " return {{{0}};\n", llvm::join(fields, ", ") ); os << " }\n"; } } } if (auto decl = op.getExtraOpdefDecl()) { os << decl.getValue(); } os << formatv( "};\n\n" ); } void OpDefEmitter::emit_tpl_spl() { for (auto &&i : op.getMgbAttributes()) { if (auto attr = llvm::dyn_cast(&i.attr)) { if (attr->supportToString()) { std::vector case_body; std::string ename = formatv("{0}::{1}", op.getCppClassName(), attr->getEnumName()); llvm::for_each(attr->getEnumMembers(), [&](auto&& v) { size_t d1 = v.find(' '); size_t d2 = v.find('='); size_t d = d1 <= d2 ? d1 : d2; case_body.push_back( formatv("case {0}::{1}: return \"{1}\";", ename, v.substr(0, d))); }); os << formatv(R"( template <> struct ToStringTrait<{0}> { std::string operator()({0} e) const { switch (e) { {1} default: return "{0}::Unknown"; } } }; )", ename, llvm::join(case_body, "\n")); } } } } void OpDefEmitter::emit_body() { auto&& className = op.getCppClassName(); os << formatv( "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className ); auto formatMethImpl = [&](auto&& meth) { return formatv( "{0}_{1}_impl", className, meth ); }; std::vector methods; if (auto hashable = llvm::dyn_cast(&op)) { os << "namespace {\n"; // generate hash() mlir::tblgen::FmtContext ctx; os << formatv( "size_t {0}(const OpDef& def_) {{\n", formatMethImpl("hash") ); os << formatv( " auto&& op_ = def_.cast_final_safe<{0}>();\n" " static_cast(op_);\n", className ); ctx.withSelf("op_"); os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); os << "}\n"; // generate is_same_st() os << formatv( "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", formatMethImpl("is_same_st") ); os << formatv( " auto &&a_ = lhs_.cast_final_safe<{0}>(),\n" " &&b_ = rhs_.cast_final_safe<{0}>();\n" " static_cast(a_);\n" " static_cast(b_);\n", className ); os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); os << "}\n"; // generate props() os << formatv( "std::vector> {0}(const OpDef& def_) {{\n", formatMethImpl("props") ); os << formatv( " auto&& op_ = def_.cast_final_safe<{0}>();\n" " static_cast(op_);\n", className ); ctx.withSelf("op_"); os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); os << "}\n"; // generate make_name() os << formatv( "std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") ); os << formatv( " auto&& op_ = def_.cast_final_safe<{0}>();\n" " static_cast(op_);\n", className ); ctx.withSelf("op_"); os << mlir::tblgen::tgfmt(op.getNameFunctionTemplate(), &ctx); os << "}\n"; os << "} // anonymous namespace\n"; methods.push_back("hash"); methods.push_back("is_same_st"); methods.push_back("props"); methods.push_back("make_name"); } if (!methods.empty()) { os << formatv( "OP_TRAIT_REG({0}, {0})", op.getCppClassName() ); for (auto&& i : methods) { os << formatv( "\n .{0}({1})", i, formatMethImpl(i) ); } os << ";\n\n"; } } } // namespace bool gen_op_def_c_header(raw_ostream &os, llvm::RecordKeeper &keeper) { foreach_operator(keeper, [&](MgbOp& op) { OpDefEmitter emitter(op, os); emitter.emit_header(); emitter.emit_tpl_spl(); }); return false; } bool gen_op_def_c_body(raw_ostream &os, llvm::RecordKeeper &keeper) { foreach_operator(keeper, [&](MgbOp& op) { OpDefEmitter emitter(op, os); emitter.emit_body(); }); return false; } } // namespace mlir::tblgen