gen_tablegen.py 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io

from gen_param_defs import member_defs, ParamDef, IndentWriterBase

14 15 16 17 18
# FIXME: move supportToString flag definition into the param def source file
ENUM_TO_STRING_SPECIAL_RULES = [
    ("Elemwise", "Mode"),
    ("ElemwiseMultiType", "Mode")
]
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

class ConverterWriter(IndentWriterBase):
    _skip_current_param = False
    _last_param = None
    _current_tparams = None
    _packed = None
    _const = None

    def __call__(self, fout, defs):
        super().__call__(fout)
        self._write("// %s", self._get_header())
        self._write("#ifndef MGB_PARAM")
        self._write("#define MGB_PARAM")
        self._process(defs)
        self._write("#endif // MGB_PARAM")

    def _ctype2attr(self, ctype, value):
        if ctype == 'uint32_t':
            return 'MgbUI32Attr', value
        if ctype == 'uint64_t':
            return 'MgbUI64Attr', value
        if ctype == 'int32_t':
            return 'MgbI32Attr', value
        if ctype == 'float':
            return 'MgbF32Attr', value
        if ctype == 'double':
            return 'MgbF64Attr', value
        if ctype == 'bool':
            return 'MgbBoolAttr', value
        if ctype == 'DTypeEnum':
            self._packed = False
            return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value)
        raise RuntimeError("unknown ctype")

    def _on_param_begin(self, p):
        self._last_param = p
        if p.is_legacy:
            self._skip_current_param = True
            return
        self._packed = True
        self._current_tparams = []
        self._const = set()

    def _on_param_end(self, p):
        if self._skip_current_param:
            self._skip_current_param = False
            return
        if self._packed:
            self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1)
        else:
            self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1)
        self._write("let fields = (ins", indent=1)
        self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
        self._write(");", indent=-1)
        self._write("}\n", indent=-1)
        if self._packed:
            self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name))
        self._current_tparams = None
        self._packed = None
        self._const = None

    def _wrapped_with_default_value(self, attr, default):
        return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default)

    def _on_member_enum(self, e):
        p = self._last_param

        # Note: always generate llvm Record def for enum attribute even it was not
        # directly used by any operator, or other enum couldn't alias to this enum
        td_class = "{}{}".format(p.name, e.name)
        fullname = "::megdnn::param::{}".format(p.name)
        enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
        def format(v):
            return '\"{}\"'.format(str(v))
        enum_def += ','.join(format(i) for i in e.members)
94 95 96 97 98 99

        if e.combined:
            enum_def += "], 1"
        else:
            enum_def += "], 0"

100 101 102
        if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
            enum_def += ", 1" # whether generate ToStringTrait
        enum_def += ">"
103

104
        self._write("def {} : {};".format(td_class, enum_def))
105 106 107 108
        if self._skip_current_param:
            return

        # wrapped with default value
109 110 111 112 113 114
        if e.combined:
            default_val = "static_cast<{}::{}>({})".format(
                    fullname, e.name, e.compose_combined_enum(e.default))
        else:
            default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default])

115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        wrapped = self._wrapped_with_default_value(td_class, default_val)

        self._current_tparams.append("{}:${}".format(wrapped, e.name_field))

    def _on_member_enum_alias(self, e):
        p = self._last_param
        if self._skip_current_param:
            return

        # write enum attr def
        td_class = "{}{}".format(p.name, e.name)
        fullname = "::megdnn::param::{}".format(p.name)
        base_td_class = "{}{}".format(e.src_class, e.src_name)
        enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class)
        self._write("def {} : {};".format(td_class, enum_def))

        # wrapped with default value
132 133 134 135 136 137 138
        s = e.src_enum
        if s.combined:
            default_val = "static_cast<{}::{}>({})".format(
                    fullname, e.name, s.compose_combined_enum(e.get_default()))
        else:
            default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()])

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        wrapped = self._wrapped_with_default_value(td_class, default_val)

        self._current_tparams.append("{}:${}".format(wrapped, e.name_field))


    def _on_member_field(self, f):
        if self._skip_current_param:
            return
        attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
        if str(value) in self._const:
            value = '::megdnn::param::{}::{}'.format(self._last_param.name, value)
        wrapped = self._wrapped_with_default_value(attr, value)
        self._current_tparams.append("{}:${}".format(wrapped, f.name))

    def _on_const_field(self, f):
        self._const.add(str(f.name))

def main():
    parser = argparse.ArgumentParser('generate op param tablegen file')
    parser.add_argument('input')
    parser.add_argument('output')
    args = parser.parse_args()

    with open(args.input) as fin:
        inputs = fin.read()
        exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
        input_hash = hashlib.sha256()
        input_hash.update(inputs.encode(encoding='UTF-8'))
        input_hash = input_hash.hexdigest()

    writer = ConverterWriter()
    with open(args.output, 'w') as fout:
        writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)

if __name__ == "__main__":
    main()