#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import argparse import itertools from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES def generate(modes, support_dtypes, output, cpp_ext): for anum, ctype in itertools.product(modes.keys(), support_dtypes): print('{} : {}'.format(anum, ctype)) src_ctype = ctype[0] dst_ctype = ctype[1] for mode in modes[anum]: formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext) fname = os.path.join(output, fname) with open(fname, 'w') as fout: w = lambda s: print(s, file=fout) w('// generated by gen_elemwise_multi_type_kern_impls.py') w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) w('#define KERN_IMPL_ARITY {}'.format(anum)) w('#define KERN_IMPL_STYPE {}'.format(src_ctype)) w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype)) w('#include "../kern_impl.inl"') print('generated {}'.format(fname)) def main(): parser = argparse.ArgumentParser( description='generate elemwise impl files', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--type', type=str, choices=['cuda'], default='cuda', help='generate cuda kernel file') parser.add_argument('output', help='output directory') args = parser.parse_args() if not os.path.isdir(args.output): os.makedirs(args.output) assert args.type == 'cuda' if args.type == 'cuda': cpp_ext = 'cu' generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) os.utime(args.output) if __name__ == '__main__': main()