From 8494a1529e8ed35e83500798e53a4289b9d9c6ed Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 2 Apr 2021 15:20:33 +0800 Subject: [PATCH] chore(scripts): clarify and fix default value of bit combined enum GitOrigin-RevId: 3716bf9bb566a23c6916df611dae563934e824cf --- dnn/scripts/gen_flatbuffers_schema.py | 25 ++-- dnn/scripts/gen_param_defs.py | 136 ++++++++++++++---- dnn/scripts/gen_tablegen.py | 15 +- imperative/python/src/ops.cpp | 41 ++++-- .../tablegen/targets/python_c_extension.cpp | 11 +- src/opr/impl/dnn/dnn.oprdecl | 6 +- tools/param_defs/mgb_opr_param_defs.py | 7 +- 7 files changed, 177 insertions(+), 64 deletions(-) diff --git a/dnn/scripts/gen_flatbuffers_schema.py b/dnn/scripts/gen_flatbuffers_schema.py index 3c43561b4..9794c7565 100755 --- a/dnn/scripts/gen_flatbuffers_schema.py +++ b/dnn/scripts/gen_flatbuffers_schema.py @@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase): name = p + e e = self._enums[(p, e)] self._write_doc(e.name) - self._write("enum %s%s : uint {", p, e.name, indent=1) + attribute = "(bit_flags)" if e.combined else "" + self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1) for idx, member in enumerate(e.members): self._write_doc(member) - if e.combined: - self._write("%s=%d,", scramble_enum_member_name(str(member)), - 1< tuple of the indexes of the enum members""" + assert isinstance(v, int) + idx = 0 + members = [] + while v > 0: + if v & 1: + members.append(idx) + idx += 1 + v >>= 1 + return tuple(members) + + def compose_combined_enum(self, v): + """tuple of members => Integer""" + assert self.combined and isinstance(v, tuple) + norm_v = self.normalize_enum_value(v) + return sum(1 << i for i in norm_v) + class Field(Base): """define a normal data field""" __slots__ = ['name', 'dtype', 'default'] @@ -146,6 +187,10 @@ class member_defs: src_name = name self.src_name = src_name self.default = default + # TODO: remove this assertion if needed; adding mock param_defs in + # current testing framework is too complicated, and currently we + # only allow aliasing of normal enum + assert not self.src_enum.combined @property def src_enum(self): @@ -157,7 +202,7 @@ class member_defs: set""" if self.default is None: return self.src_enum.default - return self.default + return self.src_enum.normalize_enum_value(self.default) class ParamDef: @@ -198,7 +243,7 @@ class ParamDef: self.name.id, name, name_field, members, default, member_alias)) return self - def add_bit_combination_enum(self, name, *members, default=0, + def add_bit_combination_enum(self, name, *members, default=tuple(), name_field=None, member_alias=[]): self.members.append(member_defs.Enum( self.name.id, name, name_field, members, default, member_alias, True)) @@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase): ' for idx, v in enumerate(pdata):\n' ' if isinstance(v, _EnumBase):\n' ' pdata[idx] = _enum_member2num[id(v)]\n' + ' elif isinstance(v, _BitCombinedEnumBase):\n' + ' pdata[idx] = v._value_\n' ' return tag + self._packer.pack(*pdata)\n' '\n' ) - self._write( - 'class _EnumBase(enum.Enum):\n' + # it's hard to mix custom implemention into enum, just do copy-paste instead + classbody = ( ' @classmethod\n' ' def __normalize(cls, val):\n' ' if isinstance(val, str):\n' @@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase): ' return super()._missing_(value)\n' '\n' ) + self._write( + 'class _EnumBase(enum.Enum):\n' + classbody + ) + self._write( + 'class _BitCombinedEnumBase(enum.Flag):\n' + classbody + ) if not self._imperative: self._write( 'def _as_dtype_num(dtype):\n' @@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase): def _on_member_enum(self, e): qualname = '{}.{}'.format(self._cur_param_name, e.name) - self._write('class %s(_EnumBase):', e.name, indent=1) + if e.combined: + self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1) + else: + self._write('class %s(_EnumBase):', e.name, indent=1) + self._write_doc(e.name) for idx, emem in enumerate(e.members): - self._write('%s = "%s"', emem, emem) - self._write_doc(emem) if e.combined: - self._enum_member2num.append('id({}.{}):{}'.format( - qualname, emem, 1< +struct EnumTrait; + template -struct EnumTrait { +struct EnumTrait>> { static constexpr bool is_bit_combined = false; + static constexpr std::underlying_type_t max = 0; }; template @@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper { return ret; } } - static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { - PyObject* obj = type->tp_alloc(type, 0); - reinterpret_cast(obj)->value = static_cast(1); - return obj; - } - static int py_init(PyObject* self, PyObject* args, PyObject*) { - int input = 1; - if (PyArg_ParseTuple(args, "|i", &input)){ - reinterpret_cast(self)->value = - static_cast(input); + static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) { + if (!PyTuple_Size(args)) { + PyObject* obj = type->tp_alloc(type, 0); + reinterpret_cast(obj)->value = T(); + return obj; + } + else { + PyObject* input; + if (!PyArg_ParseTuple(args, "|O", &input)) { + return nullptr; + } + T value; + try { + value = pyobj_convert_generic::from(input); + } CATCH_ALL(nullptr); + PyObject* obj = type->tp_alloc(type, 0); + reinterpret_cast(obj)->value = value; + return obj; } - return 0; } static PyObject* py_repr(PyObject* self) { return pyobj_convert_generic::to( @@ -325,6 +336,12 @@ struct pyobj_convert_generic(obj)->value; + } else if(PyLong_Check(obj)) { + auto value = pyobj_convert_generic>::from(obj); + mgb_throw_if(value > EnumTrait::max, mgb::MegBrainError, + "out of range, cannot convert %zu to %s", + static_cast(value), Wrapper::name); + return static_cast(value); } // try as string // TODO: type checkcd diff --git a/imperative/tablegen/targets/python_c_extension.cpp b/imperative/tablegen/targets/python_c_extension.cpp index 1de71a858..130962e0e 100644 --- a/imperative/tablegen/targets/python_c_extension.cpp +++ b/imperative/tablegen/targets/python_c_extension.cpp @@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() { "template<> PyNumberMethods " "$enumTpl<$opClass::$enumClass>::number_methods={};\n", &ctx); - os << tgfmt( - "template<> struct EnumTrait<$opClass::$enumClass> { static constexpr " - "bool is_bit_combined = true;};\n", - &ctx); + os << tgfmt(R"( +template<> struct EnumTrait<$opClass::$enumClass> { + static constexpr bool is_bit_combined = true; + static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1; +}; +)", &ctx, attr->getEnumMembers().size()); } auto str2type = [&](auto&& i) -> std::string { @@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) { // others should always use singleton os << tgfmt(R"( e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; - e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init; auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; diff --git a/src/opr/impl/dnn/dnn.oprdecl b/src/opr/impl/dnn/dnn.oprdecl index 16d6e3ff1..c511ed496 100644 --- a/src/opr/impl/dnn/dnn.oprdecl +++ b/src/opr/impl/dnn/dnn.oprdecl @@ -6,7 +6,7 @@ decl_opr('Convolution', 'convolution kernel in ' '(out channel, in channel, kern row, kern col) format')], params=[('param', 'ConvolutionV0'), - ('execution_polity', 'ExecutionPolicy')], + ('execution_polity', 'ExecutionPolicyV0')], desc='batched convolution on channeled 2D images') decl_opr('Convolution', @@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData', 'convolution kernel in ' '(out channel, in channel, kern row, kern col) format')], params=[('param', 'ConvolutionV0'), - ('execution_polity', 'ExecutionPolicy')], + ('execution_polity', 'ExecutionPolicyV0')], body=[ 'a, b = all_inputs', 'all_inputs = [b, a]' @@ -201,7 +201,7 @@ decl_opr('ConvBiasForward', Doc('bias', 'bias'), ], params=[('param', 'ConvBiasV1'), - ('execution_policy', 'ExecutionPolicy')], + ('execution_policy', 'ExecutionPolicyV0')], desc=('activation(convolution(src, filter) + bias) with specified ' 'dtype'), has_out_dtype=True) diff --git a/tools/param_defs/mgb_opr_param_defs.py b/tools/param_defs/mgb_opr_param_defs.py index d8fd2026e..cd46edd2a 100644 --- a/tools/param_defs/mgb_opr_param_defs.py +++ b/tools/param_defs/mgb_opr_param_defs.py @@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields( 'when profile or heuristic algo selection it require the algos' 'must be reproducible'), Doc('OPTMIZED', - 'profile require algos are optmized to achieve fast-profile')). + 'profile require algos are optmized to achieve fast-profile'), + default=('HEURISTIC',), + member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'), + (('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'), + (('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'), + ]). add_fields('uint64', Doc('workspace_limit', 'workspace limit in bytes'), str(2**64-1)+'ull')) -- GitLab