# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse import collections import textwrap import os import hashlib import struct class member_defs: """contain classes to define members of an opr param""" Dtype = collections.namedtuple('Dtype', ['cname', 'pycvt', 'pyfmt', 'cppjson', 'cname_attr']) Dtype.__new__.__defaults__ = ('', ) uint32 = Dtype('uint32_t', 'int', 'I', 'NumberInt') uint64 = Dtype('uint64_t', 'int', 'Q', 'NumberInt', 'alignas(sizeof(uint64_t)) ') int32 = Dtype('int32_t', 'int', 'i', 'NumberInt') float32 = Dtype('float', 'float', 'f', 'Number') float64 = Dtype('double', 'float', 'd', 'Number') dtype = Dtype('DTypeEnum', '_as_dtype_num', 'I', 'Number') bool = Dtype('bool', 'bool', '?', 'Bool') class Base: pass class Doc: """wrap an identifier to associate document note: if the doc starts with a linebreak, it would not be reforamtted. """ __slots__ = ['id', 'doc'] def __init__(self, id_, doc): assert isinstance(id_, str) and isinstance(doc, str), (id_, doc) self.id = id_ self.doc = doc @property def no_reformat(self): """whether reformat is disallowed for this doc string""" return self.doc.startswith('\n') @property def raw_lines(self): """the doc lines when ``no_format`` is true""" ret = self.doc.split('\n') assert not ret[0] return ret[1:] @classmethod def make(cls, v): """make doc object from str or doc""" if isinstance(v, cls): return v assert isinstance(v, str) return cls(v, '') def __str__(self): return self.id def __eq__(self, rhs): if isinstance(rhs, str): return self.id == rhs return (isinstance(rhs, Doc) and (self.id, self.doc) == (rhs.id, rhs.doc)) class Enum(Base): """define an enum; the result would contain both an enum class def and its corresponding data field :param default: index of default member value :attr name_field: name of the data field of this enum in the param struct :attr member_alias: list of (member, alias) pairs """ __slots__ = ['name', 'name_field', 'members', 'default', 'member_alias'] all_enums = {} """(param_name, name) => enum""" def __init__(self, param_name, name, name_field, members, default, member_alias): name = member_defs.Doc.make(name) assert name.id[0].isupper() members = tuple(map(member_defs.Doc.make, members)) if isinstance(default, str): if default not in name_field: raise ValueError( "Default value '{}' does not exist.".format(default)) default = name_field.index(default) assert isinstance(default, int) self.name = name self.name_field = self.get_name_field(name.id, name_field) self.members = members self.default = default self.all_enums[(param_name, name.id)] = self assert isinstance(member_alias, list) self.member_alias = member_alias @classmethod def get_name_field(cls, name, name_field): if name_field is None: name_field = name[0].lower() + name[1:] assert isinstance(name_field, str) return name_field class Field(Base): """define a normal data field""" __slots__ = ['name', 'dtype', 'default'] def __init__(self, name, dtype, default): assert isinstance(dtype, member_defs.Dtype) self.name = member_defs.Doc.make(name) self.dtype = dtype self.default = default class Const(Base): """define a const data field""" __slots__ = ['name', 'dtype', 'default'] def __init__(self, name, dtype, default): assert isinstance(dtype, member_defs.Dtype) self.name = member_defs.Doc.make(name) self.dtype = dtype self.default = default class EnumAlias(Base): """alias of enum type from another param""" __slots__ = ['name', 'name_field', 'src_class', 'src_name', 'default'] def __init__(self, name, name_field, src_class, src_name, default): self.name = name self.name_field = member_defs.Enum.get_name_field(name, name_field) self.src_class = src_class if src_name is None: src_name = name self.src_name = src_name self.default = default @property def src_enum(self): """source Enum class""" return member_defs.Enum.all_enums[(self.src_class, self.src_name)] def get_default(self): """get default index; fallback to src index if default is not set""" if self.default is None: return self.src_enum.default return self.default class ParamDef: """""" __all_tags = set() all_param_defs = [] __slots__ = ['name', 'members', 'tag', 'is_legacy'] def __init__(self, name, doc='', *, version=0, is_legacy=False): self.members = [] self.all_param_defs.append(self) h = hashlib.sha256(name.encode('utf-8')) if version: h.update(struct.pack(' 0: self._indent() class PyWriter(IndentWriterBase): _static_members = None _non_static_members = None _enums = None _enum_map = None def __call__(self, fout, defs): super().__call__(fout) self._enum_map = {} self._write('// %s', self._get_header()) self._write('#include "megbrain/imperative/opdef/all.h"') self._write('') self._write('using namespace mgb::imperative;') self._write('') self._process(defs) def _on_param_begin(self, p): self._enums = [] self._non_static_members = [] self._static_members = [] def _reg_enum_single(self, cur_def, e): alias = None if isinstance(e, member_defs.Enum): src = e else: assert isinstance(e, member_defs.EnumAlias) src = e.src_enum alias = e src_py_name = self._enum_map.get(src, None) if src_py_name is not None: py_name = '{}{}Enum'.format(cur_def, src.name if alias is None else alias.name) self._write('m.attr("{}") = m.attr("{}");\n'.format(py_name, src_py_name)) return if alias is None: enum_name = str(src.name) else: enum_name = str(alias.name) c_name = 'opdef::{}::{}'.format(cur_def, enum_name) py_name = '{}{}Enum'.format(cur_def, enum_name) self._write('py::enum_<{}>(m, "{}")'.format(c_name, py_name), indent=1) for i in src.members: self._write('.value("{0}", {1}::{0})'.format(i, c_name)) self._write(';\n', indent=-1) self._enum_map[src] = py_name def _on_param_end(self, p): cur_def = '{}Def'.format(p.name) for e in self._enums: self._reg_enum_single(cur_def, e) self._write('py::class_(m, "{0}")'.format(cur_def), indent=1) # TODO: use ctor with given default value self._write('.def(py::init<>())') for i in self._static_members: assert isinstance(i, member_defs.Const) self._write('.def_property_readonly_static("{0}", []() {{ return opdef::{1}::{0}; }})'.format(i.name, cur_def)) for i in self._non_static_members: fname = None if isinstance(i, member_defs.Field): fname = i.name else: assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias)) fname = i.name_field self._write('.def_readwrite("{0}", &opdef::{1}::{0})'.format(fname, cur_def)) self._write(';\n', indent=-1) def _on_member_enum(self, e,): self._enums.append(e) self._non_static_members.append(e) def _on_member_enum_alias(self, e): self._enums.append(e) self._non_static_members.append(e) def _on_member_field(self, f): self._non_static_members.append(f) def _on_const_field(self, f): self._static_members.append(f) class CPPWriter(IndentWriterBase): _param_namespace = 'opdef' _ctor_args = None """list of (text in func param, var name); func param name must be var name appended by an underscore""" _non_static_members = None def __call__(self, fout, defs): super().__call__(fout) self._write('// %s', self._get_header()) self._write('#pragma once') self._write('#include "megdnn.h"') # which defined in megbrain/tools/param_defs/mgb_opr_param_defs.py self._write('#include "megbrain/opr/param_defs.h"') self._write('#include ') self._write('namespace mgb {') self._write('namespace imperative {') self._write('namespace %s {', self._param_namespace) self._write('namespace {') self._write('#include "megdnn/dtype.h"') self._write('using DTypeEnum = megdnn::DTypeEnum;') self._write('} // anonymous namespace') self._process(defs) self._write('} // namespace %s', self._param_namespace) self._write('} // namespace imperative') self._write('} // namespace mgb') self._write('// vim: syntax=cpp.doxygen') def _on_param_begin(self, p): self._write('struct %sDef {', p.name, indent=1) self._ctor_args = [] self._non_static_members = [] def _add_ctor_args(self, typename, default, varname): self._ctor_args.append(( '{} {}_={}'.format(typename, varname, default), varname)) def _on_param_end(self, p): ''' MegDNN param structures are not packed and we need to initialize the structure paddings to zero or it would break MegBrain hash system. We do memset(0) in default ctor and use a trick, wrapping non-static members in a anonymous union which would copy the object representation in its default copy/move ctor, for copy/move ctor. > The implicitly-defined copy/move constructor for a non-union class X performs > a memberwise copy/move of its bases and members. [class.copy.ctor 14] > The implicitly-defined copy/move constructor for a union X copies the object > representation (6.9) of X. [class.copy.ctor 15] ''' if self._non_static_members: self._write('union { struct {') for i in self._non_static_members: if isinstance(i, member_defs.Field): self._write('%s%s %s;', i.dtype.cname_attr, i.dtype.cname, i.name) else: assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias)) self._write('%s %s;', i.name, i.name_field) self._write('}; };') param_list = [] if self._ctor_args: pdefs, varnames = zip(*self._ctor_args) self._write('%sDef(%s) {', p.name, ', '.join(pdefs), indent=1) self._write('memset(this, 0, sizeof(*this));') for var in varnames: self._write('this->%s = %s_;', var, var) param_list.append(str(var)) self._write('}', indent=-1) self._write('megdnn::param::%s param() {', self._cur_class, indent=1) self._write('return {%s};', ','.join(param_list)) self._write('}', indent=-1) self._write('};\n', indent=-1) def __on_member_enum(self, e, default_value): self._write('using %s = megdnn::param::%s::%s;', e.name, self._cur_class, e.name) self._non_static_members.append(e) self._add_ctor_args(e.name, default_value, e.name_field) def _on_member_enum(self, e,): self.__on_member_enum(e, '{}::{}'.format(e.name, e.members[e.default])) def _on_member_enum_alias(self, e): self.__on_member_enum(e, '{}::{}'.format(e.name, e.src_enum.members[e.get_default()])) def _on_member_field(self, f): self._non_static_members.append(f) self._add_ctor_args(f.dtype.cname, f.default, f.name) def _on_const_field(self, f): if 'int' in f.dtype.cname: self._write('static constexpr %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default) else: self._write('static const %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default) def main(): parser = argparse.ArgumentParser( 'generate opr param defs from description file') parser.add_argument('-t', '--type', choices=['c++', 'py'], default='c++', help='output type') parser.add_argument('input') parser.add_argument('output') args = parser.parse_args() with open(args.input) as fin: inputs = fin.read() exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) input_hash = hashlib.sha256() input_hash.update(inputs.encode(encoding='UTF-8')) input_hash = input_hash.hexdigest() if args.type == 'py': writer = PyWriter() else: writer = CPPWriter() with open(args.output, 'w') as fout: writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) if __name__ == '__main__': main()