提交 30ef28b0 编写于 作者: M Megvii Engine Team

build(imperative): generate stub files for core ops

GitOrigin-RevId: 7971df5f20cb317046647fecd11b1978b4b16242
上级 8fe5e74f
......@@ -1399,6 +1399,23 @@ if(TARGET _imperative_rt)
DEPENDS ${develop_depends}
VERBATIM)
add_dependencies(develop _imperative_rt)
# generate stub file for _imperative_rt
execute_process(
COMMAND ${PYTHON3_EXECUTABLE_WITHOUT_VERSION} -c
"import mypy.version; assert mypy.version.__version__ >= '0.982'"
RESULT_VARIABLE NOT_HAVING_MYPY_STUBGEN)
if(NOT ${NOT_HAVING_MYPY_STUBGEN})
add_custom_command(
TARGET develop
POST_BUILD
COMMAND
${PYTHON3_EXECUTABLE_WITHOUT_VERSION} -c "from mypy.stubgen import main; main()"
-p ${PACKAGE_NAME}.core.${MODULE_NAME} -o
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python
VERBATIM)
endif()
endif()
# Configure and install pkg-config. Note that unlike the Config.cmake modules, this is
......
......@@ -107,6 +107,9 @@ add_custom_command(
${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/include # clean develop
COMMAND ${CMAKE_COMMAND} -E remove -f
${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/lib # clean develop
COMMAND ${CMAKE_COMMAND} -E remove -f
${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/${MODULE_NAME} # clean
# develop
COMMAND
${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine
${CMAKE_CURRENT_BINARY_DIR}/python/megengine
......
# -*- coding: utf-8 -*-
from ..._imperative_rt import OpDef, ops
from ..._imperative_rt import OpDef
__all__ = ["OpDef"]
original_keys = set()
for k, v in ops.__dict__.items():
if isinstance(v, type) and issubclass(v, OpDef):
globals()[k] = v
__all__.append(k)
def backup_keys():
global original_keys
original_keys = set()
for k in globals().keys():
original_keys.add(k)
backup_keys()
from ..._imperative_rt.ops import * # isort:skip
def setup():
to_be_removed = set()
for k, v in globals().items():
is_original_key = k in original_keys
is_op = isinstance(v, type) and issubclass(v, OpDef)
if not is_op and not is_original_key:
to_be_removed.add(k)
for k in to_be_removed:
del globals()[k]
setup()
......@@ -7,7 +7,7 @@ from ..core.ops import builtin
from ..logger import get_logger
from ..utils.deprecation import deprecated
Strategy = builtin.ops.Convolution.Strategy
Strategy = builtin.Convolution.Strategy
if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
get_logger().warning(
......
black==19.10b0
isort==4.3.21
pylint==2.4.3
mypy==0.750
mypy==0.982
......@@ -3,5 +3,5 @@ cf864561de125ab559c0035158656682 ../../src/core/include/megbrain/ir/ops.td
9248d42a9b3e770693306992156f6015 generated/opdef.h.inl
5c7e7ac49d1338d70ac84ba309e6732b generated/opdef.cpp.inl
30b669eec36876a65717e0c68dd76c83 generated/opdef.py.inl
d10455217f5f01e3d2668e5689068920 generated/opdef.cpy.inl
4312de292a3d71f34a084bf43ea2ecec generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
#include "python_c_extension.h"
#include <cctype>
#include <functional>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "../emitter.h"
#include "python_c_extension.h"
namespace mlir::tblgen {
namespace {
class TypeInfo;
std::pair<TypeInfo, int> parse_type(const std::string&, const int);
std::pair<std::vector<std::string>, int> parse_namespace(const std::string&, const int);
struct Unit {};
Unit unit;
struct ParseError {};
class TypeInfo {
public:
TypeInfo(std::string name) : name(name) {}
std::string to_python_type_string() {
std::stringstream ss;
ss << translate_type_name(name);
if (params.size() > 0) {
ss << "[" << params[0].to_python_type_string();
for (auto i = 1; i < params.size(); i++) {
ss << ", " << params[i].to_python_type_string();
}
ss << "]";
}
return ss.str();
}
std::string translate_type_name(const std::string& cppTypeName) {
auto res = translation.find(cppTypeName);
if (res != translation.end())
return res->second;
try {
auto segments = parse_namespace(cppTypeName, 0).first;
// special rules
if (segments.size() > 3 && segments[0] == "megdnn" &&
segments[1] == "param") {
segments.erase(segments.begin(), segments.begin() + 3);
} else if (
segments.size() == 2 && segments[0] == "megdnn" &&
segments[1] == "DType") {
segments.erase(segments.begin(), segments.begin() + 1);
segments[0] = "str";
} else if (
segments.size() == 2 && segments[0] == "mgb" &&
segments[1] == "CompNode") {
segments.erase(segments.begin(), segments.begin() + 1);
segments[0] = "str";
}
std::stringstream joined;
joined << segments[0];
for (auto i = 1; i < segments.size(); i++) {
joined << "." << segments[i];
}
return joined.str();
} catch (ParseError) {
return cppTypeName;
}
}
std::string name;
std::vector<TypeInfo> params;
private:
static const std::unordered_map<std::string, std::string> translation;
};
const std::unordered_map<std::string, std::string> TypeInfo::translation = {
{"bool", "bool"}, {"double", "float"}, {"float", "float"},
{"int32_t", "int"}, {"int8_t", "int"}, {"size_t", "int"},
{"std::string", "str"}, {"std::tuple", "tuple"}, {"std::vector", "list"},
{"uint32_t", "int"}, {"uint64_t", "int"},
};
// a parser takes:
// 1. a string to parse
// 2. location to parse from (index of character)
// returns:
// 1. parsing result (type T)
// 2. end location of substring which is consumed by parsing
// throws exception when failed to parse
template <typename T>
using Parser = std::function<std::pair<T, int>(const std::string&, const int)>;
std::pair<Unit, int> parse_blank(const std::string& text, const int begin) {
auto now = begin;
while (now < text.length() && isblank(text[now]))
now += 1;
return {unit, now};
}
Parser<Unit> parse_non_blank_char(char ch) {
return [=](const std::string& text, const int begin) -> std::pair<Unit, int> {
auto blankEnd = parse_blank(text, begin).second;
if (blankEnd >= text.length() || text[blankEnd] != ch)
throw ParseError{};
return {unit, blankEnd + 1};
};
}
Parser<std::string> parse_allowed_chars(std::function<bool(char)> allow) {
return [=](const std::string& text,
const int begin) -> std::pair<std::string, int> {
auto now = begin;
while (now < text.length() && allow(text[now]))
now += 1;
return {text.substr(begin, now - begin), now};
};
}
template <typename T>
Parser<std::tuple<T>> parse_seq(Parser<T> only) {
return [=](const std::string& text,
const int begin) -> std::pair<std::tuple<T>, int> {
auto res = only(text, begin);
return {{res.first}, res.second};
};
}
template <typename Head, typename... Tail>
Parser<std::tuple<Head, Tail...>> parse_seq(Parser<Head> head, Parser<Tail>... tail) {
return [=](const std::string& text,
const int begin) -> std::pair<std::tuple<Head, Tail...>, int> {
std::pair<Head, int> headRes = head(text, begin);
std::pair<std::tuple<Tail...>, int> tailRes =
parse_seq(tail...)(text, headRes.second);
return {std::tuple_cat(std::tuple<Head>(headRes.first), tailRes.first),
tailRes.second};
};
}
template <typename T>
Parser<std::vector<T>> parse_many_at_least0(Parser<T> one) {
return [=](const std::string& text,
const int begin) -> std::pair<std::vector<T>, int> {
std::vector<T> ret;
auto now = begin;
try {
while (true) {
auto oneRes = one(text, now);
ret.emplace_back(oneRes.first);
now = oneRes.second;
}
} catch (ParseError) {
}
return {ret, now};
};
}
template <typename C>
Parser<std::vector<C>> parse_sep_by_at_least1(
Parser<Unit> separator, Parser<C> component) {
return [=](const std::string& text,
const int begin) -> std::pair<std::vector<C>, int> {
std::vector<C> ret;
auto headRes = component(text, begin);
ret.emplace_back(headRes.first);
auto tailRes = parse_many_at_least0(parse_seq(separator, component))(
text, headRes.second);
for (const auto& elem : tailRes.first) {
ret.emplace_back(std::get<1>(elem));
}
return {ret, tailRes.second};
};
}
std::pair<std::string, int> parse_identifier(const std::string& text, const int begin) {
auto blankEnd = parse_blank(text, begin).second;
auto indentRes = parse_allowed_chars(
[](char ch) { return std::isalnum(ch) || ch == '_'; })(text, blankEnd);
if (indentRes.first.empty())
throw ParseError{};
return indentRes;
};
std::pair<std::string, int> parse_qualified(const std::string& text, const int begin) {
auto blankEnd = parse_blank(text, begin).second;
auto indentRes = parse_allowed_chars([](char ch) {
return std::isalnum(ch) || ch == '_' || ch == ':';
})(text, blankEnd);
if (indentRes.first.empty())
throw ParseError{};
return indentRes;
};
std::pair<std::vector<std::string>, int> parse_namespace(
const std::string& text, const int begin) {
auto res = parse_many_at_least0(parse_seq(
parse_non_blank_char(':'), parse_non_blank_char(':'),
Parser<std::string>(parse_identifier)))(text, begin);
std::vector<std::string> ret;
for (const auto& elem : res.first) {
ret.emplace_back(std::get<2>(elem));
}
return {ret, res.second};
}
std::pair<TypeInfo, int> parse_leaf_type(const std::string& text, const int begin) {
auto ret = parse_qualified(text, begin);
return {TypeInfo(ret.first), ret.second};
};
std::pair<TypeInfo, int> parse_node_type(const std::string& text, const int begin) {
auto nameRes = parse_qualified(text, begin);
auto ret = TypeInfo(nameRes.first);
auto now = parse_non_blank_char('<')(text, nameRes.second).second;
auto argsRes = parse_sep_by_at_least1(
parse_non_blank_char(','), Parser<TypeInfo>(parse_type))(text, now);
ret.params = argsRes.first;
now = parse_non_blank_char('>')(text, argsRes.second).second;
return {ret, now};
};
std::pair<TypeInfo, int> parse_type(const std::string& text, const int begin) {
try {
return parse_node_type(text, begin);
} catch (ParseError) {
}
return parse_leaf_type(text, begin);
};
std::string cpp_type_to_python_type(const std::string& input) {
auto res = parse_type(input, 0);
return res.first.to_python_type_string();
}
struct Initproc {
std::string func;
Initproc(std::string&& s) : func(std::move(s)) {}
......@@ -25,6 +259,10 @@ private:
void emit_py_init();
void emit_py_getsetters();
void emit_py_methods();
void emit_py_init_proxy();
void emit_py_init_methoddef(
const std::unordered_map<std::string, std::vector<std::string>>&
enum_attr_members);
Initproc emit_initproc();
MgbOp& op;
......@@ -248,10 +486,18 @@ void $0(PyTypeObject& py_type) {
}
Initproc OpDefEmitter::emit() {
std::unordered_map<std::string, std::vector<std::string>> enum_attr_members;
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
subclasses.push_back(
EnumAttrEmitter(op.getCppClassName(), attr, os, env()).emit());
auto retType = cpp_type_to_python_type(std::string(attr->getReturnType()));
enum_attr_members[retType] = std::vector<std::string>();
for (const auto& member : attr->getEnumMembers()) {
enum_attr_members[retType].emplace_back(member);
}
}
}
......@@ -259,6 +505,8 @@ Initproc OpDefEmitter::emit() {
emit_py_init();
emit_py_getsetters();
emit_py_methods();
emit_py_init_proxy();
emit_py_init_methoddef(enum_attr_members);
return emit_initproc();
}
......@@ -318,6 +566,8 @@ PyOpDefBegin($_self) // {
static PyMethodDef tp_methods[];
$0
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
static PyObject* py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds);
static PyMethodDef py_init_methoddef;
// };
PyOpDefEnd($_self)
)",
......@@ -438,6 +688,55 @@ void OpDefEmitter::emit_py_methods() {
&ctx, llvm::join(method_items, "\n "));
}
void OpDefEmitter::emit_py_init_proxy() {
os << tgfmt(
R"(
PyObject *PyOp($_self)::py_init_proxy(PyObject *self, PyObject *args, PyObject *kwds) {
if (PyOp($_self)::py_init(self, args, kwds) < 0) {
return NULL;
}
Py_RETURN_NONE;
}
)",
&ctx);
}
void OpDefEmitter::emit_py_init_methoddef(
const std::unordered_map<std::string, std::vector<std::string>>&
enum_attr_members) {
std::string docstring = "__init__(self";
for (const auto& attr : op.getMgbAttributes()) {
if (attr.name == "workspace_limit")
continue;
auto pyType = cpp_type_to_python_type(std::string(attr.attr.getReturnType()));
auto findRes = enum_attr_members.find(pyType);
if (findRes != enum_attr_members.end()) {
pyType = formatv("Union[str, {0}]", pyType);
// TODO stubgen cannot handle Literal strings for now
// auto members = findRes->second;
// std::string enumTypeString = "Literal[";
// enumTypeString += formatv("'{0}'", lowercase(members[0]));
// for (auto i = 1; i < members.size(); i++) {
// enumTypeString += formatv(", '{0}'", lowercase(members[i]));
// }
// enumTypeString += "]";
// pyType = enumTypeString;
}
docstring += formatv(", {0}: {1} = ...", attr.name, pyType);
}
docstring += ") -> None\\n";
os << tgfmt(
R"(
PyMethodDef PyOp($_self)::py_init_methoddef = {
"__init__",
(PyCFunction)PyOp($_self)::py_init_proxy,
METH_VARARGS | METH_KEYWORDS,
"$0"
};
)",
&ctx, docstring);
}
Initproc OpDefEmitter::emit_initproc() {
std::string initproc = formatv("_init_py_{0}", op.getCppClassName());
std::string subclass_init_call;
......@@ -460,6 +759,10 @@ void $0(py::module m) {
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
py_type.tp_dict = PyDict_New();
PyObject* descr = PyDescr_NewMethod(&PyOpType($_self), &PyOp($_self)::py_init_methoddef);
PyDict_SetItemString(py_type.tp_dict, "__init__", descr);
mgb_assert(PyType_Ready(&py_type) >= 0);
$1
PyType_Modified(&py_type);
......@@ -486,4 +789,4 @@ bool gen_op_def_python_c_extension(raw_ostream& os, llvm::RecordKeeper& keeper)
os << "\n";
return false;
}
} // namespace mlir::tblgen
\ No newline at end of file
} // namespace mlir::tblgen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册