/** * \file imperative/tablegen/targets/pybind11.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./pybind11.h" #include "../emitter.h" namespace mlir::tblgen { namespace { class OpDefEmitter final: public EmitterBase { public: OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_): EmitterBase(os_, env_), op(op_) {} void emit(); private: MgbOp& op; }; void OpDefEmitter::emit() { auto className = op.getCppClassName(); os << formatv( "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", className ); for (auto&& i : op.getMgbAttributes()) { if (auto attr = llvm::dyn_cast(&i.attr)) { unsigned int enumID; if (auto alias = llvm::dyn_cast(attr)) { auto&& aliasBase = alias->getAliasBase(); enumID = llvm::cast(aliasBase) .getBaseRecord()->getID(); } else { enumID = attr->getBaseRecord()->getID(); } auto&& enumAlias = env().enumAlias; auto&& iter = enumAlias.find(enumID); if (iter == enumAlias.end()) { os << formatv( "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", className, attr->getEnumName() ); std::vector body; for (auto&& i: attr->getEnumMembers()) { size_t d1 = i.find(' '); size_t d2 = i.find('='); size_t d = d1 <= d2 ? d1 : d2; os << formatv("\n .value(\"{2}\", {0}::{1}::{2})", className, attr->getEnumName(), i.substr(0, d)); body.push_back(formatv( "if (str == \"{2}\") return {0}::{1}::{2};", className, attr->getEnumName(), i.substr(0, d))); } if (attr->getEnumCombinedFlag()) { //! define operator | os << formatv( "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " "\n return static_cast<{0}::{1}>(uint32_t(s0) | uint32_t(s1));" "\n })", className, attr->getEnumName()); //! define operator & os << formatv( "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" "\n return static_cast<{0}::{1}>(uint32_t(s0) & uint32_t(s1));" "\n })", className, attr->getEnumName()); } os << formatv( "\n .def(py::init([](const std::string& in) {" "\n auto&& str = normalize_enum(in);" "\n {0}" "\n throw py::cast_error(\"invalid enum value \" + in);" "\n }));\n", llvm::join(body, "\n ") ); os << formatv( "py::implicitly_convertible();\n\n", className, attr->getEnumName() ); enumAlias.emplace(enumID, std::make_pair(className, attr->getEnumName())); } else { os << formatv( "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", className, attr->getEnumName(), iter->second.first, iter->second.second ); } } } // generate op class binding os << formatv("{0}Inst", className); bool hasDefaultCtor = op.getMgbAttributes().empty(); if (!hasDefaultCtor) { os << "\n .def(py::init<"; std::vector targs; for (auto &&i : op.getMgbAttributes()) { targs.push_back(i.attr.getReturnType()); } os << llvm::join(targs, ", "); os << ", std::string>()"; for (auto &&i : op.getMgbAttributes()) { os << formatv(", py::arg(\"{0}\")", i.name); auto defaultValue = i.attr.getDefaultValue(); if (!defaultValue.empty()) { os << formatv(" = {0}", defaultValue); } else { hasDefaultCtor = true; } } os << ", py::arg(\"scope\") = {})"; } if (hasDefaultCtor) { os << "\n .def(py::init<>())"; } for (auto &&i : op.getMgbAttributes()) { os << formatv( "\n .def_readwrite(\"{0}\", &{1}::{0})", i.name, className ); } os << ";\n\n"; } } // namespace bool gen_op_def_pybind11(raw_ostream &os, llvm::RecordKeeper &keeper) { Environment env; using namespace std::placeholders; foreach_operator(keeper, [&](MgbOp& op) { OpDefEmitter(op, os, env).emit(); }); return false; } } // namespace mlir::tblgen