提交 421bcfd3 编写于 作者: M Megvii Engine Team

style(mgb/tools): add format for tools, dnn and ci

GitOrigin-RevId: 5684e5ea43385ae7f3802c5c2e779362a030d643
上级 99309fa3
......@@ -8,7 +8,7 @@
import enum
import os.path
import shutil
from typing import Tuple, List
from typing import List, Tuple
from library import *
......
......@@ -5,14 +5,13 @@
#
import enum
import os.path
import shutil
import functools
import operator
import os.path
import shutil
from library import *
###################################################################################################
#
# Data structure modeling a GEMM operation
......
from generator import (
GenerateGemmOperations,
GenerateGemvOperations,
from generator import ( # isort: skip; isort: skip
GenerateConv2dOperations,
GenerateDeconvOperations,
GenerateDwconv2dFpropOperations,
GenerateDwconv2dDgradOperations,
GenerateDwconv2dFpropOperations,
GenerateDwconv2dWgradOperations,
GenerateGemmOperations,
GenerateGemvOperations,
)
......@@ -35,12 +35,14 @@ def write_op_list(f, gen_op, gen_type):
if gen_op != "gemv":
f.write(' "all_%s_%s_operations.cu",\n' % (gen_op, gen_type))
# Write down a list of merged filenames
def write_merge_file_name(f, gen_op, gen_type, split_number):
for i in range(0, split_number):
f.write(' "{}_{}_{}.cu",\n'.format(gen_op,gen_type,i))
f.write(' "{}_{}_{}.cu",\n'.format(gen_op, gen_type, i))
if gen_op != "gemv":
f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op,gen_type))
f.write(' "all_{}_{}_operations.cu",\n'.format(gen_op, gen_type))
if __name__ == "__main__":
with open("list.bzl", "w") as f:
......
......@@ -4,11 +4,11 @@
# \brief Generates the CUTLASS Library's instances
#
import enum
import re
###################################################################################################
import enum
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
# as the default 3.5.2 on Ubuntu 16.04.
......
......@@ -8,9 +8,9 @@ import enum
import os.path
import shutil
from library import *
from gemm_operation import *
from conv2d_operation import *
from gemm_operation import *
from library import *
###################################################################################################
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import os
from gen_elemwise_utils import DTYPES
def main():
parser = argparse.ArgumentParser(
description='generate elemwise impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['cuda'],
default='cuda',
help='generate cuda cond take kernel file')
parser.add_argument('output', help='output directory')
description="generate elemwise impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["cuda"],
default="cuda",
help="generate cuda cond take kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args()
if not os.path.isdir(args.output):
os.makedirs(args.output)
assert args.type =='cuda'
cpp_ext = 'cu'
assert args.type == "cuda"
cpp_ext = "cu"
for dtype in DTYPES.keys():
fname = '{}.{}'.format(dtype, cpp_ext)
fname = "{}.{}".format(dtype, cpp_ext)
fname = os.path.join(args.output, fname)
with open(fname, 'w') as fout:
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout)
w('// generated by gen_cond_take_kern_impls.py')
w("// generated by gen_cond_take_kern_impls.py")
w('#include "../kern.inl"')
w('')
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w('namespace megdnn {')
w('namespace cuda {')
w('namespace cond_take {')
w('')
w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef inst_genidx')
w('')
w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef inst_copy')
w('#undef inst_copy_')
w('')
w('} // cond_take')
w('} // cuda')
w('} // megdnn')
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif')
print('generated {}'.format(fname))
w("")
if dtype == "dt_float16" or dtype == "dt_bfloat16":
w("#if !MEGDNN_DISABLE_FLOAT16")
w("namespace megdnn {")
w("namespace cuda {")
w("namespace cond_take {")
w("")
w("inst_genidx(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
w("#undef inst_genidx")
w("")
w("inst_copy(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
w("#undef inst_copy")
w("#undef inst_copy_")
w("")
w("} // cond_take")
w("} // cuda")
w("} // megdnn")
if dtype == "dt_float16" or dtype == "dt_bfloat16":
w("#endif")
print("generated {}".format(fname))
os.utime(args.output)
if __name__ == '__main__':
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import itertools
import os
PREFIXES = {
"dp4a": [
("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True),
("batch_conv_bias_int8_gemm_ncdiv4hw4", False),
("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False),
]
}
PREFIXES = {"dp4a": [("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False)]}
ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}
ACTIVATIONS = {1: ("IDENTITY", "_id"),
2: ("RELU", "_relu"),
3: ("H_SWISH", "_hswish")}
BIASES = {
1: ("PerElementBiasVisitor", "_per_elem"),
2: ("PerChannelBiasVisitor", "_per_chan"),
}
BIASES = {1: ("PerElementBiasVisitor", "_per_elem"),
2: ("PerChannelBiasVisitor", "_per_chan")}
SUFFIXES = {"dp4a": [""], "imma": [""]}
SUFFIXES = {"dp4a": [""],
"imma": [""]}
def main():
parser = argparse.ArgumentParser(
description='generate cuda batch conv bias (dp4a/imma) kern impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['dp4a',
'imma'],
default='dp4a', help='generate cuda conv bias kernel file')
parser.add_argument('output', help='output directory')
description="generate cuda batch conv bias (dp4a/imma) kern impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["dp4a", "imma"],
default="dp4a",
help="generate cuda conv bias kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args()
if not os.path.isdir(args.output):
os.makedirs(args.output)
inst = '''
inst = """
template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS,
IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>(
const int8_t* d_src,
......@@ -41,7 +51,7 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS,
const ConvParam& param,
float alpha,
float beta,
cudaStream_t stream);'''
cudaStream_t stream);"""
for prefix in PREFIXES[args.type]:
for suffix in SUFFIXES[args.type]:
......@@ -52,17 +62,23 @@ template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS,
fname = os.path.join(args.output, fname)
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout)
w('// generated by gen_batch_cuda_conv_bias_kern_impls.py')
cur_inst = inst.replace("PREFIX", prefix[0]).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0])
w("// generated by gen_batch_cuda_conv_bias_kern_impls.py")
cur_inst = (
inst.replace("PREFIX", prefix[0])
.replace("SUFFIX", suffix)
.replace("BIAS", bias[0])
.replace("ACTIVATION", act[0])
)
if has_workspace:
cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ")
else:
cur_inst = cur_inst.replace("WORKSPACE", "")
cur_inst = cur_inst.replace("WORKSPACE", "")
w('#include "../{}{}.cuinl"'.format(prefix[0], suffix))
w(cur_inst)
print('generated {}'.format(fname))
print("generated {}".format(fname))
os.utime(args.output)
if __name__ == '__main__':
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import itertools
import os
PREFIXES = {
"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4",
"imma": "conv_bias_int8_implicit_gemm",
}
PREFIXES = {"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm"}
ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}
ACTIVATIONS = {1: ("IDENTITY", "_id"),
2: ("RELU", "_relu"),
3: ("H_SWISH", "_hswish")}
BIASES = {
1: ("PerElementBiasVisitor", "_per_elem"),
2: ("PerChannelBiasVisitor", "_per_chan"),
}
BIASES = {1: ("PerElementBiasVisitor", "_per_elem"),
2: ("PerChannelBiasVisitor", "_per_chan")}
SUFFIXES = {
"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"],
"imma": [
"_imma16x16x16_cdiv4hwn4",
"_imma8x32x16_cdiv4hwn4",
"_imma32x8x16_cdiv4hwn4",
"_imma16x16x16_cdiv4hwn4_reorder_filter",
"_imma8x32x16_cdiv4hwn4_reorder_filter",
"_imma32x8x16_cdiv4hwn4_reorder_filter",
"_imma16x16x16_cdiv4hwn4_unroll_width",
"_imma8x32x16_cdiv4hwn4_unroll_width",
"_imma32x8x16_cdiv4hwn4_unroll_width",
],
}
SUFFIXES = {"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"],
"imma": ["_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4",
"_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter",
"_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width"]}
def main():
parser = argparse.ArgumentParser(
description='generate cuda conv bias (dp4a/imma) kern impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['dp4a',
'imma'],
default='dp4a', help='generate cuda conv bias kernel file')
parser.add_argument('output', help='output directory')
description="generate cuda conv bias (dp4a/imma) kern impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["dp4a", "imma"],
default="dp4a",
help="generate cuda conv bias kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args()
if not os.path.isdir(args.output):
os.makedirs(args.output)
inst = '''
inst = """
template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS,
IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>(
const int8_t* d_src,
......@@ -43,7 +61,7 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS,
const ConvParam& param,
float alpha,
float beta,
cudaStream_t stream);'''
cudaStream_t stream);"""
for suffix in SUFFIXES[args.type]:
for _, act in ACTIVATIONS.items():
......@@ -53,13 +71,19 @@ template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS,
fname = os.path.join(args.output, fname)
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout)
w('// generated by gen_cuda_conv_bias_kern_impls.py')
cur_inst = inst.replace("PREFIX", prefix).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0])
w("// generated by gen_cuda_conv_bias_kern_impls.py")
cur_inst = (
inst.replace("PREFIX", prefix)
.replace("SUFFIX", suffix)
.replace("BIAS", bias[0])
.replace("ACTIVATION", act[0])
)
w('#include "../{}{}.cuinl"'.format(prefix, suffix))
w(cur_inst)
print('generated {}'.format(fname))
print("generated {}".format(fname))
os.utime(args.output)
if __name__ == '__main__':
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import os
from gen_elemwise_utils import ARITIES, MODES
def main():
parser = argparse.ArgumentParser(
description='generate elemwise each mode',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
description="generate elemwise each mode",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('output', help='output directory')
parser.add_argument("output", help="output directory")
args = parser.parse_args()
with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_each_mode.py')
w("// generated by gen_elemwise_each_mode.py")
keys = list(MODES.keys())
keys.sort()
for (anum, ctype) in keys:
w('#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\'.format(
ARITIES[anum], ctype))
w(
"#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\".format(
ARITIES[anum], ctype
)
)
for mode in MODES[(anum, ctype)]:
w(' MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\'.format(mode))
w('')
w(" MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\".format(mode))
w("")
print('generated each_mode.inl')
print("generated each_mode.inl")
os.utime(args.output)
if __name__ == '__main__':
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import itertools
import os
from gen_elemwise_utils import ARITIES, DTYPES, MODES
def main():
parser = argparse.ArgumentParser(
description='generate elemwise impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['cuda',
'hip',
'cpp'],
default='cpp', help='generate cuda/hip kernel file')
parser.add_argument('output', help='output directory')
description="generate elemwise impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["cuda", "hip", "cpp"],
default="cpp",
help="generate cuda/hip kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args()
if not os.path.isdir(args.output):
os.makedirs(args.output)
if args.type == 'cuda':
cpp_ext = 'cu'
elif args.type == 'hip':
cpp_ext = 'cpp.hip'
if args.type == "cuda":
cpp_ext = "cu"
elif args.type == "hip":
cpp_ext = "cpp.hip"
else:
assert args.type == 'cpp'
cpp_ext = 'cpp'
assert args.type == "cpp"
cpp_ext = "cpp"
for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()):
for mode in MODES[(anum, DTYPES[ctype][1])]:
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode)
fname = '{}_{}.{}'.format(mode, ctype, cpp_ext)
formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
fname = "{}_{}.{}".format(mode, ctype, cpp_ext)
fname = os.path.join(args.output, fname)
with open(fname, 'w') as fout:
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_kern_impls.py')
w("// generated by gen_elemwise_kern_impls.py")
if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
if ctype == "dt_float16" or ctype == "dt_bfloat16":
w("#if !MEGDNN_DISABLE_FLOAT16")
w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
w('#define KERN_IMPL_ARITY {}'.format(anum))
w('#define KERN_IMPL_CTYPE {}'.format(ctype))
w("#define KERN_IMPL_MODE(cb) {}".format(formode))
w("#define KERN_IMPL_ARITY {}".format(anum))
w("#define KERN_IMPL_CTYPE {}".format(ctype))
w('#include "../kern_impl.inl"')
if ctype == 'dt_float16' or ctype == 'dt_bfloat16':
w('#endif')
if ctype == "dt_float16" or ctype == "dt_bfloat16":
w("#endif")
print('generated {}'.format(fname))
print("generated {}".format(fname))
os.utime(args.output)
if __name__ == '__main__':
if __name__ == "__main__":
main()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import itertools
from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES
import os
from gen_elemwise_multi_type_utils import ( # isort: skip; isort: skip
MODES,
QINT32_MODES,
SUPPORT_DTYPES,
SUPPORT_QINT32_DTYPES,
)
def generate(modes, support_dtypes, output, cpp_ext):
for anum, ctype in itertools.product(modes.keys(), support_dtypes):
print('{} : {}'.format(anum, ctype))
print("{} : {}".format(anum, ctype))
src_ctype = ctype[0]
dst_ctype = ctype[1]
for mode in modes[anum]:
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode)
fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext)
formode = "MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)".format(mode)
fname = "{}_{}_{}.{}".format(mode, src_ctype, dst_ctype, cpp_ext)
fname = os.path.join(output, fname)
with open(fname, 'w') as fout:
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_multi_type_kern_impls.py')
w("// generated by gen_elemwise_multi_type_kern_impls.py")
w('#define KERN_IMPL_MODE(cb) {}'.format(formode))
w('#define KERN_IMPL_ARITY {}'.format(anum))
w('#define KERN_IMPL_STYPE {}'.format(src_ctype))
w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype))
w("#define KERN_IMPL_MODE(cb) {}".format(formode))
w("#define KERN_IMPL_ARITY {}".format(anum))
w("#define KERN_IMPL_STYPE {}".format(src_ctype))
w("#define KERN_IMPL_DTYPE {}".format(dst_ctype))
w('#include "../kern_impl.inl"')
print('generated {}'.format(fname))
print("generated {}".format(fname))
def main():
parser = argparse.ArgumentParser(
description='generate elemwise impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=['cuda'],
default='cuda', help='generate cuda kernel file')
parser.add_argument('output', help='output directory')
description="generate elemwise impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["cuda"],
default="cuda",
help="generate cuda kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args()
if not os.path.isdir(args.output):
os.makedirs(args.output)
assert args.type == 'cuda'
if args.type == 'cuda':
cpp_ext = 'cu'
assert args.type == "cuda"
if args.type == "cuda":
cpp_ext = "cu"
generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext)
generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext)
os.utime(args.output)
if __name__ == '__main__':
if __name__ == "__main__":
main()
# As cuda currently do not support quint8, so we just ignore it.
SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')]
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32'),
('dt_qint4', 'dt_qint32'), ('dt_quint4', 'dt_qint32')]
SUPPORT_DTYPES = [("dt_qint8", "dt_qint8")]
SUPPORT_QINT32_DTYPES = [
("dt_qint32", "dt_qint8"),
("dt_qint8", "dt_qint32"),
("dt_qint4", "dt_qint32"),
("dt_quint4", "dt_qint32"),
]
SUPPORT_DTYPES_Q4 = [('dt_qint4', 'dt_qint4'), ('dt_quint4', 'dt_quint4')]
SUPPORT_QINT32_DTYPES_Q4 = [('dt_qint32', 'dt_qint4'), ('dt_qint32', 'dt_quint4')]
SUPPORT_DTYPES_Q4 = [("dt_qint4", "dt_qint4"), ("dt_quint4", "dt_quint4")]
SUPPORT_QINT32_DTYPES_Q4 = [("dt_qint32", "dt_qint4"), ("dt_qint32", "dt_quint4")]
SUPPORT_ARRITY2_DTYPES = ['dt_int32', 'dt_uint8', 'dt_int8', 'dt_int16', 'dt_bool', 'dt_float32',
'dt_float16', 'dt_bfloat16']
SUPPORT_ARRITY1_DTYPES = ['dt_float32','dt_float16', 'dt_bfloat16']
SUPPORT_ARRITY2_DTYPES = [
"dt_int32",
"dt_uint8",
"dt_int8",
"dt_int16",
"dt_bool",
"dt_float32",
"dt_float16",
"dt_bfloat16",
]
SUPPORT_ARRITY1_DTYPES = ["dt_float32", "dt_float16", "dt_bfloat16"]
MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
1: [
"RELU",
"ABS",
"NEGATE",
"ACOS",
"ASIN",
"CEIL",
"COS",
"EXP",
"EXPM1",
"FLOOR",
"LOG",
"LOG1P",
"SIGMOID",
"SIN",
"TANH",
"FAST_TANH",
"ROUND",
"ERF",
"ERFINV",
"ERFC",
"ERFCINV",
"H_SWISH",
"SILU",
"GELU",
],
2: [
"ABS_GRAD",
"ADD",
"FLOOR_DIV",
"MAX",
"MIN",
"MOD",
"MUL",
"SIGMOID_GRAD",
"SUB",
"SWITCH_GT0",
"TANH_GRAD",
"LT",
"LEQ",
"EQ",
"FUSE_ADD_RELU",
"TRUE_DIV",
"POW",
"LOG_SUM_EXP",
"FUSE_ADD_TANH",
"FAST_TANH_GRAD",
"FUSE_ADD_SIGMOID",
"ATAN2",
"H_SWISH_GRAD",
"FUSE_ADD_H_SWISH",
"SILU_GRAD",
"GELU_GRAD",
],
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
}
QINT4_MODES = {
1: ['RELU', 'ABS', 'NEGATE', 'CEIL', 'FLOOR', 'SIGMOID',
'TANH', 'FAST_TANH', 'ROUND', 'H_SWISH'],
2: ['ADD', 'MAX', 'MIN', 'MUL', 'SUB', 'SWITCH_GT0',
'LT', 'LEQ', 'EQ', 'FUSE_ADD_RELU', 'FUSE_ADD_TANH',
'FUSE_ADD_SIGMOID', 'FUSE_ADD_H_SWISH'],
3: ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
1: [
"RELU",
"ABS",
"NEGATE",
"CEIL",
"FLOOR",
"SIGMOID",
"TANH",
"FAST_TANH",
"ROUND",
"H_SWISH",
],
2: [
"ADD",
"MAX",
"MIN",
"MUL",
"SUB",
"SWITCH_GT0",
"LT",
"LEQ",
"EQ",
"FUSE_ADD_RELU",
"FUSE_ADD_TANH",
"FUSE_ADD_SIGMOID",
"FUSE_ADD_H_SWISH",
],
3: ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
}
QINT32_MODES = {
1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'],
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID',
'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH']
1: ["RELU", "SIGMOID", "TANH", "FAST_TANH", "H_SWISH"],
2: [
"ADD",
"FUSE_ADD_RELU",
"FUSE_ADD_SIGMOID",
"FUSE_ADD_TANH",
"FUSE_ADD_H_SWISH",
],
}
ARRITY1_BOOL_MODES = {
1: ['ISINF','ISNAN'],
1: ["ISINF", "ISNAN"],
}
ARRITY2_BOOL_MODES = {
2: ['EQ','LEQ','NEQ','LT'],
2: ["EQ", "LEQ", "NEQ", "LT"],
}
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
import os
from gen_elemwise_utils import DTYPES
def main():
parser = argparse.ArgumentParser(
description='generate elemwise impl files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--type', type=str, choices=[
'cuda',
'hip'
],
default='cuda',
help='generate cuda/hip elemwise special kernel file')
parser.add_argument('output', help='output directory')
description="generate elemwise impl files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--type",
type=str,
choices=["cuda", "hip"],
default="cuda",
help="generate cuda/hip elemwise special kernel file",
)
parser.add_argument("output", help="output directory")
args = parser.parse_args()
if not os.path.isdir(args.output):
os.makedirs(args.output)
if args.type == 'cuda':
cpp_ext = 'cu'
if args.type == "cuda":
cpp_ext = "cu"
else:
assert args.type =='hip'
cpp_ext = 'cpp.hip'
assert args.type == "hip"
cpp_ext = "cpp.hip"
for dtype in DTYPES.keys():
fname = 'special_{}.{}'.format(dtype, cpp_ext)
fname = "special_{}.{}".format(dtype, cpp_ext)
fname = os.path.join(args.output, fname)
with open(fname, 'w') as fout:
with open(fname, "w") as fout:
w = lambda s: print(s, file=fout)
w('// generated by gen_elemwise_special_kern_impls.py')
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#if !MEGDNN_DISABLE_FLOAT16')
w("// generated by gen_elemwise_special_kern_impls.py")
if dtype == "dt_float16" or dtype == "dt_bfloat16":
w("#if !MEGDNN_DISABLE_FLOAT16")
w('#include "../special_kerns.inl"')
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0]))
w('#undef INST')
w('}')
w('}')
if dtype == 'dt_float16' or dtype == 'dt_bfloat16':
w('#endif')
w("INST(::megdnn::dtype::{})".format(DTYPES[dtype][0]))
w("#undef INST")
w("}")
w("}")
if dtype == "dt_float16" or dtype == "dt_bfloat16":
w("#endif")
print('generated {}'.format(fname))
print("generated {}".format(fname))
os.utime(args.output)
if __name__ == '__main__':
if __name__ == "__main__":
main()
ARITIES = {1: "UNARY", 2: "BINARY", 3: "TERNARY"}
ARITIES = {1: 'UNARY', 2: 'BINARY', 3: 'TERNARY'}
DTYPES = {'dt_int32': ('Int32', 'INT'),
'dt_uint8': ('Uint8', 'INT'),
'dt_int8': ('Int8', 'INT'),
'dt_int16': ('Int16', 'INT'),
'dt_bool': ('Bool', 'BOOL'),
'dt_float32': ('Float32', 'FLOAT'),
'dt_float16': ('Float16', 'FLOAT'),
'dt_bfloat16': ('BFloat16', 'FLOAT')
}
DTYPES = {
"dt_int32": ("Int32", "INT"),
"dt_uint8": ("Uint8", "INT"),
"dt_int8": ("Int8", "INT"),
"dt_int16": ("Int16", "INT"),
"dt_bool": ("Bool", "BOOL"),
"dt_float32": ("Float32", "FLOAT"),
"dt_float16": ("Float16", "FLOAT"),
"dt_bfloat16": ("BFloat16", "FLOAT"),
}
MODES = {
(1, 'INT'): ['RELU', 'ABS', 'NEGATE'],
(2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ',
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'],
(3, 'INT'): ['COND_LEQ_MOV', 'COND_LT_MOV'],
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS',
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN',
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC',
'ERFCINV', 'H_SWISH', 'SILU', 'GELU'],
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL',
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT',
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW',
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD',
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD',
'FUSE_ADD_H_SWISH', 'SILU_GRAD', 'GELU_GRAD'],
(3, 'FLOAT'): ['COND_LEQ_MOV', 'COND_LT_MOV', 'FUSE_MUL_ADD3'],
(1, 'BOOL'): ['NOT'],
(2, 'BOOL'): ['AND', 'OR', 'XOR', 'LT', 'LEQ', 'EQ'],
(3, 'BOOL'): []
(1, "INT"): ["RELU", "ABS", "NEGATE"],
(2, "INT"): [
"ABS_GRAD",
"ADD",
"FLOOR_DIV",
"MAX",
"MIN",
"MOD",
"MUL",
"SIGMOID_GRAD",
"SUB",
"SWITCH_GT0",
"TANH_GRAD",
"LT",
"LEQ",
"EQ",
"FUSE_ADD_RELU",
"SHL",
"SHR",
"RMULH",
],
(3, "INT"): ["COND_LEQ_MOV", "COND_LT_MOV"],
(1, "FLOAT"): [
"RELU",
"ABS",
"NEGATE",
"ACOS",
"ASIN",
"CEIL",
"COS",
"EXP",
"EXPM1",
"FLOOR",
"LOG",
"LOG1P",
"SIGMOID",
"SIN",
"TANH",
"FAST_TANH",
"ROUND",
"ERF",
"ERFINV",
"ERFC",
"ERFCINV",
"H_SWISH",
"SILU",
"GELU",
],
(2, "FLOAT"): [
"ABS_GRAD",
"ADD",
"FLOOR_DIV",
"MAX",
"MIN",
"MOD",
"MUL",
"SIGMOID_GRAD",
"SUB",
"SWITCH_GT0",
"TANH_GRAD",
"LT",
"LEQ",
"EQ",
"FUSE_ADD_RELU",
"TRUE_DIV",
"POW",
"LOG_SUM_EXP",
"FUSE_ADD_TANH",
"FAST_TANH_GRAD",
"FUSE_ADD_SIGMOID",
"ATAN2",
"H_SWISH_GRAD",
"FUSE_ADD_H_SWISH",
"SILU_GRAD",
"GELU_GRAD",
],
(3, "FLOAT"): ["COND_LEQ_MOV", "COND_LT_MOV", "FUSE_MUL_ADD3"],
(1, "BOOL"): ["NOT"],
(2, "BOOL"): ["AND", "OR", "XOR", "LT", "LEQ", "EQ"],
(3, "BOOL"): [],
}
......@@ -3,13 +3,14 @@
import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io
import os
import struct
import textwrap
from gen_param_defs import IndentWriterBase, ParamDef, member_defs
from gen_param_defs import member_defs, ParamDef, IndentWriterBase
class ConverterWriter(IndentWriterBase):
_skip_current_param = False
......@@ -20,7 +21,7 @@ class ConverterWriter(IndentWriterBase):
def __call__(self, fout, defs):
super().__call__(fout)
self._write("// %s", self._get_header())
self._write('#include <flatbuffers/flatbuffers.h>')
self._write("#include <flatbuffers/flatbuffers.h>")
self._write("namespace mgb {")
self._write("namespace serialization {")
self._write("namespace fbs {")
......@@ -33,8 +34,9 @@ class ConverterWriter(IndentWriterBase):
self._last_param = p
self._param_fields = []
self._fb_fields = ["builder"]
self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {",
p.name, indent=1)
self._write(
"template<>\nstruct ParamConverter<megdnn::param::%s> {", p.name, indent=1
)
self._write("using MegDNNType = megdnn::param::%s;", p.name)
self._write("using FlatBufferType = fbs::param::%s;\n", p.name)
......@@ -42,22 +44,22 @@ class ConverterWriter(IndentWriterBase):
if self._skip_current_param:
self._skip_current_param = False
return
self._write("static MegDNNType to_param(const FlatBufferType* fb) {",
indent=1)
line = 'return {'
line += ', '.join(self._param_fields)
line += '};'
self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1)
line = "return {"
line += ", ".join(self._param_fields)
line += "};"
self._write(line)
self._write("}\n", indent=-1)
self._write(
"static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {",
indent=1)
line = 'return fbs::param::Create{}('.format(str(p.name))
line += ', '.join(self._fb_fields)
line += ');'
indent=1,
)
line = "return fbs::param::Create{}(".format(str(p.name))
line += ", ".join(self._fb_fields)
line += ");"
self._write(line)
self._write('}', indent=-1)
self._write("}", indent=-1)
self._write("};\n", indent=-1)
......@@ -68,18 +70,23 @@ class ConverterWriter(IndentWriterBase):
return
self._param_fields.append(
"static_cast<megdnn::param::{}::{}>(fb->{}())".format(
str(p.name), str(e.name), e.name_field))
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format(
key, e.name_field))
str(p.name), str(e.name), e.name_field
)
)
self._fb_fields.append(
"static_cast<fbs::param::{}>(param.{})".format(key, e.name_field)
)
def _on_member_field(self, f):
if self._skip_current_param:
return
if f.dtype.cname == 'DTypeEnum':
if f.dtype.cname == "DTypeEnum":
self._param_fields.append(
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name))
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)
)
self._fb_fields.append(
"intl::convert_dtype_to_fbs(param.{})".format(f.name))
"intl::convert_dtype_to_fbs(param.{})".format(f.name)
)
else:
self._param_fields.append("fb->{}()".format(f.name))
self._fb_fields.append("param.{}".format(f.name))
......@@ -93,28 +100,33 @@ class ConverterWriter(IndentWriterBase):
enum_name = e.src_class + e.src_name
self._param_fields.append(
"static_cast<megdnn::param::{}::{}>(fb->{}())".format(
e.src_class, e.src_name, e.name_field))
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format(
enum_name, e.name_field))
e.src_class, e.src_name, e.name_field
)
)
self._fb_fields.append(
"static_cast<fbs::param::{}>(param.{})".format(enum_name, e.name_field)
)
def main():
parser = argparse.ArgumentParser(
'generate convert functions between FlatBuffers type and MegBrain type')
parser.add_argument('input')
parser.add_argument('output')
"generate convert functions between FlatBuffers type and MegBrain type"
)
parser.add_argument("input")
parser.add_argument("output")
args = parser.parse_args()
with open(args.input) as fin:
inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash.update(inputs.encode(encoding="UTF-8"))
input_hash = input_hash.hexdigest()
writer = ConverterWriter()
with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
if __name__ == "__main__":
main()
......@@ -3,13 +3,14 @@
import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io
import os
import struct
import textwrap
from gen_param_defs import IndentWriterBase, ParamDef, member_defs
from gen_param_defs import member_defs, ParamDef, IndentWriterBase
def _cname_to_fbname(cname):
return {
......@@ -22,17 +23,19 @@ def _cname_to_fbname(cname):
"bool": "bool",
}[cname]
def scramble_enum_member_name(name):
s = name.find('<<')
s = name.find("<<")
if s != -1:
name = name[0:name.find('=') + 1] + ' ' + name[s+2:]
name = name[0 : name.find("=") + 1] + " " + name[s + 2 :]
if name in ("MIN", "MAX"):
return name + "_"
o_name = name.split(' ')[0].split('=')[0]
o_name = name.split(" ")[0].split("=")[0]
if o_name in ("MIN", "MAX"):
return name.replace(o_name, o_name + "_")
return name
class FlatBuffersWriter(IndentWriterBase):
_skip_current_param = False
_last_param = None
......@@ -66,12 +69,13 @@ class FlatBuffersWriter(IndentWriterBase):
self._write("}\n", indent=-1)
def _write_doc(self, doc):
if not isinstance(doc, member_defs.Doc) or not doc.doc: return
if not isinstance(doc, member_defs.Doc) or not doc.doc:
return
doc_lines = []
if doc.no_reformat:
doc_lines = doc.raw_lines
else:
doc = doc.doc.replace('\n', ' ')
doc = doc.doc.replace("\n", " ")
text_width = 80 - len(self._cur_indent) - 4
doc_lines = textwrap.wrap(doc, text_width)
for line in doc_lines:
......@@ -101,7 +105,8 @@ class FlatBuffersWriter(IndentWriterBase):
default = e.compose_combined_enum(e.default)
else:
default = scramble_enum_member_name(
str(e.members[e.default]).split(' ')[0].split('=')[0])
str(e.members[e.default]).split(" ")[0].split("=")[0]
)
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)
def _resolve_const(self, v):
......@@ -113,8 +118,12 @@ class FlatBuffersWriter(IndentWriterBase):
if self._skip_current_param:
return
self._write_doc(f.name)
self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname),
self._get_fb_default(self._resolve_const(f.default)))
self._write(
"%s:%s = %s;",
f.name,
_cname_to_fbname(f.dtype.cname),
self._get_fb_default(self._resolve_const(f.default)),
)
def _on_const_field(self, f):
self._cur_const_val[str(f.name)] = str(f.default)
......@@ -129,7 +138,8 @@ class FlatBuffersWriter(IndentWriterBase):
default = s.compose_combined_enum(e.get_default())
else:
default = scramble_enum_member_name(
str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
str(s.members[e.get_default()]).split(" ")[0].split("=")[0]
)
self._write("%s:%s = %s;", e.name_field, enum_name, default)
def _get_fb_default(self, cppdefault):
......@@ -137,9 +147,9 @@ class FlatBuffersWriter(IndentWriterBase):
return cppdefault
d = cppdefault
if d.endswith('f'): # 1.f
if d.endswith("f"): # 1.f
return d[:-1]
if d.endswith('ull'):
if d.endswith("ull"):
return d[:-3]
if d.startswith("DTypeEnum::"):
return d[11:]
......@@ -148,21 +158,23 @@ class FlatBuffersWriter(IndentWriterBase):
def main():
parser = argparse.ArgumentParser(
'generate FlatBuffers schema of operator param from description file')
parser.add_argument('input')
parser.add_argument('output')
"generate FlatBuffers schema of operator param from description file"
)
parser.add_argument("input")
parser.add_argument("output")
args = parser.parse_args()
with open(args.input) as fin:
inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash.update(inputs.encode(encoding="UTF-8"))
input_hash = input_hash.hexdigest()
writer = FlatBuffersWriter()
with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
if __name__ == "__main__":
main()
#! /usr/local/env python3
import pickle
import numpy as np
import os
import argparse
import re
import collections
import os
import pickle
import re
import numpy as np
def define_template(**kwargs):
template = '''
template = """
float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}};
float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}};
float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}};
......@@ -17,21 +19,23 @@ def define_template(**kwargs):
const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}};
const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}};
const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}};
'''
"""
return template.format(**kwargs)
def cudnn_slt_template(**kwargs):
template = ("#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" +
" {define_cmd}\n" +
" {select_cmd}\n" +
" return true;\n" +
"#endif\n"
)
template = (
"#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n"
+ " {define_cmd}\n"
+ " {select_cmd}\n"
+ " return true;\n"
+ "#endif\n"
)
return template.format(**kwargs)
def select_template(**kwargs):
template = \
'''if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} &&
template = """if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} &&
cuda_minor == {cuda_minor}) {{
*layer_num_p = {layer_num};
*hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units;
......@@ -42,7 +46,7 @@ def select_template(**kwargs):
*beta_p = cuda{cuda_arch}_{conv_type}_beta;
*time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred;
*mask_p = cuda{cuda_arch}_{conv_type}_mask;
}} else '''
}} else """
return template.format(**kwargs)
......@@ -58,48 +62,48 @@ def fill_src():
if len(matrix_files) == 0:
print("Warning: no param files detected.")
for fpath in matrix_files:
cudnn_version = re.findall('cudnn([\d.]+)',fpath)[0]
cudnn_version = re.findall("cudnn([\d.]+)", fpath)[0]
gen_list[cudnn_version].append(fpath)
for cudnn in gen_list:
select_cmd = ("{\n" +
" " * 8 + "return false;\n" +
" " * 4 + "}")
select_cmd = "{\n" + " " * 8 + "return false;\n" + " " * 4 + "}"
define_cmd = ""
cudnn_major, cudnn_minor = cudnn.split('.')
cudnn_major, cudnn_minor = cudnn.split(".")
for fpath in gen_list[cudnn]:
cuda_arch = fpath.split("-")[1].replace(".", "_")
print('cudnn_version: {}, cuda_arch: {}'.format(cudnn,cuda_arch))
print("cudnn_version: {}, cuda_arch: {}".format(cudnn, cuda_arch))
conv_type = fpath.split("-")[2].split(".")[0]
with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj:
params = pickle.load(pobj)
crt_define_cmd, crt_select_cmd = gen_cmds(
cuda_arch, conv_type, params)
crt_define_cmd, crt_select_cmd = gen_cmds(cuda_arch, conv_type, params)
select_cmd = crt_select_cmd + select_cmd
define_cmd = crt_define_cmd + define_cmd
cudnn_slt_cmd += cudnn_slt_template(cudnn_major=cudnn_major,
cudnn_minor=cudnn_minor,
select_cmd=select_cmd,
define_cmd=define_cmd)
cudnn_slt_cmd += cudnn_slt_template(
cudnn_major=cudnn_major,
cudnn_minor=cudnn_minor,
select_cmd=select_cmd,
define_cmd=define_cmd,
)
#select_cmd = select_cmd
# select_cmd = select_cmd
with open(os.path.join(home, "get_params.template"), "r") as srcf:
src = srcf.read()
dst = src.replace("{cudnn_select}", cudnn_slt_cmd)
MegDNN_path = os.path.join(home, "../..")
with open(os.path.join(MegDNN_path,
"src/cuda/convolution/get_params.cpp"), "w") as dstf:
with open(
os.path.join(MegDNN_path, "src/cuda/convolution/get_params.cpp"), "w"
) as dstf:
dstf.write(dst)
def gen_cmds(cuda_arch, conv_type, params):
cuda_major, cuda_minor = cuda_arch.split("_")
alphastr = format_array(params['alpha']).rstrip()[:-1]
betastr = format_array(params['beta']).rstrip()[:-1]
W_list = params['W']
b_list = params['b']
Wstr = ''
bstr = ''
alphastr = format_array(params["alpha"]).rstrip()[:-1]
betastr = format_array(params["beta"]).rstrip()[:-1]
W_list = params["W"]
b_list = params["b"]
Wstr = ""
bstr = ""
layer_num = str(len(b_list) + 1)
layers_dim = [W_list[0].shape[1]]
matrices_dim = 0
......@@ -118,16 +122,27 @@ def gen_cmds(cuda_arch, conv_type, params):
out_dim = layers_dim[-1]
layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1]
select_cmd = select_template(conv_type=conv_type.upper(), cuda_major=cuda_major,
cuda_minor=cuda_minor, layer_num=layer_num,
cuda_arch=cuda_arch)
define_cmd = define_template(cuda_arch=cuda_arch, conv_type=conv_type.upper(),
hidden_num=hidden_num,
layer_num=layer_num, out_dim=out_dim,
layers_dim=layers_dim_str,
matrices_dim=matrices_dim, matrices=Wstr,
biases_dim=biases_dim, biases=bstr,
alpha=alphastr, beta=betastr)
select_cmd = select_template(
conv_type=conv_type.upper(),
cuda_major=cuda_major,
cuda_minor=cuda_minor,
layer_num=layer_num,
cuda_arch=cuda_arch,
)
define_cmd = define_template(
cuda_arch=cuda_arch,
conv_type=conv_type.upper(),
hidden_num=hidden_num,
layer_num=layer_num,
out_dim=out_dim,
layers_dim=layers_dim_str,
matrices_dim=matrices_dim,
matrices=Wstr,
biases_dim=biases_dim,
biases=bstr,
alpha=alphastr,
beta=betastr,
)
return (define_cmd, select_cmd)
......@@ -153,8 +168,9 @@ def format_array(array):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate cuDNN heuristic code by neural network into"
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp,"
" using parameter value from pickle files in"
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/")
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp,"
" using parameter value from pickle files in"
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/"
)
args = parser.parse_args()
main()
此差异已折叠。
......@@ -3,19 +3,17 @@
import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io
import os
import struct
import textwrap
from gen_param_defs import member_defs, ParamDef, IndentWriterBase
from gen_param_defs import IndentWriterBase, ParamDef, member_defs
# FIXME: move supportToString flag definition into the param def source file
ENUM_TO_STRING_SPECIAL_RULES = [
("Elemwise", "Mode"),
("ElemwiseMultiType", "Mode")
]
ENUM_TO_STRING_SPECIAL_RULES = [("Elemwise", "Mode"), ("ElemwiseMultiType", "Mode")]
class ConverterWriter(IndentWriterBase):
_skip_current_param = False
......@@ -33,21 +31,21 @@ class ConverterWriter(IndentWriterBase):
self._write("#endif // MGB_PARAM")
def _ctype2attr(self, ctype, value):
if ctype == 'uint32_t':
return 'MgbUI32Attr', value
if ctype == 'uint64_t':
return 'MgbUI64Attr', value
if ctype == 'int32_t':
return 'MgbI32Attr', value
if ctype == 'float':
return 'MgbF32Attr', value
if ctype == 'double':
return 'MgbF64Attr', value
if ctype == 'bool':
return 'MgbBoolAttr', value
if ctype == 'DTypeEnum':
if ctype == "uint32_t":
return "MgbUI32Attr", value
if ctype == "uint64_t":
return "MgbUI64Attr", value
if ctype == "int32_t":
return "MgbI32Attr", value
if ctype == "float":
return "MgbF32Attr", value
if ctype == "double":
return "MgbF64Attr", value
if ctype == "bool":
return "MgbBoolAttr", value
if ctype == "DTypeEnum":
self._packed = False
return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value)
return "MgbDTypeAttr", "megdnn::DType::from_enum(megdnn::{})".format(value)
raise RuntimeError("unknown ctype")
def _on_param_begin(self, p):
......@@ -61,21 +59,26 @@ class ConverterWriter(IndentWriterBase):
self._skip_current_param = False
return
if self._packed:
self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1)
self._write(
'class {0}ParamBase<string accessor> : MgbPackedParamBase<"{0}", accessor> {{'.format(
p.name
),
indent=1,
)
else:
self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1)
self._write('def {0}Param: MgbParamBase<"{0}"> {{'.format(p.name), indent=1)
self._write("let fields = (ins", indent=1)
self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
self._write(");", indent=-1)
self._write("}\n", indent=-1)
if self._packed:
self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name))
self._write('def {0}Param : {0}ParamBase<"param">;\n'.format(p.name))
self._current_tparams = None
self._packed = None
self._const = None
def _wrapped_with_default_value(self, attr, default):
return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default)
return 'MgbDefaultValuedAttr<{}, "{}">'.format(attr, default)
def _on_member_enum(self, e):
p = self._last_param
......@@ -84,10 +87,12 @@ class ConverterWriter(IndentWriterBase):
# directly used by any operator, or other enum couldn't alias to this enum
td_class = "{}{}".format(p.name, e.name)
fullname = "::megdnn::param::{}".format(p.name)
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
enum_def = 'MgbEnumAttr<"{}", "{}", ['.format(fullname, e.name)
def format(v):
return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0])
enum_def += ','.join(format(i) for i in e.members)
return '"{}"'.format(str(v).split(" ")[0].split("=")[0])
enum_def += ",".join(format(i) for i in e.members)
if e.combined:
enum_def += "], 1"
......@@ -95,7 +100,7 @@ class ConverterWriter(IndentWriterBase):
enum_def += "], 0"
if ENUM_TO_STRING_SPECIAL_RULES.count((p.name, e.name)):
enum_def += ", 1" # whether generate ToStringTrait
enum_def += ", 1" # whether generate ToStringTrait
enum_def += ">"
self._write("def {} : {};".format(td_class, enum_def))
......@@ -105,10 +110,12 @@ class ConverterWriter(IndentWriterBase):
# wrapped with default value
if e.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default))
fullname, e.name, e.compose_combined_enum(e.default)
)
else:
default_val = "{}::{}::{}".format(
fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0])
fullname, e.name, str(e.members[e.default]).split(" ")[0].split("=")[0]
)
wrapped = self._wrapped_with_default_value(td_class, default_val)
......@@ -123,51 +130,58 @@ class ConverterWriter(IndentWriterBase):
td_class = "{}{}".format(p.name, e.name)
fullname = "::megdnn::param::{}".format(p.name)
base_td_class = "{}{}".format(e.src_class, e.src_name)
enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class)
enum_def = 'MgbEnumAliasAttr<"{}", "{}", {}>'.format(
fullname, e.name, base_td_class
)
self._write("def {} : {};".format(td_class, enum_def))
# wrapped with default value
s = e.src_enum
if s.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default()))
fullname, e.name, s.compose_combined_enum(e.get_default())
)
else:
default_val = "{}::{}::{}".format(fullname, e.name, str(
s.members[e.get_default()]).split(' ')[0].split('=')[0])
default_val = "{}::{}::{}".format(
fullname,
e.name,
str(s.members[e.get_default()]).split(" ")[0].split("=")[0],
)
wrapped = self._wrapped_with_default_value(td_class, default_val)
self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
def _on_member_field(self, f):
if self._skip_current_param:
return
attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
if str(value) in self._const:
value = '::megdnn::param::{}::{}'.format(self._last_param.name, value)
value = "::megdnn::param::{}::{}".format(self._last_param.name, value)
wrapped = self._wrapped_with_default_value(attr, value)
self._current_tparams.append("{}:${}".format(wrapped, f.name))
def _on_const_field(self, f):
self._const.add(str(f.name))
def main():
parser = argparse.ArgumentParser('generate op param tablegen file')
parser.add_argument('input')
parser.add_argument('output')
parser = argparse.ArgumentParser("generate op param tablegen file")
parser.add_argument("input")
parser.add_argument("output")
args = parser.parse_args()
with open(args.input) as fin:
inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash.update(inputs.encode(encoding="UTF-8"))
input_hash = input_hash.hexdigest()
writer = ConverterWriter()
with open(args.output, 'w') as fout:
with open(args.output, "w") as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
if __name__ == "__main__":
main()
......@@ -19,6 +19,7 @@ device = {
"thread_number": 3,
}
class SshConnector:
"""imp ssh control master connector"""
......@@ -83,17 +84,17 @@ def main():
model_file = args.model_file
# copy model file
ssh.copy([args.model_file], workspace)
m = model_file.split('\\')[-1]
m = model_file.split("\\")[-1]
# run single thread
result = []
thread_number = [1, 2, 4]
for b in thread_number :
for b in thread_number:
cmd = []
cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format(
workspace, m, b
workspace, m, b
)
cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format(
workspace, m, b
workspace, m, b
)
cmd.append(cmd1)
cmd.append(cmd2)
......@@ -103,12 +104,20 @@ def main():
logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret))
result.append(ret)
thread_2 = result[0]/result[1]
thread_4 = result[0]/result[2]
thread_2 = result[0] / result[1]
thread_4 = result[0] / result[2]
if thread_2 > 1.6 or thread_4 > 3.0:
print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
print(
"model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(
m, thread_2, thread_4
)
)
else:
print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
print(
"model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(
m, thread_2, thread_4
)
)
if __name__ == "__main__":
......
......@@ -20,8 +20,12 @@ failed_files = Manager().list()
def process_file(file, clang_format, write):
original_source = open(file, "r").read()
source = original_source
source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source)
source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source)
source = re.sub(
r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source
)
source, count = re.subn(
r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source
)
result = subprocess.check_output(
[
......@@ -36,7 +40,9 @@ def process_file(file, clang_format, write):
result = result.decode("utf-8")
if count:
result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result)
result = re.sub(
r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result
)
result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result)
if write and original_source != result:
......@@ -109,19 +115,17 @@ def main():
raise ValueError("Invalid path {}".format(path))
# check version, we only support 12.0.1 now
version = subprocess.check_output(
[
args.clang_format,
"--version",
],
)
version = subprocess.check_output([args.clang_format, "--version",],)
version = version.decode("utf-8")
need_version = '12.0.1'
need_version = "12.0.1"
if version.find(need_version) < 0:
print('We only support {} now, please install {} version, find version: {}'
.format(need_version, need_version, version))
raise RuntimeError('clang-format version not equal {}'.format(need_version))
print(
"We only support {} now, please install {} version, find version: {}".format(
need_version, need_version, version
)
)
raise RuntimeError("clang-format version not equal {}".format(need_version))
process_map(
partial(process_file, clang_format=args.clang_format, write=args.write,),
......
......@@ -20,6 +20,7 @@ device = {
"thread_number": 3,
}
class SshConnector:
"""imp ssh control master connector"""
......@@ -54,6 +55,7 @@ class SshConnector:
except:
raise
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model_file", help="megengine model", required=True)
......@@ -78,10 +80,10 @@ def main():
model_file = args.model_file
# copy model file
ssh.copy([model_file], workspace)
m = model_file.split('\\')[-1]
m = model_file.split("\\")[-1]
# run single thread
cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format(
workspace, m
workspace, m
)
try:
raw_log = ssh.cmd([cmd])
......@@ -91,6 +93,7 @@ def main():
print("model: {} is static model.".format(m))
if __name__ == "__main__":
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
DATE_FORMAT = "%Y/%m/%d %H:%M:%S"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册