From fb49a2834fc115f623a12156ea18791f1e781365 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 7 Sep 2021 17:11:01 +0800 Subject: [PATCH] refactor(mgb/dnn): refactor enum used in serializing GitOrigin-RevId: e57af4a59c9b4e090f3972b4d0cf01a2737f8355 --- dnn/scripts/gen_flatbuffers_schema.py | 12 +- dnn/scripts/gen_param_defs.py | 59 +- dnn/scripts/gen_tablegen.py | 8 +- dnn/scripts/opr_param_defs.py | 562 +++++++++--------- imperative/tablegen/helper.h | 19 +- imperative/tablegen/targets/cpp_class.cpp | 10 +- imperative/tablegen/targets/pybind11.cpp | 15 +- .../tablegen/targets/python_c_extension.cpp | 16 +- tools/gen_header_for_bin_reduce.py | 3 +- tools/param_defs/mgb_opr_param_defs.py | 76 +-- 10 files changed, 410 insertions(+), 370 deletions(-) diff --git a/dnn/scripts/gen_flatbuffers_schema.py b/dnn/scripts/gen_flatbuffers_schema.py index 9794c7565..d6165d0d1 100755 --- a/dnn/scripts/gen_flatbuffers_schema.py +++ b/dnn/scripts/gen_flatbuffers_schema.py @@ -23,8 +23,14 @@ def _cname_to_fbname(cname): }[cname] def scramble_enum_member_name(name): + s = name.find('<<') + if s != -1: + name = name[0:name.find('=') + 1] + ' ' + name[s+2:] if name in ("MIN", "MAX"): return name + "_" + o_name = name.split(' ')[0].split('=')[0] + if o_name in ("MIN", "MAX"): + return name.replace(o_name, o_name + "_") return name class FlatBuffersWriter(IndentWriterBase): @@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase): if e.combined: default = e.compose_combined_enum(e.default) else: - default = scramble_enum_member_name(str(e.members[e.default])) + default = scramble_enum_member_name( + str(e.members[e.default]).split(' ')[0].split('=')[0]) self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) def _resolve_const(self, v): @@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase): if s.combined: default = s.compose_combined_enum(e.get_default()) else: - default = scramble_enum_member_name(str(s.members[e.get_default()])) + default = scramble_enum_member_name( + str(s.members[e.get_default()]).split(' ')[0].split('=')[0]) self._write("%s:%s = %s;", e.name_field, enum_name, default) def _get_fb_default(self, cppdefault): diff --git a/dnn/scripts/gen_param_defs.py b/dnn/scripts/gen_param_defs.py index dff1cc85c..f3e1b5efb 100755 --- a/dnn/scripts/gen_param_defs.py +++ b/dnn/scripts/gen_param_defs.py @@ -121,10 +121,12 @@ class member_defs: def normalize_enum_value(self, value): def normalize(v): if isinstance(v, str): - if v not in self.members: - raise ValueError( - "enum member '{}' does not exist.".format(v)) - v = self.members.index(v) + for idx, m in enumerate(self.members): + m = str(m).split(' ')[0].split('=')[0] + if v == m : + return idx + raise ValueError( + "enum member '{}' does not exist.".format(v)) assert isinstance(v, int) return v if self.combined: @@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase): self._write_doc(e.name) - for idx, emem in enumerate(e.members): + for emem in e.members: if e.combined: - self._write('%s = 1 << %d', emem, idx) + self._write('%s', emem) self._write_doc(emem) else: - self._write('%s = "%s"', emem, emem) + v = str(emem).split(' ')[0].split('=')[0] + n = int(str(emem).split('=')[1]) + self._write('%s = "%s"', v, v) self._write_doc(emem) self._enum_member2num.append('id({}.{}):{}'.format( - qualname, emem, idx)) + qualname, v, n)) for emem, emem_alias in e.member_alias: + em_a = emem_alias.split(' ')[0].split('=')[0] if e.combined: - self._write('%s = %s', emem_alias, e.compose_combined_enum(emem)) + self._write('%s = %s', em_a, e.compose_combined_enum(emem)) else: - self._write('%s = %s', emem_alias, emem) + em = str(emem).split(' ')[0].split('=')[0] + self._write('%s = %s', em_a, em) self._unindent() self._write('') @@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase): if e.combined: default = e.compose_combined_enum(e.default) else: - default = "'{}'".format(e.members[e.default]) + default = "'{}'".format(str(e.members[e.default]).split(' ')[0].split('=')[0]) self._cur_fields.append(self.FieldDef( name=e.name_field, @@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase): if s.combined: default = s.compose_combined_enum(e.get_default()) else: - default = "'{}'".format(s.members[e.get_default()]) + default = "'{}'".format(str(s.members[e.get_default()]).split(' ')[0].split('=')[0]) self._cur_fields.append(self.FieldDef( name=e.name_field, cvt='{}.convert({})'.format(qualname, e.name_field), @@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase): def _on_member_enum(self, e): self._write_doc(e.name) self._write('enum class %s: uint32_t {', e.name, indent=1) - for idx, i in enumerate(e.members): + for i in e.members: self._write_doc(i) - v = '{} = {}'.format(i, idx) - if e.combined: - v = '{} = 1 << {}'.format(i, idx) + v = str(i) if i is not e.members[-1] or e.member_alias: v += ',' self._write(v) @@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase): if e.combined: self._write('%s = %s,', alias, e.compose_combined_enum(mem)) else: - self._write('%s = %s,', alias, mem) + self._write('%s = %s,', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0]) self._write('};', indent=-1) self._non_static_members.append(e) self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', @@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase): if e.combined: default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) else: - default = '{}::{}'.format(e.name, e.members[e.default]) + value = str(e.members[e.default]) + value = value.split(' ')[0].split('=')[0] + default = '{}::{}'.format(e.name, value) self._add_ctor_args(e.name, default, e.name_field) def _on_member_enum_alias(self, e): @@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase): if s.combined: default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) else: - default = '{}::{}'.format(e.name, s.members[e.get_default()]) + value = str(s.members[e.get_default()]) + value = value.split(' ')[0].split('=')[0] + default = '{}::{}'.format(e.name, value) self._add_ctor_args(e.name, default, e.name_field) def _on_member_field(self, f): @@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter): def _on_member_enum(self, e): self._write_doc(e.name) self._write('struct %s {', e.name, indent=1) - for idx, val in enumerate(e.members): + for val in e.members: self._write_doc(val) - self._write('static const uint32_t %s = %d;', val, idx) + v = str(val) + self._write('static const uint32_t %s;', v) for mem, alias in e.member_alias: - self._write('static const uint32_t %s = %s;', alias, mem) + self._write('static const uint32_t %s = %s;', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0]) self._write('};', indent=-1) def _on_member_enum_alias(self, e): @@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase): members = e.src_enum.members else: members = e.members - for idx, i in enumerate(members): + for i in members: + v = str(i) + v = v.split(' ')[0].split('=')[0] self._write('case %s::%s::%s: return "%s";', - self._param_name, e.name, i, i, indent=0) + self._param_name, e.name, v, v, indent=0) self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast(arg));', self._param_name, e.name, indent=0) self._write('}', indent=-1) diff --git a/dnn/scripts/gen_tablegen.py b/dnn/scripts/gen_tablegen.py index fa032c20e..911cf749c 100755 --- a/dnn/scripts/gen_tablegen.py +++ b/dnn/scripts/gen_tablegen.py @@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase): fullname = "::megdnn::param::{}".format(p.name) enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) def format(v): - return '\"{}\"'.format(str(v)) + return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0]) enum_def += ','.join(format(i) for i in e.members) if e.combined: @@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase): default_val = "static_cast<{}::{}>({})".format( fullname, e.name, e.compose_combined_enum(e.default)) else: - default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default]) + default_val = "{}::{}::{}".format( + fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0]) wrapped = self._wrapped_with_default_value(td_class, default_val) @@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase): default_val = "static_cast<{}::{}>({})".format( fullname, e.name, s.compose_combined_enum(e.get_default())) else: - default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()]) + default_val = "{}::{}::{}".format(fullname, e.name, str( + s.members[e.get_default()]).split(' ')[0].split('=')[0]) wrapped = self._wrapped_with_default_value(td_class, default_val) diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 83fafb665..32c42750a 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -3,7 +3,7 @@ pdef('Empty') pdef('Axis').add_fields('int32', 'axis', 0) (pdef('Convolution', version=0, is_legacy=True). - add_enum('Mode', 'CROSS_CORRELATION', 'CONVOLUTION'). + add_enum('Mode', 'CROSS_CORRELATION = 0', 'CONVOLUTION = 1'). add_fields( 'uint32', Doc('pad_h', 'padding on one side on the first dimension'), 0, @@ -16,41 +16,41 @@ pdef('Axis').add_fields('int32', 'axis', 0) 'on the second dimension'), 1 ). add_enum('DataType', - Doc('FLOAT', 'input/output both float32/float16'), - 'INT8x8x16', - 'INT8x8x32', - Doc('FLOAT_IO16xC32', 'input/output both float16, the internal ' + Doc('FLOAT = 0', 'input/output both float32/float16'), + 'INT8x8x16 = 1', + 'INT8x8x32 = 2', + Doc('FLOAT_IO16xC32 = 3', 'input/output both float16, the internal ' 'compute is float32'), - Doc('QUINT8x8x32', 'input QuantizedAsymm8, output QuantizedS32'), - Doc('INT8x8xX', 'input int8, output specified by tensor DType'), - Doc('QUINT4x4x32', 'input QuantizedAsymm4, output QuantizedS32'), + Doc('QUINT8x8x32 = 4', 'input QuantizedAsymm8, output QuantizedS32'), + Doc('INT8x8xX = 5', 'input int8, output specified by tensor DType'), + Doc('QUINT4x4x32 = 6', 'input QuantizedAsymm4, output QuantizedS32'), name_field='data_type'). add_enum('Sparse', - Doc('DENSE', 'dense convolution: filter shape should be ' + Doc('DENSE = 0', 'dense convolution: filter shape should be ' '[oc, ic, spatial...] if format is NCHW, ' '[oc, spatial..., ic] if format is NHWC'), - Doc('GROUP', 'group convolution: filter shape should be ' + Doc('GROUP = 1', 'group convolution: filter shape should be ' '[group, oc_per_group, ic_per_group, spatial...] if format is NCHW, ' '[group, oc_per_group, spatial..., ic_per_group] if format is NHWC') ). add_enum(Doc('Format', 'convolution data/filter/output format; see ' ':class:`RelayoutFormat` for more details'), - 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', - 'NCHW44','NCHW44_DOT', - Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), - Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), - Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), - Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), - Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), - Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), - Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'), - Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' + 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', + 'NCHW44 = 7','NCHW44_DOT = 8', + Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'), + Doc('NCHW88_WINOGRAD = 10', 'NCHW88 layout with weights tranformed by winograd'), + Doc('NCHW44_WINOGRAD = 11', 'NCHW44 layout with weights tranformed by winograd'), + Doc('NCHW4_NCHW32 = 12', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), + Doc('NCHW32_NCHW4 = 13', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), + Doc('NCHW4_NCHW = 14', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), + Doc('NCHW4_NHWC = 15', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'), + Doc('NHWC_NCHW = 16', 'NHWC_NCHW means input tensors are nhwc layout, ' 'output tensor is nchw layout'), - Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' + Doc('NHWC_NCHW4_IC_SMALL = 17', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' 'output tensor is nchw4 layout, padding c=4'), - Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' + Doc('NCHW_NCHW4_IC_SMALL = 18', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' 'output tensor is nchw4 layout, padding c=4'), - Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' + Doc('CHWN4 = 19', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) ) @@ -72,9 +72,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. ' 'different combinations of intermediate result ' 'data types.'), - Doc('DEFAULT', 'No special requirements on the precision of ' + Doc('DEFAULT = 0', 'No special requirements on the precision of ' 'intermediate results.'), - Doc('FLOAT32', 'Use Float32 accumulator and intermediate result. ' + Doc('FLOAT32 = 1', 'Use Float32 accumulator and intermediate result. ' 'Only supported when input and output is Float16.'), name_field='compute_mode') ) @@ -95,21 +95,21 @@ pdef('Axis').add_fields('int32', 'axis', 0) add_enum_alias('Sparse', 'ConvolutionV0'). add_enum(Doc('Format', 'convolution data/filter/output format; see ' ':class:`RelayoutFormat` for more details'), - 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', - 'NCHW44','NCHW44_DOT', - Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), - Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), - Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), - Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'), - Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' + 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6', + 'NCHW44 = 7','NCHW44_DOT = 8', + Doc('NCHW4_NCHW32 = 9', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), + Doc('NCHW32_NCHW4 = 10', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), + Doc('NCHW4_NCHW = 11', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), + Doc('NCHW4_NHWC = 12', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'), + Doc('NHWC_NCHW = 13', 'NHWC_NCHW means input tensors are nhwc layout, ' 'output tensor is nchw layout'), - Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' + Doc('NHWC_NCHW4_IC_SMALL = 14', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' 'output tensor is nchw4 layout, padding c=4'), - Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' + Doc('NCHW_NCHW4_IC_SMALL = 15', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' 'output tensor is nchw4 layout, padding c=4'), - Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' + Doc('CHWN4 = 16', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'), - Doc('NCHW64', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' + Doc('NCHW64 = 17', 'NCHW64 is designed for convolution implementation to utilizing TensorCore ' 'instructions for 4-bit integers on Nvidia platforms')). add_enum_alias('ComputeMode', 'ConvolutionV1',name_field='compute_mode') ) @@ -129,15 +129,15 @@ pdef('Axis').add_fields('int32', 'axis', 0) ) (pdef('ConvPooling'). - add_enum('Method', 'WITH_TEXTURE_OBJ', 'WITH_SHARED_MEM'). + add_enum('Method', 'WITH_TEXTURE_OBJ = 0', 'WITH_SHARED_MEM = 1'). add_enum_alias('ConvMode', 'ConvolutionV0', 'Mode'). - add_enum('PoolMode', 'AVERAGE', 'MAX'). - add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID'). + add_enum('PoolMode', 'AVERAGE = 0', 'MAX = 1'). + add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2'). add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1, \ 'pool_pad_h', 0, 'pool_pad_w', 0, 'conv_stride_h', 1, 'conv_stride_w', 1, 'conv_pad_h', 0, 'conv_pad_w', 0)) (pdef('ConvBias', 'legacy conv_bias', version=0, is_legacy=True). - add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID', 'H_SWISH'). + add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2', 'H_SWISH = 3'). add_enum_alias('Mode', 'ConvolutionV0'). add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1)) @@ -215,9 +215,9 @@ pdef('Axis').add_fields('int32', 'axis', 0) ) (pdef('SeparableConv'). add_enum_alias('Mode', 'ConvolutionV0'). - add_enum('BorderMode', 'BORDER_REPLICATE', 'BORDER_REFLECT', - 'BORDER_REFLECT_101','BORDER_WRAP', - 'BORDER_CONSTANT', 'BORDER_TRANSPARENT','BORDER_ISOLATED'). + add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1', + 'BORDER_REFLECT_101 = 2','BORDER_WRAP = 3', + 'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5','BORDER_ISOLATED = 6'). add_fields('bool', 'is_symm_kernel', 'true'). add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1)) @@ -233,11 +233,11 @@ pdef('Axis').add_fields('int32', 'axis', 0) (pdef('Pooling', version=0, is_legacy=True). add_enum( 'Mode', - Doc('MAX', 'maximum value inside pooling window'), - Doc('AVERAGE', + Doc('MAX = 0', 'maximum value inside pooling window'), + Doc('AVERAGE = 1', 'arithmetic mean of all values inside pooling window. Padding values ' 'are taken into account and are viewed as zero'), - Doc('AVERAGE_COUNT_EXCLUDE_PADDING', + Doc('AVERAGE_COUNT_EXCLUDE_PADDING = 2', 'arithmetic mean of all values inside pooling window. No padding is' 'used.') ). @@ -273,15 +273,15 @@ pdef('Axis').add_fields('int32', 'axis', 0) (pdef('BN'). add_enum( 'ParamDim', - Doc('DIM_11HW', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'), - Doc('DIM_1CHW', 'Dim of params (Sigma, Mu) is 1 x C x H x W'), - Doc('DIM_1C11', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'), + Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'), + Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'), + Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'), name_field='param_dim' ). add_enum( 'FwdMode', - Doc('TRAINING', 'Training phase.'), - Doc('INFERENCE', 'Inference phase.'), + Doc('TRAINING = 0', 'Training phase.'), + Doc('INFERENCE = 1', 'Inference phase.'), name_field='fwd_mode' ). add_fields('float64', 'epsilon', '1e-4f'). @@ -293,22 +293,22 @@ pdef('Axis').add_fields('int32', 'axis', 0) (pdef('ROIPooling'). add_enum( 'Mode', - Doc('MAX', 'maximum value inside pooling window; pooling result would ' + Doc('MAX = 0', 'maximum value inside pooling window; pooling result would ' 'be 0 if pooling window is empty'), - Doc('AVERAGE', + Doc('AVERAGE = 1', 'arithmetic mean of all values inside pooling window; pooling result ' 'would be 0 if pooling window is empty') ). add_fields('float32', 'scale', '1.f')) -INTERP_MODES = ['NEAREST', 'LINEAR', 'AREA', 'CUBIC', 'LANCZOS4'] -BORDER_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'), - Doc('REFLECT', 'fedcba|abcdefgh|hgfedcb'), - Doc('REFLECT_101', 'gfedcb|abcdefgh|gfedcba'), - Doc('WRAP', 'cdefgh|abcdefgh|abcdefg'), - Doc('CONSTANT', 'iiiiii|abcdefgh|iiiiiii'), - Doc('TRANSPARENT', ''), - Doc('ISOLATED', '')] +INTERP_MODES = ['NEAREST = 0', 'LINEAR = 1', 'AREA = 2', 'CUBIC = 3', 'LANCZOS4 = 4'] +BORDER_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), + Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), + Doc('REFLECT_101 = 2', 'gfedcb|abcdefgh|gfedcba'), + Doc('WRAP = 3', 'cdefgh|abcdefgh|abcdefg'), + Doc('CONSTANT = 4', 'iiiiii|abcdefgh|iiiiiii'), + Doc('TRANSPARENT = 5', ''), + Doc('ISOLATED = 6', '')] (pdef('WarpPerspective', version=1, is_legacy=True). add_enum('InterpolationMode', *INTERP_MODES, name_field='imode', default=1, @@ -328,181 +328,181 @@ BORDER_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'), add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')) -pdef('SpatialTfGridGenerator').add_enum('Mode', 'AFFINE') -pdef('SpatialTfSampler').add_enum('Mode', 'BILINEAR') +pdef('SpatialTfGridGenerator').add_enum('Mode', 'AFFINE = 0') +pdef('SpatialTfSampler').add_enum('Mode', 'BILINEAR = 0') pdef('AddUpdate').add_fields( 'float32', 'alpha', '1.f', 'beta', '1.f', 'bias', '0.f') pdef('Elemwise').add_enum( 'Mode', - Doc('RELU', 'unary: max(x, 0)'), - Doc('ABS', 'unary: abs(x)'), - Doc('ACOS', 'unary: acos(x)'), - Doc('ASIN', 'unary: asin(x)'), - Doc('CEIL', 'unary: ceil(x)'), - Doc('COS', 'unary: cos(x)'), - Doc('EXP', 'unary: exp(x)'), - Doc('EXPM1', 'unary: numerically stable exp(x)-1'), - Doc('FLOOR', 'unary: floor(x)'), - Doc('LOG', 'unary: natural logarithm, log(x)'), - Doc('LOG1P', 'unary: numerically stable log(x+1)'), - Doc('NEGATE', 'unary: -x'), - Doc('SIGMOID', 'unary: 1/(1+exp(-x))'), - Doc('SIN', 'unary: sin(x)'), - Doc('TANH', 'unary: tanh(x)'), - - Doc('ABS_GRAD', 'binary: x > 0 ? y : -y'), - Doc('ADD', 'binary: x + y'), - Doc('FLOOR_DIV', 'binary: floor(x / y)'), - Doc('MAX', 'binary: max(x, y)'), - Doc('MIN', 'binary: min(x, y)'), - Doc('MOD', 'binary: x % y or fmodf(x, y)'), - Doc('MUL', 'binary: x * y'), - Doc('POW', 'binary: pow(x, y)'), - Doc('SIGMOID_GRAD', 'binary: x * (1 - x) * y'), - Doc('SUB', 'binary: x - y'), - Doc('SWITCH_GT0', 'binary: (x > 0) * y'), - Doc('TANH_GRAD', 'binary: (1 - x * x) * y'), - Doc('TRUE_DIV', 'binary: x / y'), - Doc('LOG_SUM_EXP', 'binary: numerically stable log(exp(x) + exp(y))'), - - Doc('LT', 'binary: x < y'), - Doc('LEQ', 'binary: x <= y'), - Doc('EQ', 'binary: x == y'), - - Doc('SHL', 'bitwise binary: x << y. ' + Doc('RELU = 0', 'unary: max(x, 0)'), + Doc('ABS = 1', 'unary: abs(x)'), + Doc('ACOS = 2', 'unary: acos(x)'), + Doc('ASIN = 3', 'unary: asin(x)'), + Doc('CEIL = 4', 'unary: ceil(x)'), + Doc('COS = 5', 'unary: cos(x)'), + Doc('EXP = 6', 'unary: exp(x)'), + Doc('EXPM1 = 7', 'unary: numerically stable exp(x)-1'), + Doc('FLOOR = 8', 'unary: floor(x)'), + Doc('LOG = 9', 'unary: natural logarithm, log(x)'), + Doc('LOG1P = 10', 'unary: numerically stable log(x+1)'), + Doc('NEGATE = 11', 'unary: -x'), + Doc('SIGMOID = 12', 'unary: 1/(1+exp(-x))'), + Doc('SIN = 13', 'unary: sin(x)'), + Doc('TANH = 14', 'unary: tanh(x)'), + + Doc('ABS_GRAD = 15', 'binary: x > 0 ? y : -y'), + Doc('ADD = 16', 'binary: x + y'), + Doc('FLOOR_DIV = 17', 'binary: floor(x / y)'), + Doc('MAX = 18', 'binary: max(x, y)'), + Doc('MIN = 19', 'binary: min(x, y)'), + Doc('MOD = 20', 'binary: x % y or fmodf(x, y)'), + Doc('MUL = 21', 'binary: x * y'), + Doc('POW = 22', 'binary: pow(x, y)'), + Doc('SIGMOID_GRAD = 23', 'binary: x * (1 - x) * y'), + Doc('SUB = 24', 'binary: x - y'), + Doc('SWITCH_GT0 = 25', 'binary: (x > 0) * y'), + Doc('TANH_GRAD = 26', 'binary: (1 - x * x) * y'), + Doc('TRUE_DIV = 27', 'binary: x / y'), + Doc('LOG_SUM_EXP = 28', 'binary: numerically stable log(exp(x) + exp(y))'), + + Doc('LT = 29', 'binary: x < y'), + Doc('LEQ = 30', 'binary: x <= y'), + Doc('EQ = 31', 'binary: x == y'), + + Doc('SHL = 32', 'bitwise binary: x << y. ' 'Note that result is undefined if y < 0 or y >= bitwidth. Logical ' 'shift is performed for unsigned intergers, and arithmetic shift for ' 'signed ones.'), - Doc('SHR', 'bitwise binary: x >> y; see SHL mode for more details'), + Doc('SHR = 33', 'bitwise binary: x >> y; see SHL mode for more details'), - Doc('COND_LEQ_MOV', 'ternary: x <= y ? z : 0'), - Doc('FUSE_MUL_ADD3', + Doc('COND_LEQ_MOV = 34', 'ternary: x <= y ? z : 0'), + Doc('FUSE_MUL_ADD3 = 35', 'compute ``a * b + c`` where c must either have same layout as ' 'a or b, or be a scalar'), - Doc('FUSE_MUL_ADD4', + Doc('FUSE_MUL_ADD4 = 36', 'compute ``a * A + b * B`` where a and b must have equal layout, ' 'and A and B must have equal layout. In the inputs ``b`` and ``B`` ' 'can be swapped'), - Doc('FUSE_ADD_RELU', 'binary: max(x+y, 0)'), - Doc('FUSE_ADD_SIGMOID', 'binary: 1/(1+exp(-(x+y)))'), - Doc('FUSE_ADD_TANH', 'binary: tanh(x+y)'), - Doc('FAST_TANH', 'unary: rational approximation of tanh(x)'), - Doc('FAST_TANH_GRAD', 'binary: grad of the rational approximation of tanh(x)'), + Doc('FUSE_ADD_RELU = 37', 'binary: max(x+y, 0)'), + Doc('FUSE_ADD_SIGMOID = 38', 'binary: 1/(1+exp(-(x+y)))'), + Doc('FUSE_ADD_TANH = 39', 'binary: tanh(x+y)'), + Doc('FAST_TANH = 40', 'unary: rational approximation of tanh(x)'), + Doc('FAST_TANH_GRAD = 41', 'binary: grad of the rational approximation of tanh(x)'), - Doc('ROUND', 'unary: round(x), the nearest integer value to x, rounding ' + Doc('ROUND = 42', 'unary: round(x), the nearest integer value to x, rounding ' 'halfway cases away from zero. Float only.'), - Doc('RMULH', 'binary: rounded higher l bits of x * y, where l is the bit ' + Doc('RMULH = 43', 'binary: rounded higher l bits of x * y, where l is the bit ' 'length of x.'), - Doc('ATAN2','binary: atan2(y,x)'), - Doc('ERF', 'unary: erf(x)'), - Doc('ERFINV', 'unary: inverse function of erf(x)'), - Doc('ERFC', 'unary: erfc(x)'), - Doc('ERFCINV', 'unary: inverse function of erfc(x)'), - Doc('H_SWISH', 'unary: x * clip(x + 3, 0, 6) / 6'), - Doc('H_SWISH_GRAD', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'), - Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)'), - - Doc('NOT', 'unary: !x'), - Doc('AND', 'binary: x && y'), - Doc('OR', 'binary: x || y'), - Doc('XOR', 'binary: x ^ y'), - Doc('SILU', 'unary: x / (1 + exp(-x))'), - Doc('SILU_GRAD', 'binary: grad(x / (1 + exp(-x))'), - Doc('GELU', 'unary: x Phi(x)'), - Doc('GELU_GRAD', 'binary: grad(x Phi(x))'), + Doc('ATAN2 = 44','binary: atan2(y,x)'), + Doc('ERF = 45', 'unary: erf(x)'), + Doc('ERFINV = 46', 'unary: inverse function of erf(x)'), + Doc('ERFC = 47', 'unary: erfc(x)'), + Doc('ERFCINV = 48', 'unary: inverse function of erfc(x)'), + Doc('H_SWISH = 49', 'unary: x * clip(x + 3, 0, 6) / 6'), + Doc('H_SWISH_GRAD = 50', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'), + Doc('FUSE_ADD_H_SWISH = 51', 'binary: hswish(x+y)'), + + Doc('NOT = 52', 'unary: !x'), + Doc('AND = 53', 'binary: x && y'), + Doc('OR = 54', 'binary: x || y'), + Doc('XOR = 55', 'binary: x ^ y'), + Doc('SILU = 56', 'unary: x / (1 + exp(-x))'), + Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'), + Doc('GELU = 58', 'unary: x Phi(x)'), + Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'), ) pdef('ElemwiseMultiType').add_enum( 'Mode', - Doc('FUSE_MUL_ADD3_INT16x32x32x32', + Doc('FUSE_MUL_ADD3_INT16x32x32x32 = 0', 'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and ' '``c`` int32, and the result is int32. This mode is optimized for ' 'the channel-broadacsted case, i.e. ``a`` has shape (A, B, C) and ' '``b`` and ``c`` have shape (1, C, 1)'), - Doc('FUSE_MUL_ADD3_IXxF32xF32xI8', + Doc('FUSE_MUL_ADD3_IXxF32xF32xI8 = 1', 'compuate ``a * b + c`` where the inputs ``a`` is an integer type ' '``b`` and ``c`` are both ``float32``, the result is ' '``int8``. This is currently only optimized for ``(1, x)`` ' 'broadcast for ``b`` and ``c``. Computation is carried in floating ' 'points and results are rounded towards zero with saturated cast to ' 'int.'), - Doc('ROUND_SHR_SATURATE_IXxI8xI8', + Doc('ROUND_SHR_SATURATE_IXxI8xI8 = 2', 'Compute ``a >> b``, round the result according to lower ``b`` bits ' 'of ``a``` and make a saturating conversion to int8. Where ``a`` should' ' be an integer tensor and ``b`` should be an int8 scalar.'), - Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8', + Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8 = 3', 'Fused operation of an int16 elemwise add, an int16 rounding multiply ' 'high and an int16 to int8 rounding right shift with saturation.'), - Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8', + Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8 = 4', 'Fused operation of an int32 elemwise add, an int32 rounding multiply ' 'high and an int32 to int8 rounding right shift with saturation.'), - Doc('ROUND_SHR_SATURATE_IXxI8xI16', + Doc('ROUND_SHR_SATURATE_IXxI8xI16 = 5', 'Compute ``a >> b``, round the result according to lower ``b`` bits of ' '``a``` and make a saturating conversion to int16. Where ``a`` should' ' be an integer tensor and ``b`` should be an int8 scalar.'), - Doc('QADD', 'Fused elemwise add two quantized int8 with specified' + Doc('QADD = 6', 'Fused elemwise add two quantized int8 with specified' 'output quantized dtype'), - Doc('QFUSE_ADD_RELU', 'Fused elemwise add two quantized int8 followed' + Doc('QFUSE_ADD_RELU = 7', 'Fused elemwise add two quantized int8 followed' ' by ReLU and typecvt to specified dtype'), - Doc('QMUL', 'Fused elemwise multiply two quantized int8 with specified' + Doc('QMUL = 8', 'Fused elemwise multiply two quantized int8 with specified' 'output quantized dtype'), - Doc('QMIN', 'Fused elemwise min two quantized int8 with specified' + Doc('QMIN = 9', 'Fused elemwise min two quantized int8 with specified' 'output quantized dtype'), - Doc('QMAX', 'quantized: max(x, y), with specified output quantized dtype'), - Doc('QSUB', 'quantized: x - y'), - Doc('QTRUE_DIV', 'quantized: x / y'), - Doc('QFUSE_ADD_SIGMOID', 'quantized: sigmoid(x + y)'), - Doc('QFUSE_ADD_TANH', 'quantized: tanh(x + y)'), - Doc('QRELU', 'quantized: x > 0 ? x : 0'), - Doc('QABS', 'quantized: x > 0 ? x : -x'), - Doc('QSIGMOID', 'quantized: sigmoid(x)'), - Doc('QEXP', 'quantized: exp(x)'), - Doc('QTANH', 'quantized: tanh(x)'), - Doc('QFUSE_MUL_ADD3', 'quantized: x * y + z'), - Doc('QFAST_TANH', 'quantized: fast_tanh(x)'), - Doc('QNEGATE', 'quantized: -x'), - Doc('QACOS', 'quantized: acos(x)'), - Doc('QASIN', 'quantized: asin(x)'), - Doc('QCEIL', 'quantized: ceil(x)'), - Doc('QCOS', 'quantized: cos(x)'), - Doc('QEXPM1', 'quantized: expm1(x)'), - Doc('QFLOOR', 'quantized: floor(x)'), - Doc('QLOG', 'quantized: log(x)'), - Doc('QLOG1P', 'quantized: log1p(x)'), - Doc('QSIN', 'quantized: sin(x)'), - Doc('QROUND', 'quantized: round(x)'), - Doc('QERF', 'quantized: erf(x)'), - Doc('QERFINV', 'quantized: erfinv(x)'), - Doc('QERFC', 'quantized: erfc(x)'), - Doc('QERFCINV', 'quantized: erfcinv(x)'), - Doc('QABS_GRAD', 'quantized: abs_grad'), - Doc('QFLOOR_DIV', 'quantized floor_div'), - Doc('QMOD', 'quantized mod'), - Doc('QSIGMOID_GRAD', 'quantized sigmoid_grad'), - Doc('QSWITCH_GT0', 'quantized switch_gt0'), - Doc('QTANH_GRAD', 'quantized tanh_grad'), - Doc('QLT', 'quantized lt'), - Doc('QLEQ', 'quantized leq'), - Doc('QEQ', 'quantized eq'), - Doc('QPOW', 'quantized pow'), - Doc('QLOG_SUM_EXP', 'quantized log_sum_exp'), - Doc('QFAST_TANH_GRAD', 'quantized fast_tanh_grad'), - Doc('QATAN2', 'quantized atan2'), - Doc('QCOND_LEQ_MOV', 'quantized cond_leq_mov'), - Doc('QH_SWISH', 'quantized h_swish'), - Doc('QFUSE_ADD_H_SWISH', 'quantized h_swish(x+y)'), - Doc('QH_SWISH_GRAD', 'quantized h_swish_grad') + Doc('QMAX = 10', 'quantized: max(x, y), with specified output quantized dtype'), + Doc('QSUB = 11', 'quantized: x - y'), + Doc('QTRUE_DIV = 12', 'quantized: x / y'), + Doc('QFUSE_ADD_SIGMOID = 13', 'quantized: sigmoid(x + y)'), + Doc('QFUSE_ADD_TANH = 14', 'quantized: tanh(x + y)'), + Doc('QRELU = 15', 'quantized: x > 0 ? x : 0'), + Doc('QABS = 16', 'quantized: x > 0 ? x : -x'), + Doc('QSIGMOID = 17', 'quantized: sigmoid(x)'), + Doc('QEXP = 18', 'quantized: exp(x)'), + Doc('QTANH = 19', 'quantized: tanh(x)'), + Doc('QFUSE_MUL_ADD3 = 20', 'quantized: x * y + z'), + Doc('QFAST_TANH = 21', 'quantized: fast_tanh(x)'), + Doc('QNEGATE = 22', 'quantized: -x'), + Doc('QACOS = 23', 'quantized: acos(x)'), + Doc('QASIN = 24', 'quantized: asin(x)'), + Doc('QCEIL = 25', 'quantized: ceil(x)'), + Doc('QCOS = 26', 'quantized: cos(x)'), + Doc('QEXPM1 = 27', 'quantized: expm1(x)'), + Doc('QFLOOR = 28', 'quantized: floor(x)'), + Doc('QLOG = 29', 'quantized: log(x)'), + Doc('QLOG1P = 30', 'quantized: log1p(x)'), + Doc('QSIN = 31', 'quantized: sin(x)'), + Doc('QROUND = 32', 'quantized: round(x)'), + Doc('QERF = 33', 'quantized: erf(x)'), + Doc('QERFINV = 34', 'quantized: erfinv(x)'), + Doc('QERFC = 35', 'quantized: erfc(x)'), + Doc('QERFCINV = 36', 'quantized: erfcinv(x)'), + Doc('QABS_GRAD = 37', 'quantized: abs_grad'), + Doc('QFLOOR_DIV = 38', 'quantized floor_div'), + Doc('QMOD = 39', 'quantized mod'), + Doc('QSIGMOID_GRAD = 40', 'quantized sigmoid_grad'), + Doc('QSWITCH_GT0 = 41', 'quantized switch_gt0'), + Doc('QTANH_GRAD = 42', 'quantized tanh_grad'), + Doc('QLT = 43', 'quantized lt'), + Doc('QLEQ = 44', 'quantized leq'), + Doc('QEQ = 45', 'quantized eq'), + Doc('QPOW = 46', 'quantized pow'), + Doc('QLOG_SUM_EXP = 47', 'quantized log_sum_exp'), + Doc('QFAST_TANH_GRAD = 48', 'quantized fast_tanh_grad'), + Doc('QATAN2 = 49', 'quantized atan2'), + Doc('QCOND_LEQ_MOV = 50', 'quantized cond_leq_mov'), + Doc('QH_SWISH = 51', 'quantized h_swish'), + Doc('QFUSE_ADD_H_SWISH = 52', 'quantized h_swish(x+y)'), + Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad') ) pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) (pdef('DctChannelSelect', '2d discrete cosine transform', version=0, is_legacy=True).add_enum_alias('Format', 'ConvolutionV0'). - add_enum('FastImpl', 'NONE', 'FIX_32_MASK').add_fields('int32', 'dct_block_size', 8)) + add_enum('FastImpl', 'NONE = 0', 'FIX_32_MASK = 1').add_fields('int32', 'dct_block_size', 8)) (pdef('DctChannelSelect', '2d discrete cosine transform', version=1).add_enum_alias('Format', 'Convolution'). add_enum_alias('FastImpl', 'DctChannelSelectV0').add_fields('int32', 'dct_block_size', 8)) @@ -510,13 +510,13 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) (pdef('MatrixMul', version=0, is_legacy=True). add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). add_enum('DataType', - Doc('FLOAT', 'input/output both float32/float16'), - 'INT8x8x16', - 'INT8x8x32', - Doc('FLOAT_IO16xC32', 'input/output both float16, the internal compute is ' + Doc('FLOAT = 0', 'input/output both float32/float16'), + 'INT8x8x16 = 1', + 'INT8x8x32 = 2', + Doc('FLOAT_IO16xC32 = 3', 'input/output both float16, the internal compute is ' 'float32'), - Doc('QUINT8x8x32', 'input QuantizedAsymm8, output QuantizedS32'), - Doc('QUINT4x4x32', 'input QuantizedAsymm4, output QuantizedS32'), + Doc('QUINT8x8x32 = 4', 'input QuantizedAsymm8, output QuantizedS32'), + Doc('QUINT4x4x32 = 5', 'input QuantizedAsymm4, output QuantizedS32'), name_field='data_type')) (pdef('MatrixMul', version=1, is_legacy=True). @@ -524,9 +524,9 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. ' 'different combinations of intermediate result ' 'data types.'), - Doc('DEFAULT', 'No special requirements on the precision of ' + Doc('DEFAULT = 0', 'No special requirements on the precision of ' 'intermediate results.'), - Doc('FLOAT32', 'Use Float32 accumulator and intermediate result. ' + Doc('FLOAT32 = 1', 'Use Float32 accumulator and intermediate result. ' 'Only supported when input and output is Float16.'), name_field='compute_mode')) @@ -534,14 +534,14 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). add_enum_alias('ComputeMode', 'MatrixMulV1', name_field='compute_mode'). add_enum('Format', - Doc('DEFAULT', 'Normal matrix mul: (M, K) x (K, N) = (M, N)'), - Doc('MK4', 'Split 4 from M and K, better for neon compute:' + Doc('DEFAULT = 0', 'Normal matrix mul: (M, K) x (K, N) = (M, N)'), + Doc('MK4 = 1', 'Split 4 from M and K, better for neon compute:' '(M/4, K/4, 4(k), 4(m)) x (K/4, N, 4(k)). if transposeA the ' 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), - Doc('MK8', 'Split 8 from M and K, better for neon compute:' + Doc('MK8 = 2', 'Split 8 from M and K, better for neon compute:' '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), - Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' + Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:' 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) ) @@ -560,9 +560,9 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) (pdef('Reduce', 'legacy reduce', version=0, is_legacy=True). add_enum('Mode', - 'SUM', - Doc('SUM_SQR', 'sum of x * x for each element x'), - 'PRODUCT', 'MIN', 'MAX'). + 'SUM = 0', + Doc('SUM_SQR = 1', 'sum of x * x for each element x'), + 'PRODUCT = 2', 'MIN = 3', 'MAX = 4'). add_fields('int32', Doc('axis', 'axis along which reduction is performed; if -1 is given, ' @@ -571,16 +571,16 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) (pdef('Reduce', 'reduce along given axis', version=1, is_legacy=True). add_enum('Mode', - 'SUM', - Doc('SUM_SQR', 'sum of x * x for each element x'), - 'PRODUCT', 'MIN', 'MAX', 'MEAN'). + 'SUM = 0', + Doc('SUM_SQR = 1', 'sum of x * x for each element x'), + 'PRODUCT = 2', 'MIN = 3', 'MAX = 4', 'MEAN = 5'). add_fields('int32', Doc('axis', 'axis along which reduction is performed; if -1 is given, ' 'reduce to given target shape (only used in megbrain)'), -1). add_enum('DataType', - Doc('DEFAULT', + Doc('DEFAULT = 0', ''' input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode. Currently, ```DEFAULT``` mode means: @@ -607,26 +607,26 @@ Currently, ```DEFAULT``` mode means: ''' ), - Doc('FLOAT_IO16xC32', 'Deprecated. This was replaced by ' + Doc('FLOAT_IO16xC32 = 1', 'Deprecated. This was replaced by ' 'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'), - Doc('FLOAT_O32xC32', 'compute/output both are float32'), - Doc('FLOAT_O16xC32', 'compute are float32, output float16'), - Doc('QUINT_I8xO32', 'input quint8, compute and output are qint32'), - Doc('QINT_I8xO32', 'input qint8, compute and output are qint32'), + Doc('FLOAT_O32xC32 = 2', 'compute/output both are float32'), + Doc('FLOAT_O16xC32 = 3', 'compute are float32, output float16'), + Doc('QUINT_I8xO32 = 4', 'input quint8, compute and output are qint32'), + Doc('QINT_I8xO32 = 5', 'input qint8, compute and output are qint32'), name_field='data_type')) (pdef('Reduce', 'reduce along given axis', version=2). add_enum('Mode', - 'SUM', - Doc('SUM_SQR', 'sum of x * x for each element x'), - 'PRODUCT', 'MIN', 'MAX', 'MEAN'). + 'SUM = 0', + Doc('SUM_SQR = 1', 'sum of x * x for each element x'), + 'PRODUCT = 2', 'MIN = 3', 'MAX = 4', 'MEAN = 5'). add_fields('int32', Doc('axis', 'axis along which reduction is performed; if INT_MAX is given, ' 'reduce to given target shape (only used in megbrain)'), (1<<31)-1). add_enum('DataType', - Doc('DEFAULT', + Doc('DEFAULT = 0', ''' input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode. Currently, ```DEFAULT``` mode means: @@ -653,12 +653,12 @@ Currently, ```DEFAULT``` mode means: ''' ), - Doc('FLOAT_IO16xC32', 'Deprecated. This was replaced by ' + Doc('FLOAT_IO16xC32 = 1', 'Deprecated. This was replaced by ' 'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'), - Doc('FLOAT_O32xC32', 'compute/output both are float32'), - Doc('FLOAT_O16xC32', 'compute are float32, output float16'), - Doc('QUINT_I8xO32', 'input quint8, compute and output are qint32'), - Doc('QINT_I8xO32', 'input qint8, compute and output are qint32'), + Doc('FLOAT_O32xC32 = 2', 'compute/output both are float32'), + Doc('FLOAT_O16xC32 = 3', 'compute are float32, output float16'), + Doc('QUINT_I8xO32 = 4', 'input quint8, compute and output are qint32'), + Doc('QINT_I8xO32 = 5', 'input qint8, compute and output are qint32'), name_field='data_type')) (pdef('Cumsum', 'calculate accumulated sum along given axis', version=0, is_legacy=True). @@ -691,12 +691,12 @@ Currently, ```DEFAULT``` mode means: (pdef('CondTake'). add_enum('Mode', - Doc('EQ', 'take if ``abs(data-val)=eps``'), - Doc('LT', 'take if ``dataval``'), - Doc('GEQ', 'take if ``data>=val``')). + Doc('EQ = 0', 'take if ``abs(data-val)=eps``'), + Doc('LT = 2', 'take if ``dataval``'), + Doc('GEQ = 5', 'take if ``data>=val``')). add_fields('float32', Doc('val', 'the value to be compared with; note that for integer ' 'data, val is also converted to int'), 0). @@ -704,7 +704,7 @@ Currently, ```DEFAULT``` mode means: 1e-6)) -pdef('Argsort').add_enum('Order', 'ASCENDING', 'DESCENDING') +pdef('Argsort').add_enum('Order', 'ASCENDING = 0', 'DESCENDING = 1') (pdef('IndexingRemap'). add_fields('bool', @@ -791,17 +791,17 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) .add_fields('uint32', 'row_from', 0, 'row_to', 0, 'col_from', 0, 'col_to', 0)) (pdef('CvtColor') - .add_enum('Mode', 'RGB2GRAY', 'RGB2YUV', 'YUV2RGB', 'GRAY2RGB', 'RGBA2RGB', - 'RGBA2BGR', 'RGBA2GRAY', 'RGB2BGR', 'BGR2GRAY', 'BGR2RGB', - Doc('YUV2GRAY_NV21', 'For historical reasons, referred to as YCC by opencv'), - 'YUV2RGB_NV21', 'YUV2BGR_NV21', 'YUV2GRAY_NV12', 'YUV2RGB_NV12', - 'YUV2BGR_NV12', 'YUV2GRAY_YV12', 'YUV2RGB_YV12', 'YUV2BGR_YV12', - 'YUV2GRAY_YU12', 'YUV2RGB_YU12', 'YUV2BGR_YU12', - 'YCrCb2RGB', 'YCrCb2BGR', - Doc('BT601_YUV2RGB_NV21', 'BT601 yuv format, referred to as YUV by opencv'), - 'BT601_YUV2BGR_NV21', 'BT601_YUV2RGB_NV12', 'BT601_YUV2BGR_NV12', - 'BT601_YUV2RGB_YV12', 'BT601_YUV2BGR_YV12', 'BT601_YUV2RGB_YU12', - 'BT601_YUV2BGR_YU12', + .add_enum('Mode', 'RGB2GRAY = 0', 'RGB2YUV = 1', 'YUV2RGB = 2', 'GRAY2RGB = 3', 'RGBA2RGB = 4', + 'RGBA2BGR = 5', 'RGBA2GRAY = 6', 'RGB2BGR = 7', 'BGR2GRAY = 8', 'BGR2RGB = 9', + Doc('YUV2GRAY_NV21 = 10', 'For historical reasons, referred to as YCC by opencv'), + 'YUV2RGB_NV21 = 11', 'YUV2BGR_NV21 = 12', 'YUV2GRAY_NV12 = 13', 'YUV2RGB_NV12 = 14', + 'YUV2BGR_NV12 = 15', 'YUV2GRAY_YV12 = 16', 'YUV2RGB_YV12 = 17', 'YUV2BGR_YV12 = 18', + 'YUV2GRAY_YU12 = 19', 'YUV2RGB_YU12 = 20', 'YUV2BGR_YU12 = 21', + 'YCrCb2RGB = 22', 'YCrCb2BGR = 23', + Doc('BT601_YUV2RGB_NV21 = 24', 'BT601 yuv format, referred to as YUV by opencv'), + 'BT601_YUV2BGR_NV21 = 25', 'BT601_YUV2RGB_NV12 = 26', 'BT601_YUV2BGR_NV12 = 27', + 'BT601_YUV2RGB_YV12 = 28', 'BT601_YUV2BGR_YV12 = 29', 'BT601_YUV2RGB_YU12 = 30', + 'BT601_YUV2BGR_YU12 = 31', member_alias=[('YUV2GRAY_NV21', 'BT601_YUV2GRAY_NV21'), ('YUV2GRAY_NV12', 'BT601_YUV2GRAY_NV12'), ('YUV2GRAY_YV12', 'BT601_YUV2GRAY_YV12'), @@ -855,7 +855,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) .add_fields('float32', 'scalar', '0.f')) (pdef('Convolution3D'). - add_enum('Mode', 'CROSS_CORRELATION', 'CONVOLUTION'). + add_enum('Mode', 'CROSS_CORRELATION = 0', 'CONVOLUTION = 1'). add_fields( 'uint32', Doc('pad_d', 'padding on one side on the first dimension'), 0, @@ -872,32 +872,32 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) 'on the third dimension'), 1 ). add_enum('Sparse', - Doc('DENSE', 'dense convolution: filter shape should be ' + Doc('DENSE = 0', 'dense convolution: filter shape should be ' '[oc, ic, spatial...] if format is NCDHW, ' '[oc, spatial..., ic] if format is NDHWC'), - Doc('GROUP', 'group convolution: filter shape should be ' + Doc('GROUP = 1', 'group convolution: filter shape should be ' '[group, oc_per_group, ic_per_group, spatial...] if format is NCDHW, ' '[group, oc_per_group, spatial..., ic_per_group] if format is NDHWC') ). add_enum('DataType', - Doc('FLOAT', 'input/output both float32/float16'), - Doc('FLOAT_IO16xC32', 'input/output both float16, the internal ' + Doc('FLOAT = 0', 'input/output both float32/float16'), + Doc('FLOAT_IO16xC32 = 1', 'input/output both float16, the internal ' 'compute is float32'), name_field='data_type'). - add_enum('Format', 'NCDHW', 'NDHWC') + add_enum('Format', 'NCDHW = 0', 'NDHWC = 1') ) (pdef('Conv3DBias'). - add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID'). + add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2'). add_enum_alias('Mode', 'Convolution3D'). add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0, 'stride_d', 1, 'stride_h', 1, 'stride_w', 0)) (pdef('SeparableConv3D'). add_enum_alias('Mode', 'Convolution3D'). - add_enum('BorderMode', 'BORDER_REPLICATE', 'BORDER_REFLECT', - 'BORDER_REFLECT_101','BORDER_WRAP', - 'BORDER_CONSTANT', 'BORDER_TRANSPARENT','BORDER_ISOLATED'). + add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1', + 'BORDER_REFLECT_101 = 2','BORDER_WRAP = 3', + 'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5','BORDER_ISOLATED = 6'). add_fields('bool', 'is_symm_kernel', 'true'). add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0, 'stride_d', 0, 'stride_h', 1, 'stride_w', 1, @@ -907,11 +907,11 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) (pdef('TopK'). add_enum( 'Mode', - Doc('KTH_ONLY', "only the value of the k'th element would be computed"), - Doc('VALUE_IDX_NOSORT', + Doc('KTH_ONLY = 0', "only the value of the k'th element would be computed"), + Doc('VALUE_IDX_NOSORT = 1', 'all the top-k values and corresponding indices would be computed; ' 'no order is guaranteed'), - Doc('VALUE_IDX_SORTED', + Doc('VALUE_IDX_SORTED = 2', 'all the top-k values and corresponding indices sorted')) ) @@ -983,37 +983,37 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o (pdef('RelayoutFormat', 'Change the tensor layout format', version=0, is_legacy=True). add_enum( Doc('Mode', RELAYOUT_FORMAT_MODE_DOC), - 'NHWC_NHWCD4', - 'NHWCD4_NHWC', - 'NHWC_NHWCD4I', - 'NCHW_NHWCD4', - 'NCHW_NHWCD4I', - 'NHWCD4I_NCHW', - 'NHWCD4_NCHW', - 'INTER_WEIGHT_DENSE', - 'INTER_WEIGHT_DENSEI', - 'INTER_WEIGHT_GROUP', - 'INTER_WEIGHT_GROUPI', - 'INTER_WEIGHT_CHAN', - 'INTER_WEIGHT_CHANI', - 'INTER_WEIGHT_DENSEI_DOT', - 'INTER_WEIGHT_GROUPI_DOT', - 'NCHW4_CHWN4', - 'CHWN4_NCHW4', - 'NCHW_NCHW88_CONV_DENSE_WEIGHT', - 'NCHW_NCHW88_CONV_CHAN_WEIGHT', - 'NCHW_NCHW88_CONV_GROUP_WEIGHT', - 'NCHW_NCHW88', - 'NCHW88_NCHW', - 'NCHW_NCHW4_IC_SMALL', - 'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT', - 'NCHW_NCHW4', - 'NCHW4_NCHW', - 'NCHW_NCHW4_WEIGHT', - 'NCHW_NCHW64', - 'NCHW64_NCHW', - 'NCHW_NHWC', - 'NHWC_NCHW', + 'NHWC_NHWCD4 = 0', + 'NHWCD4_NHWC = 1', + 'NHWC_NHWCD4I = 2', + 'NCHW_NHWCD4 = 3', + 'NCHW_NHWCD4I = 4', + 'NHWCD4I_NCHW = 5', + 'NHWCD4_NCHW = 6', + 'INTER_WEIGHT_DENSE = 7', + 'INTER_WEIGHT_DENSEI = 8', + 'INTER_WEIGHT_GROUP = 9', + 'INTER_WEIGHT_GROUPI = 10', + 'INTER_WEIGHT_CHAN = 11', + 'INTER_WEIGHT_CHANI = 12', + 'INTER_WEIGHT_DENSEI_DOT = 13', + 'INTER_WEIGHT_GROUPI_DOT = 14', + 'NCHW4_CHWN4 = 15', + 'CHWN4_NCHW4 = 16', + 'NCHW_NCHW88_CONV_DENSE_WEIGHT = 17', + 'NCHW_NCHW88_CONV_CHAN_WEIGHT = 18', + 'NCHW_NCHW88_CONV_GROUP_WEIGHT = 19', + 'NCHW_NCHW88 = 20', + 'NCHW88_NCHW = 21', + 'NCHW_NCHW4_IC_SMALL = 22', + 'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT = 23', + 'NCHW_NCHW4 = 24', + 'NCHW4_NCHW = 25', + 'NCHW_NCHW4_WEIGHT = 26', + 'NCHW_NCHW64 = 27', + 'NCHW64_NCHW = 28', + 'NCHW_NHWC = 29', + 'NHWC_NCHW = 30', ) ) @@ -1077,7 +1077,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o (pdef('ROIAlign',version=0,is_legacy=True). - add_enum('Mode', 'MAX', 'AVERAGE', name_field='mode'). + add_enum('Mode', 'MAX = 0', 'AVERAGE = 1', name_field='mode'). add_enum_alias('Format', 'ConvolutionV0'). add_fields('float32', 'spatial_scale', '1.0'). add_fields('float32', 'offset', '0.0'). @@ -1173,9 +1173,9 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o pdef('Fill').add_fields('float32', 'value', '0') -PADDING_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'), - Doc('REFLECT', 'fedcba|abcdefgh|hgfedcb'), - Doc('CONSTANT', 'iiiiii|abcdefgh|iiiiiii')] +PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), + Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), + Doc('CONSTANT = 2', 'iiiiii|abcdefgh|iiiiiii')] (pdef('Padding'). add_fields('uint32', Doc('front_offset_dim0','offset in dim 0'), 0). add_fields('uint32', Doc('front_offset_dim1','offset in dim 1'), 0). diff --git a/imperative/tablegen/helper.h b/imperative/tablegen/helper.h index c881618e5..418bf83fb 100644 --- a/imperative/tablegen/helper.h +++ b/imperative/tablegen/helper.h @@ -241,14 +241,17 @@ private: if (auto* enumAttr = llvm::dyn_cast(&it.attr)) { body += formatv(" switch ({0}){{\n", "$_self." + it.name); for (auto&& enumMember: enumAttr->getEnumMembers()) { - body += formatv( - " case {0}::{1}::{2}:\n", - getCppClassName(), enumAttr->getEnumName(), enumMember - ); - body += formatv( - " props_.emplace_back(\"{0}\", \"{1}\");\n", - it.name, enumMember - ); + size_t d1 = enumMember.find(' '); + size_t d2 = enumMember.find('='); + size_t d = d1 <= d2 ? d1 : d2; + body += formatv(" case {0}::{1}::{2}:\n", + getCppClassName(), + enumAttr->getEnumName(), + enumMember.substr(0, d)); + body += + formatv(" props_.emplace_back(\"{0}\", " + "\"{1}\");\n", + it.name, enumMember.substr(0, d)); body += " break;\n"; } body += " default: break;\n"; diff --git a/imperative/tablegen/targets/cpp_class.cpp b/imperative/tablegen/targets/cpp_class.cpp index e7285f14d..fa437209a 100644 --- a/imperative/tablegen/targets/cpp_class.cpp +++ b/imperative/tablegen/targets/cpp_class.cpp @@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() { std::vector case_body; std::string ename = formatv("{0}::{1}", op.getCppClassName(), attr->getEnumName()); - llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ - case_body.push_back(formatv( - "case {0}::{1}: return \"{1}\";", ename, v)); + llvm::for_each(attr->getEnumMembers(), [&](auto&& v) { + size_t d1 = v.find(' '); + size_t d2 = v.find('='); + size_t d = d1 <= d2 ? d1 : d2; + case_body.push_back( + formatv("case {0}::{1}: return \"{1}\";", ename, + v.substr(0, d))); }); os << formatv(R"( template <> diff --git a/imperative/tablegen/targets/pybind11.cpp b/imperative/tablegen/targets/pybind11.cpp index 714b4f960..9399648c6 100644 --- a/imperative/tablegen/targets/pybind11.cpp +++ b/imperative/tablegen/targets/pybind11.cpp @@ -50,14 +50,15 @@ void OpDefEmitter::emit() { ); std::vector body; for (auto&& i: attr->getEnumMembers()) { - os << formatv( - "\n .value(\"{2}\", {0}::{1}::{2})", - className, attr->getEnumName(), i - ); + size_t d1 = i.find(' '); + size_t d2 = i.find('='); + size_t d = d1 <= d2 ? d1 : d2; + os << formatv("\n .value(\"{2}\", {0}::{1}::{2})", + className, attr->getEnumName(), + i.substr(0, d)); body.push_back(formatv( - "if (str == \"{2}\") return {0}::{1}::{2};", - className, attr->getEnumName(), i - )); + "if (str == \"{2}\") return {0}::{1}::{2};", + className, attr->getEnumName(), i.substr(0, d))); } if (attr->getEnumCombinedFlag()) { //! define operator | diff --git a/imperative/tablegen/targets/python_c_extension.cpp b/imperative/tablegen/targets/python_c_extension.cpp index ff1356225..49ba7f490 100644 --- a/imperative/tablegen/targets/python_c_extension.cpp +++ b/imperative/tablegen/targets/python_c_extension.cpp @@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() { &ctx); auto quote = [&](auto&& i) -> std::string { - return formatv("\"{0}\"", i); + size_t d1 = i.find(' '); + size_t d2 = i.find('='); + size_t d = d1 <= d2 ? d1 : d2; + return formatv("\"{0}\"", i.substr(0, d)); }; os << tgfmt(R"( template<> const char* @@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0}; )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", ")); auto mem2value = [&](auto&& i) -> std::string { - return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); + size_t d1 = i.find(' '); + size_t d2 = i.find('='); + size_t d = d1 <= d2 ? d1 : d2; + return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, + i.substr(0, d)); }; os << tgfmt(R"( template<> std::unordered_map @@ -192,12 +199,15 @@ os << tgfmt(R"( auto&& members = attr->getEnumMembers(); for (size_t idx = 0; idx < members.size(); ++ idx) { + size_t d1 = members[idx].find(' '); + size_t d2 = members[idx].find('='); + size_t d = d1 <= d2 ? d1 : d2; os << tgfmt(R"({ PyObject* inst = e_type->tp_alloc(e_type, 0); reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0); $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst; -})", &ctx, members[idx], idx); +})", &ctx, members[idx].substr(0, d), idx); } } diff --git a/tools/gen_header_for_bin_reduce.py b/tools/gen_header_for_bin_reduce.py index 337b83db5..38c7ba16e 100755 --- a/tools/gen_header_for_bin_reduce.py +++ b/tools/gen_header_for_bin_reduce.py @@ -136,12 +136,13 @@ class HeaderGen: mode_list = [i.strip() for i in fin] for i in mode_list: + i = i.split(' ')[0].split('=')[0] if i in self._elemwise_modes: content = '_cb({})'.format(i) else: content = '' self._write_def( - '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content) + '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content) self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') diff --git a/tools/param_defs/mgb_opr_param_defs.py b/tools/param_defs/mgb_opr_param_defs.py index 16adfcb00..03f5bda60 100644 --- a/tools/param_defs/mgb_opr_param_defs.py +++ b/tools/param_defs/mgb_opr_param_defs.py @@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields( (pdef('ExecutionPolicy', version=0, is_legacy=True). add_enum('Strategy', - Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), - Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, ' + Doc('HEURISTIC = 0', 'use heuristic to choose the fastest algorithm'), + Doc('HEURISTIC_REPRODUCIBLE = 1', 'use heuristic to choose the fastest algorithm, ' 'and the chosen algorithm is reproducible'), - Doc('PROFILE', + Doc('PROFILE = 2', 'run possible algorithms on real device to find the best'), - Doc('PROFILE_REPRODUCIBLE', + Doc('PROFILE_REPRODUCIBLE = 3', 'the fastest of profile result that is also reproducible'), - Doc('PROFILE_HEURISTIC', + Doc('PROFILE_HEURISTIC = 4', 'use profile result and heuristic to choose the fastest algorithm')). add_fields('uint64', Doc('workspace_limit', 'workspace limit in bytes'), @@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields( (pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1). add_bit_combination_enum('Strategy', - Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), - Doc('PROFILE', + Doc('HEURISTIC = 1 << 0', 'use heuristic to choose the fastest algorithm'), + Doc('PROFILE = 1 << 1', 'run possible algorithms on real device to find the best'), - Doc('REPRODUCIBLE', + Doc('REPRODUCIBLE = 1 << 2', 'when profile or heuristic algo selection it require the algos' 'must be reproducible'), - Doc('OPTIMIZED', + Doc('OPTIMIZED = 1 << 3', 'profile require algos are optmized to achieve fast-profile'), default=('HEURISTIC',), member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'), @@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields( (pdef('CollectiveComm', 'collective communication between multiple computing ' 'nodes on localhost') .add_enum(Doc('Mode', 'mode of collective communication'), - Doc('REDUCE_SUM', 'reduce by sum to output computing node'), - Doc('BROADCAST', 'copy input value to each output computing node'), - Doc('ALL_GATHER', 'each output comp node gets the concatenated ' + Doc('REDUCE_SUM = 0', 'reduce by sum to output computing node'), + Doc('BROADCAST = 1', 'copy input value to each output computing node'), + Doc('ALL_GATHER = 2', 'each output comp node gets the concatenated ' 'value of all inputs'), - Doc('REDUCE_SCATTER_SUM', + Doc('REDUCE_SCATTER_SUM = 3', 'reduce inputs by sum and each output gets one part of it'), - Doc('ALL_REDUCE_SUM', 'every output gets the sum of all inputs'), - Doc('ALL_REDUCE_MAX', 'every output gets the max of all inputs'), - Doc('ALL_REDUCE_MIN', 'every output gets the min of all inputs'), - Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'), - Doc('GATHER', 'concat inputs to one node'), - Doc('SCATTER', 'scatter input to each output computing node'), - Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'), + Doc('ALL_REDUCE_SUM = 4', 'every output gets the sum of all inputs'), + Doc('ALL_REDUCE_MAX = 5', 'every output gets the max of all inputs'), + Doc('ALL_REDUCE_MIN = 6', 'every output gets the min of all inputs'), + Doc('ALL_REDUCE_PROD = 7', 'every output gets the prod of all inputs'), + Doc('GATHER = 8', 'concat inputs to one node'), + Doc('SCATTER = 9', 'scatter input to each output computing node'), + Doc('ALL_TO_ALL = 10', 'scatter inputs and gather them on each computing node'), name_field='mode')) (pdef('FakeSerializedDType', @@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields( 'evaluate a predicate and branch keys to setup ExecutionMask objects ' 'with associated predicate proxy vars (PPVs)') .add_enum(Doc('Mode', 'how to compare predicate var with branch keys'), - Doc('CASE', + Doc('CASE = 0', 'The outputs correspond to branch keys, ' 'and the one which equals predicate would be activated. ' 'This behaves like a case-statement in many languages.'), - Doc('CASE_FALLBACK', 'like :attr:`CASE`, but add an extra output ' + Doc('CASE_FALLBACK = 1', 'like :attr:`CASE`, but add an extra output ' 'that would be activated if no branch is matched'), - Doc('PIECEWISE', 'One more outputs would be produced than the ' + Doc('PIECEWISE = 2', 'One more outputs would be produced than the ' 'number of branch keys, representing the interval in which the ' 'predicate var fits in. The intervals are defined as ' r':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, ' @@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields( (pdef('CondExecPredLogical', 'compute a logical function over a set of PPVs') - .add_enum('Mode', Doc('OR', 'logical or'), - Doc('AND', 'logical and'), - Doc('XOR', 'exclusive-or'), - Doc('NOR', 'not or(inputs)'), - Doc('NAND', 'not and(inputs)'), - Doc('XNOR', 'not xor(inputs)')) + .add_enum('Mode', Doc('OR = 0', 'logical or'), + Doc('AND = 1', 'logical and'), + Doc('XOR = 2', 'exclusive-or'), + Doc('NOR = 3', 'not or(inputs)'), + Doc('NAND = 4', 'not and(inputs)'), + Doc('XNOR = 5', 'not xor(inputs)')) ) (pdef('CondExecMark', 'add ExecutionMask of the input PPV to this opr and readers of the ' 'outputs of this opr') .add_enum(Doc('GradMode', 'mode for computing the gradient'), - Doc('SUM', 'normal gradient mode: sum all the activated components'), - Doc('SUM_COND_OUT', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so ' + Doc('SUM = 0', 'normal gradient mode: sum all the activated components'), + Doc('SUM_COND_OUT = 1', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so ' 'oprs that depend on the gradient opr would not be executed ' 'if the forward var is not used.'), name_field='grad_mode') @@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields( execution into account, this option can be used to bypass static inference errors. This is currently only used by automatically generated gradient oprs."""), - Doc('SHAPE_VALUE', 'enable both shape and value inference'), - Doc('SHAPE_ONLY', + Doc('SHAPE_VALUE = 0', 'enable both shape and value inference'), + Doc('SHAPE_ONLY = 1', 'only enable shape inference (disable value inference)'), - Doc('NONE', 'disable both shape and value inference'), + Doc('NONE = 2', 'disable both shape and value inference'), name_field='static_infer') ) @@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields( 'number of output vars (i.e. vars per branch)'), 1) .add_enum('Mode', - Doc('EXACT_ONE', 'copy the var whose mask is activated to the output' + Doc('EXACT_ONE = 0', 'copy the var whose mask is activated to the output' ', requiring that exactly one branch is active'), - Doc('EXACT_ONE_SAME_SHAPE', 'like :attr:`EXACT_ONE` with the ' + Doc('EXACT_ONE_SAME_SHAPE = 1', 'like :attr:`EXACT_ONE` with the ' 'requirement that all branches have the same shape, so shape ' 'inference can be easier'), - Doc('SUM', 'sum all the active branches into output var; require ' + Doc('SUM = 2', 'sum all the active branches into output var; require ' 'all branches to have the same shape. Extra shape vars are ' 'needed in this mod, so the outputs can be initialized to zero ' 'when no input is active (and their shapes are probably ' 'unknown).'), - Doc('SUM_COND_OUT', 'like :attr:`SUM` but also add an ExecutionMask' + Doc('SUM_COND_OUT = 3', 'like :attr:`SUM` but also add an ExecutionMask' ' to the readers of output vars, so they would be skipped if ' ' no branch is taken') ) -- GitLab