提交 8494a152 编写于 作者: M Megvii Engine Team

chore(scripts): clarify and fix default value of bit combined enum

GitOrigin-RevId: 3716bf9bb566a23c6916df611dae563934e824cf
上级 da167cbc
...@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase): ...@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase):
name = p + e name = p + e
e = self._enums[(p, e)] e = self._enums[(p, e)]
self._write_doc(e.name) self._write_doc(e.name)
self._write("enum %s%s : uint {", p, e.name, indent=1) attribute = "(bit_flags)" if e.combined else ""
self._write("enum %s%s : uint %s {", p, e.name, attribute, indent=1)
for idx, member in enumerate(e.members): for idx, member in enumerate(e.members):
self._write_doc(member) self._write_doc(member)
if e.combined: self._write("%s,", scramble_enum_member_name(str(member)))
self._write("%s=%d,", scramble_enum_member_name(str(member)),
1<<idx)
else:
self._write("%s,", scramble_enum_member_name(str(member)))
self._write("}\n", indent=-1) self._write("}\n", indent=-1)
def _write_doc(self, doc): def _write_doc(self, doc):
...@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase): ...@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase):
return return
self._write_doc(e.name) self._write_doc(e.name)
self._used_enum.add(key) self._used_enum.add(key)
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, if e.combined:
scramble_enum_member_name(str(e.members[e.default]))) default = e.compose_combined_enum(e.default)
else:
default = scramble_enum_member_name(str(e.members[e.default]))
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)
def _resolve_const(self, v): def _resolve_const(self, v):
while v in self._cur_const_val: while v in self._cur_const_val:
...@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase): ...@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase):
return return
self._used_enum.add((e.src_class, e.src_name)) self._used_enum.add((e.src_class, e.src_name))
enum_name = e.src_class + e.src_name enum_name = e.src_class + e.src_name
self._write( s = e.src_enum
"%s:%s = %s;", e.name_field, enum_name, if s.combined:
scramble_enum_member_name(str(e.src_enum.members[e.get_default()]))) default = s.compose_combined_enum(e.get_default())
else:
default = scramble_enum_member_name(str(s.members[e.get_default()]))
self._write("%s:%s = %s;", e.name_field, enum_name, default)
def _get_fb_default(self, cppdefault): def _get_fb_default(self, cppdefault):
if not isinstance(cppdefault, str): if not isinstance(cppdefault, str):
......
...@@ -73,11 +73,21 @@ class member_defs: ...@@ -73,11 +73,21 @@ class member_defs:
"""define an enum; the result would contain both an enum class def and its """define an enum; the result would contain both an enum class def and its
corresponding data field corresponding data field
:param default: index of default member value :param default:
for normal enum class: index of default member value
for bit combined class: tuple of index of default member value
For example, following representations of the default value for bit
combined class are all equivalent:
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...)
Enum(members=('a', 'b', 'c'), default=(0, 1), ...)
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...)
:attr name_field: name of the data field of this enum in the param :attr name_field: name of the data field of this enum in the param
struct struct
:attr member_alias: list of (member, alias) pairs :attr member_alias:
for normal enum class: list of (member, alias) pairs
for bit combined class: list of (tuple of members, alias) paris
""" """
__slots__ = ['name', 'name_field', 'members', 'default', __slots__ = ['name', 'name_field', 'members', 'default',
'member_alias', 'combined'] 'member_alias', 'combined']
...@@ -90,17 +100,11 @@ class member_defs: ...@@ -90,17 +100,11 @@ class member_defs:
name = member_defs.Doc.make(name) name = member_defs.Doc.make(name)
assert name.id[0].isupper() assert name.id[0].isupper()
members = tuple(map(member_defs.Doc.make, members)) members = tuple(map(member_defs.Doc.make, members))
if isinstance(default, str):
if default not in name_field:
raise ValueError(
"Default value '{}' does not exist.".format(default))
default = name_field.index(default)
assert isinstance(default, int)
self.name = name self.name = name
self.combined = combined self.combined = combined
self.name_field = self.get_name_field(name.id, name_field) self.name_field = self.get_name_field(name.id, name_field)
self.members = members self.members = members
self.default = default self.default = self.normalize_enum_value(default)
self.all_enums[(param_name, name.id)] = self self.all_enums[(param_name, name.id)] = self
...@@ -114,6 +118,43 @@ class member_defs: ...@@ -114,6 +118,43 @@ class member_defs:
assert isinstance(name_field, str) assert isinstance(name_field, str)
return name_field return name_field
def normalize_enum_value(self, value):
def normalize(v):
if isinstance(v, str):
if v not in self.members:
raise ValueError(
"enum member '{}' does not exist.".format(v))
v = self.members.index(v)
assert isinstance(v, int)
return v
if self.combined:
if isinstance(value, int):
value = self.decompose_combined_enum(value)
assert isinstance(value, tuple)
value = tuple(normalize(i) for i in value)
return value
else:
return normalize(value)
@staticmethod
def decompose_combined_enum(v):
"""Integer => tuple of the indexes of the enum members"""
assert isinstance(v, int)
idx = 0
members = []
while v > 0:
if v & 1:
members.append(idx)
idx += 1
v >>= 1
return tuple(members)
def compose_combined_enum(self, v):
"""tuple of members => Integer"""
assert self.combined and isinstance(v, tuple)
norm_v = self.normalize_enum_value(v)
return sum(1 << i for i in norm_v)
class Field(Base): class Field(Base):
"""define a normal data field""" """define a normal data field"""
__slots__ = ['name', 'dtype', 'default'] __slots__ = ['name', 'dtype', 'default']
...@@ -146,6 +187,10 @@ class member_defs: ...@@ -146,6 +187,10 @@ class member_defs:
src_name = name src_name = name
self.src_name = src_name self.src_name = src_name
self.default = default self.default = default
# TODO: remove this assertion if needed; adding mock param_defs in
# current testing framework is too complicated, and currently we
# only allow aliasing of normal enum
assert not self.src_enum.combined
@property @property
def src_enum(self): def src_enum(self):
...@@ -157,7 +202,7 @@ class member_defs: ...@@ -157,7 +202,7 @@ class member_defs:
set""" set"""
if self.default is None: if self.default is None:
return self.src_enum.default return self.src_enum.default
return self.default return self.src_enum.normalize_enum_value(self.default)
class ParamDef: class ParamDef:
...@@ -198,7 +243,7 @@ class ParamDef: ...@@ -198,7 +243,7 @@ class ParamDef:
self.name.id, name, name_field, members, default, member_alias)) self.name.id, name, name_field, members, default, member_alias))
return self return self
def add_bit_combination_enum(self, name, *members, default=0, def add_bit_combination_enum(self, name, *members, default=tuple(),
name_field=None, member_alias=[]): name_field=None, member_alias=[]):
self.members.append(member_defs.Enum( self.members.append(member_defs.Enum(
self.name.id, name, name_field, members, default, member_alias, True)) self.name.id, name, name_field, members, default, member_alias, True))
...@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase): ...@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase):
' for idx, v in enumerate(pdata):\n' ' for idx, v in enumerate(pdata):\n'
' if isinstance(v, _EnumBase):\n' ' if isinstance(v, _EnumBase):\n'
' pdata[idx] = _enum_member2num[id(v)]\n' ' pdata[idx] = _enum_member2num[id(v)]\n'
' elif isinstance(v, _BitCombinedEnumBase):\n'
' pdata[idx] = v._value_\n'
' return tag + self._packer.pack(*pdata)\n' ' return tag + self._packer.pack(*pdata)\n'
'\n' '\n'
) )
self._write( # it's hard to mix custom implemention into enum, just do copy-paste instead
'class _EnumBase(enum.Enum):\n' classbody = (
' @classmethod\n' ' @classmethod\n'
' def __normalize(cls, val):\n' ' def __normalize(cls, val):\n'
' if isinstance(val, str):\n' ' if isinstance(val, str):\n'
...@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase): ...@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase):
' return super()._missing_(value)\n' ' return super()._missing_(value)\n'
'\n' '\n'
) )
self._write(
'class _EnumBase(enum.Enum):\n' + classbody
)
self._write(
'class _BitCombinedEnumBase(enum.Flag):\n' + classbody
)
if not self._imperative: if not self._imperative:
self._write( self._write(
'def _as_dtype_num(dtype):\n' 'def _as_dtype_num(dtype):\n'
...@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase): ...@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase):
def _on_member_enum(self, e): def _on_member_enum(self, e):
qualname = '{}.{}'.format(self._cur_param_name, e.name) qualname = '{}.{}'.format(self._cur_param_name, e.name)
self._write('class %s(_EnumBase):', e.name, indent=1) if e.combined:
self._write('class %s(_BitCombinedEnumBase):', e.name, indent=1)
else:
self._write('class %s(_EnumBase):', e.name, indent=1)
self._write_doc(e.name) self._write_doc(e.name)
for idx, emem in enumerate(e.members): for idx, emem in enumerate(e.members):
self._write('%s = "%s"', emem, emem)
self._write_doc(emem)
if e.combined: if e.combined:
self._enum_member2num.append('id({}.{}):{}'.format( self._write('%s = 1 << %d', emem, idx)
qualname, emem, 1<<idx)) self._write_doc(emem)
else: else:
self._write('%s = "%s"', emem, emem)
self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format( self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx)) qualname, emem, idx))
for emem, emem_alis in e.member_alias: for emem, emem_alias in e.member_alias:
self._write('%s = %s', emem_alis, emem) if e.combined:
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem))
else:
self._write('%s = %s', emem_alias, emem)
self._unindent() self._unindent()
self._write('') self._write('')
if e.combined:
default = e.compose_combined_enum(e.default)
else:
default = "'{}'".format(e.members[e.default])
self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field), cvt='{}.convert({})'.format(qualname, e.name_field),
fmt='I', fmt='I',
default="'{}'".format(e.members[e.default]), default=default,
type=qualname, type=qualname,
doc=None)) doc=None))
...@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase): ...@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase):
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) self._write('%s = %s.%s', e.name, e.src_class, e.src_name)
s = e.src_enum s = e.src_enum
qualname = '{}.{}'.format(e.src_class, e.src_name) qualname = '{}.{}'.format(e.src_class, e.src_name)
if s.combined:
default = s.compose_combined_enum(e.get_default())
else:
default = "'{}'".format(s.members[e.get_default()])
self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field), cvt='{}.convert({})'.format(qualname, e.name_field),
fmt='I', fmt='I',
default="'{}'".format(s.members[e.get_default()]), default=default,
type=qualname, type=qualname,
doc=None)) doc=None))
...@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase): ...@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase):
v += ',' v += ','
self._write(v) self._write(v)
for mem, alias in e.member_alias: for mem, alias in e.member_alias:
self._write('%s = %s,', alias, mem) if e.combined:
self._write('%s = %s,', alias, e.compose_combined_enum(mem))
else:
self._write('%s = %s,', alias, mem)
self._write('};', indent=-1) self._write('};', indent=-1)
self._non_static_members.append(e) self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
str(e.name).upper(), len(e.members)) str(e.name).upper(), len(e.members))
self._add_ctor_args(e.name, if e.combined:
'{}::{}'.format(e.name, e.members[e.default]), default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default))
e.name_field) else:
default = '{}::{}'.format(e.name, e.members[e.default])
self._add_ctor_args(e.name, default, e.name_field)
def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
s = e.src_enum s = e.src_enum
...@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase): ...@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase):
self._non_static_members.append(e) self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
str(e.name).upper(), len(s.members)) str(e.name).upper(), len(s.members))
self._add_ctor_args(e.name, if s.combined:
'{}::{}'.format(e.name, default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default))
s.members[e.get_default()]), else:
e.name_field) default = '{}::{}'.format(e.name, s.members[e.get_default()])
self._add_ctor_args(e.name, default, e.name_field)
def _on_member_field(self, f): def _on_member_field(self, f):
self._non_static_members.append(f) self._non_static_members.append(f)
......
...@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase): ...@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase):
return return
# wrapped with default value # wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default) if e.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default))
else:
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default])
wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
...@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase): ...@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase):
self._write("def {} : {};".format(td_class, enum_def)) self._write("def {} : {};".format(td_class, enum_def))
# wrapped with default value # wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default()) s = e.src_enum
if s.combined:
default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default()))
else:
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()])
wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)
self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
......
...@@ -87,9 +87,13 @@ struct pyobj_convert_generic { ...@@ -87,9 +87,13 @@ struct pyobj_convert_generic {
} }
}; };
template<typename T, typename SFINAE=void>
struct EnumTrait;
template <typename T> template <typename T>
struct EnumTrait { struct EnumTrait<T, std::enable_if_t<std::is_enum_v<T>>> {
static constexpr bool is_bit_combined = false; static constexpr bool is_bit_combined = false;
static constexpr std::underlying_type_t<T> max = 0;
}; };
template <typename T> template <typename T>
...@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper { ...@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper {
return ret; return ret;
} }
} }
static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject*, PyObject*) { static PyObject* py_new_combined_enum(PyTypeObject* type, PyObject* args, PyObject*) {
PyObject* obj = type->tp_alloc(type, 0); if (!PyTuple_Size(args)) {
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = static_cast<T>(1); PyObject* obj = type->tp_alloc(type, 0);
return obj; reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
} return obj;
static int py_init(PyObject* self, PyObject* args, PyObject*) { }
int input = 1; else {
if (PyArg_ParseTuple(args, "|i", &input)){ PyObject* input;
reinterpret_cast<BitCombinedEnumWrapper*>(self)->value = if (!PyArg_ParseTuple(args, "|O", &input)) {
static_cast<T>(input); return nullptr;
}
T value;
try {
value = pyobj_convert_generic<T>::from(input);
} CATCH_ALL(nullptr);
PyObject* obj = type->tp_alloc(type, 0);
reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
return obj;
} }
return 0;
} }
static PyObject* py_repr(PyObject* self) { static PyObject* py_repr(PyObject* self) {
return pyobj_convert_generic<std::string>::to( return pyobj_convert_generic<std::string>::to(
...@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T, ...@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T,
static T from(PyObject* obj) { static T from(PyObject* obj) {
if (PyObject_TypeCheck(obj, &Wrapper::type)) { if (PyObject_TypeCheck(obj, &Wrapper::type)) {
return reinterpret_cast<Wrapper*>(obj)->value; return reinterpret_cast<Wrapper*>(obj)->value;
} else if(PyLong_Check(obj)) {
auto value = pyobj_convert_generic<std::underlying_type_t<T>>::from(obj);
mgb_throw_if(value > EnumTrait<T>::max, mgb::MegBrainError,
"out of range, cannot convert %zu to %s",
static_cast<uint32_t>(value), Wrapper::name);
return static_cast<T>(value);
} }
// try as string // try as string
// TODO: type checkcd // TODO: type checkcd
......
...@@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() { ...@@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() {
"template<> PyNumberMethods " "template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods={};\n", "$enumTpl<$opClass::$enumClass>::number_methods={};\n",
&ctx); &ctx);
os << tgfmt( os << tgfmt(R"(
"template<> struct EnumTrait<$opClass::$enumClass> { static constexpr " template<> struct EnumTrait<$opClass::$enumClass> {
"bool is_bit_combined = true;};\n", static constexpr bool is_bit_combined = true;
&ctx); static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1;
};
)", &ctx, attr->getEnumMembers().size());
} }
auto str2type = [&](auto&& i) -> std::string { auto str2type = [&](auto&& i) -> std::string {
...@@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) { ...@@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) {
// others should always use singleton // others should always use singleton
os << tgfmt(R"( os << tgfmt(R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum; e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods; auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or; number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and; number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
......
...@@ -6,7 +6,7 @@ decl_opr('Convolution', ...@@ -6,7 +6,7 @@ decl_opr('Convolution',
'convolution kernel in ' 'convolution kernel in '
'(out channel, in channel, kern row, kern col) format')], '(out channel, in channel, kern row, kern col) format')],
params=[('param', 'ConvolutionV0'), params=[('param', 'ConvolutionV0'),
('execution_polity', 'ExecutionPolicy')], ('execution_polity', 'ExecutionPolicyV0')],
desc='batched convolution on channeled 2D images') desc='batched convolution on channeled 2D images')
decl_opr('Convolution', decl_opr('Convolution',
...@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData', ...@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData',
'convolution kernel in ' 'convolution kernel in '
'(out channel, in channel, kern row, kern col) format')], '(out channel, in channel, kern row, kern col) format')],
params=[('param', 'ConvolutionV0'), params=[('param', 'ConvolutionV0'),
('execution_polity', 'ExecutionPolicy')], ('execution_polity', 'ExecutionPolicyV0')],
body=[ body=[
'a, b = all_inputs', 'a, b = all_inputs',
'all_inputs = [b, a]' 'all_inputs = [b, a]'
...@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward', ...@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward',
Doc('bias', 'bias'), Doc('bias', 'bias'),
], ],
params=[('param', 'ConvBiasV1'), params=[('param', 'ConvBiasV1'),
('execution_policy', 'ExecutionPolicy')], ('execution_policy', 'ExecutionPolicyV0')],
desc=('activation(convolution(src, filter) + bias) with specified ' desc=('activation(convolution(src, filter) + bias) with specified '
'dtype'), 'dtype'),
has_out_dtype=True) has_out_dtype=True)
......
...@@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields(
'when profile or heuristic algo selection it require the algos' 'when profile or heuristic algo selection it require the algos'
'must be reproducible'), 'must be reproducible'),
Doc('OPTMIZED', Doc('OPTMIZED',
'profile require algos are optmized to achieve fast-profile')). 'profile require algos are optmized to achieve fast-profile'),
default=('HEURISTIC',),
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'),
(('PROFILE', 'REPRODUCIBLE'), 'PROFILE_REPRODUCIBLE'),
(('PROFILE', 'HEURISTIC'), 'PROFILE_HEURISTIC'),
]).
add_fields('uint64', add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'), Doc('workspace_limit', 'workspace limit in bytes'),
str(2**64-1)+'ull')) str(2**64-1)+'ull'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册