#!/usr/bin/env python3 # -*- coding: utf-8 -*- import sys import re if sys.version_info[0] != 3 or sys.version_info[1] < 5: print('This script requires Python version 3.5') sys.exit(1) import argparse import json import os import subprocess import tempfile from pathlib import Path MIDOUT_TRACE_MAGIC = 'midout_trace v1\n' class HeaderGen: _dtypes = None _oprs = None _fout = None _elemwise_modes = None _has_netinfo = False _midout_files = None _file_without_hash = False def __init__(self): self._dtypes = set() self._oprs = set() self._elemwise_modes = set() self._graph_hashes = set() self._midout_files = [] _megvii3_root_cache = None @classmethod def get_megvii3_root(cls): if cls._megvii3_root_cache is not None: return cls._megvii3_root_cache wd = Path(__file__).resolve().parent while wd.parent != wd: workspace_file = wd / 'WORKSPACE' if workspace_file.is_file(): cls._megvii3_root_cache = str(wd) return cls._megvii3_root_cache wd = wd.parent raise RuntimeError('This script is supposed to run in megvii3.') def extend_netinfo(self, data): self._has_netinfo = True if 'hash' not in data: self._file_without_hash = True else: self._graph_hashes.add(str(data['hash'])) for i in data['dtypes']: self._dtypes.add(i) for i in data['opr_types']: self._oprs.add(i) for i in data['elemwise_modes']: self._elemwise_modes.add(i) def extend_midout(self, fname): self._midout_files.append(fname) def generate(self, fout): self._fout = fout self._write_def('MGB_BINREDUCE_VERSION', '20190219') if self._has_netinfo: self._write_dtype() self._write_elemwise_modes() self._write_oprs() self._write_hash() self._write_midout() del self._fout def strip_opr_name_with_version(self, name): pos = len(name) t = re.search(r'V\d+$', name) if t: pos = t.start() return name[:pos] def _write_oprs(self): defs = ['}', 'namespace opr {'] already_declare = set() already_instance = set() for i in self._oprs: i = self.strip_opr_name_with_version(i) if i in already_declare: continue else: already_declare.add(i) defs.append('class {};'.format(i)) defs.append('}') defs.append('namespace serialization {') defs.append(""" template struct OprRegistryCaller { }; """) for i in sorted(self._oprs): i = self.strip_opr_name_with_version(i) if i in already_instance: continue else: already_instance.add(i) defs.append(""" template struct OprRegistryCaller: public OprRegistryCallerDefaultImpl {{ }}; """.format(i)) self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs) def _write_elemwise_modes(self): with tempfile.NamedTemporaryFile() as ftmp: fpath = os.path.realpath(ftmp.name) subprocess.check_call( ['./brain/megbrain/dnn/scripts/gen_param_defs.py', '--write-enum-items', 'Elemwise:Mode', './brain/megbrain/dnn/scripts/opr_param_defs.py', fpath], cwd=self.get_megvii3_root() ) with open(fpath) as fin: mode_list = [i.strip() for i in fin] for i in mode_list: if i in self._elemwise_modes: content = '_cb({})'.format(i) else: content = '' self._write_def( '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content) self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') def _write_dtype(self): if 'Float16' not in self._dtypes: # MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16 # support in the past; however `FLOT16' is really a typo. We plan to # change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon. # To prevent issues in the transition, we decide to define both # macros (`FLOT16' and `FLOAT16') here. # # In the future when the situation is settled and no one would ever # use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be # safely deleted. self._write_def('MEGDNN_DISABLE_FLOT16', 1) self._write_def('MEGDNN_DISABLE_FLOAT16', 1) def _write_hash(self): if self._file_without_hash: print('WARNING: network info has no graph hash. Using json file ' 'generated by MegBrain >= 7.28.0 is recommended') else: defs = 'ULL,'.join(self._graph_hashes) + 'ULL' self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs) def _write_def(self, name, val): if isinstance(val, list): val = '\n'.join(val) val = str(val).strip().replace('\n', ' \\\n') self._fout.write('#define {} {}\n'.format(name, val)) def _write_midout(self): if not self._midout_files: return gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout', 'gen_header.py') cvt = subprocess.run( [gen] + self._midout_files, stdout=subprocess.PIPE, check=True, ).stdout.decode('utf-8') self._fout.write('// midout \n') self._fout.write(cvt) def main(): parser = argparse.ArgumentParser( description='generate header file for reducing binary size by ' 'stripping unused oprs in a particular network; output file would ' 'be written to bin_reduce.h', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( 'inputs', nargs='+', help='input files that describe specific traits of the network; ' 'can be one of the following:' ' 1. json files generated by ' 'megbrain.serialize_comp_graph_to_file() in python; ' ' 2. trace files generated by midout library') parser.add_argument('-o', '--output', help='output file', default=os.path.join(HeaderGen.get_megvii3_root(), 'utils', 'bin_reduce.h')) args = parser.parse_args() gen = HeaderGen() for i in args.inputs: print('==== processing {}'.format(i)) with open(i) as fin: if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC: gen.extend_midout(i) else: fin.seek(0) gen.extend_netinfo(json.loads(fin.read())) with open(args.output, 'w') as fout: gen.generate(fout) if __name__ == '__main__': main()