gen_elemwise_multi_type_kern_impls.py 1.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import argparse
import itertools
from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES

def generate(modes, support_dtypes, output, cpp_ext):
    for anum, ctype in itertools.product(modes.keys(), support_dtypes):
        print('{} : {}'.format(anum, ctype))
        src_ctype = ctype[0]
        dst_ctype = ctype[1]
        for mode in modes[anum]:
            formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode)
            fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext)
            fname = os.path.join(output, fname)
            with open(fname, 'w') as fout:
                w = lambda s: print(s, file=fout)
                w('// generated by gen_elemwise_multi_type_kern_impls.py')

                w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
                w('#define KERN_IMPL_ARITY {}'.format(anum))
                w('#define KERN_IMPL_STYPE {}'.format(src_ctype))
                w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype))
                w('#include "../kern_impl.inl"')

            print('generated {}'.format(fname))


def main():
    parser = argparse.ArgumentParser(
        description='generate elemwise impl files',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--type', type=str, choices=['cuda'],
                        default='cuda', help='generate cuda kernel file')
    parser.add_argument('output', help='output directory')
    args = parser.parse_args()

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    assert args.type == 'cuda'
    if args.type == 'cuda':
        cpp_ext = 'cu'

    generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext)
    generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext)
    os.utime(args.output)

if __name__ == '__main__':
    main()