gen_header.py 2.5 KB
Newer Older
J
jiakai 已提交
1 2
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
王博文 已提交
3 4 5 6 7 8 9
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
J
jiakai 已提交
10 11

import argparse
J
jiakai 已提交
12
import sys
J
jiakai 已提交
13
import subprocess
J
jiakai 已提交
14

J
jiakai 已提交
15
MAGIC = 'midout_trace v1\n'
J
jiakai 已提交
16 17 18 19 20 21 22 23 24 25 26

class MidoutHeaderGen:
    _tag_names = None
    _region_names = None

    def __init__(self):
        self._tag_names = set()
        self._region_names = set()

    def add_item(self, name: str):
        prefix = 'midout::Region<midout::tags::'
J
jiakai 已提交
27
        assert name.startswith(prefix), 'bad name: {!r}'.format(name)
J
jiakai 已提交
28 29 30 31 32
        comma = name.find(',', len(prefix))
        self._tag_names.add(name[len(prefix):comma])
        self._region_names.add(name)

    def write(self, fout):
J
jiakai 已提交
33 34
        print('#define MIDOUT_GENERATED \\', file=fout)
        print('namespace midout { namespace tags { \\', file=fout)
J
jiakai 已提交
35
        for i in sorted(self._tag_names):
J
jiakai 已提交
36 37
            print('class {}; \\'.format(i), file=fout)
        print('} \\', file=fout)
J
jiakai 已提交
38 39 40

        for i in self._region_names:
            i = i.replace('midout::', '')
J
jiakai 已提交
41
            i = i.replace('__ndk1::', '')
J
jiakai 已提交
42 43
            print('template<> \\', file=fout)
            print('struct {} {{ static constexpr bool enable = true; }}; \\'.
J
jiakai 已提交
44 45 46 47 48 49 50 51 52
                  format(i), file=fout)

        print('}', file=fout)


def main():
    parser = argparse.ArgumentParser(
        description='generate header file from midout traces',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
J
jiakai 已提交
53 54
    parser.add_argument('-o', '--output',
                        help='output header file; default to stdout')
J
jiakai 已提交
55 56 57 58 59
    parser.add_argument('inputs', nargs='+', help='input trace files')
    args = parser.parse_args()
    gen = MidoutHeaderGen()
    for i in args.inputs:
        with open(i) as fin:
J
jiakai 已提交
60
            assert fin.read(len(MAGIC)) == MAGIC, 'bad trace file'
J
jiakai 已提交
61 62 63 64 65 66
            demangle = subprocess.check_output(
                ['c++filt', '-t'], input='\n'.join(list(fin)).encode('utf-8'))
            for line in demangle.decode('utf-8').split('\n'):
                line = line.strip()
                if line:
                    gen.add_item(line)
J
jiakai 已提交
67

J
jiakai 已提交
68 69 70 71 72
    if not args.output:
        gen.write(sys.stdout)
    else:
        with open(args.output, 'w') as fout:
            gen.write(fout)
J
jiakai 已提交
73 74 75

if __name__ == '__main__':
    main()