gen_tablegen.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io

from gen_param_defs import member_defs, ParamDef, IndentWriterBase

14 15 16 17 18
# FIXME: move supportToString flag definition into the param def source file
ENUM_TO_STRING_SPECIAL_RULES = [
    ("Elemwise", "Mode"),
    ("ElemwiseMultiType", "Mode")
]
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

class ConverterWriter(IndentWriterBase):
    _skip_current_param = False
    _last_param = None
    _current_tparams = None
    _packed = None
    _const = None

    def __call__(self, fout, defs):
        super().__call__(fout)
        self._write("// %s", self._get_header())
        self._write("#ifndef MGB_PARAM")
        self._write("#define MGB_PARAM")
        self._process(defs)
        self._write("#endif // MGB_PARAM")

    def _ctype2attr(self, ctype, value):
        if ctype == 'uint32_t':
            return 'MgbUI32Attr', value
        if ctype == 'uint64_t':
            return 'MgbUI64Attr', value
        if ctype == 'int32_t':
            return 'MgbI32Attr', value
        if ctype == 'float':
            return 'MgbF32Attr', value
        if ctype == 'double':
            return 'MgbF64Attr', value
        if ctype == 'bool':
            return 'MgbBoolAttr', value
        if ctype == 'DTypeEnum':
            self._packed = False
            return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value)
        raise RuntimeError("unknown ctype")

    def _on_param_begin(self, p):
        self._last_param = p
        if p.is_legacy:
            self._skip_current_param = True
            return
        self._packed = True
        self._current_tparams = []
        self._const = set()

    def _on_param_end(self, p):
        if self._skip_current_param:
            self._skip_current_param = False
            return
        if self._packed:
            self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1)
        else:
            self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1)
        self._write("let fields = (ins", indent=1)
        self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
        self._write(");", indent=-1)
        self._write("}\n", indent=-1)
        if self._packed:
            self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name))
        self._current_tparams = None
        self._packed = None
        self._const = None

    def _wrapped_with_default_value(self, attr, default):
        return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default)

    def _on_member_enum(self, e):
        p = self._last_param

        # Note: always generate llvm Record def for enum attribute even it was not
        # directly used by any operator, or other enum couldn't alias to this enum
        td_class = "{}{}".format(p.name, e.name)
        fullname = "::megdnn::param::{}".format(p.name)
        enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
        def format(v):
            return '\"{}\"'.format(str(v))
        enum_def += ','.join(format(i) for i in e.members)
94 95 96 97
        enum_def += "]"
        if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
            enum_def += ", 1" # whether generate ToStringTrait
        enum_def += ">"
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
        self._write("def {} : {};".format(td_class, enum_def))

        if self._skip_current_param:
            return

        # wrapped with default value
        default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default)
        wrapped = self._wrapped_with_default_value(td_class, default_val)

        self._current_tparams.append("{}:${}".format(wrapped, e.name_field))

    def _on_member_enum_alias(self, e):
        p = self._last_param
        if self._skip_current_param:
            return

        # write enum attr def
        td_class = "{}{}".format(p.name, e.name)
        fullname = "::megdnn::param::{}".format(p.name)
        base_td_class = "{}{}".format(e.src_class, e.src_name)
        enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class)
        self._write("def {} : {};".format(td_class, enum_def))

        # wrapped with default value
        default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default())
        wrapped = self._wrapped_with_default_value(td_class, default_val)

        self._current_tparams.append("{}:${}".format(wrapped, e.name_field))


    def _on_member_field(self, f):
        if self._skip_current_param:
            return
        attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
        if str(value) in self._const:
            value = '::megdnn::param::{}::{}'.format(self._last_param.name, value)
        wrapped = self._wrapped_with_default_value(attr, value)
        self._current_tparams.append("{}:${}".format(wrapped, f.name))

    def _on_const_field(self, f):
        self._const.add(str(f.name))

def main():
    parser = argparse.ArgumentParser('generate op param tablegen file')
    parser.add_argument('input')
    parser.add_argument('output')
    args = parser.parse_args()

    with open(args.input) as fin:
        inputs = fin.read()
        exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
        input_hash = hashlib.sha256()
        input_hash.update(inputs.encode(encoding='UTF-8'))
        input_hash = input_hash.hexdigest()

    writer = ConverterWriter()
    with open(args.output, 'w') as fout:
        writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)

if __name__ == "__main__":
    main()