提交 fb49a283 编写于 作者: M Megvii Engine Team

refactor(mgb/dnn): refactor enum used in serializing

GitOrigin-RevId: e57af4a59c9b4e090f3972b4d0cf01a2737f8355
上级 d69b5903
...@@ -23,8 +23,14 @@ def _cname_to_fbname(cname): ...@@ -23,8 +23,14 @@ def _cname_to_fbname(cname):
}[cname] }[cname]
def scramble_enum_member_name(name): def scramble_enum_member_name(name):
s = name.find('<<')
if s != -1:
name = name[0:name.find('=') + 1] + ' ' + name[s+2:]
if name in ("MIN", "MAX"): if name in ("MIN", "MAX"):
return name + "_" return name + "_"
o_name = name.split(' ')[0].split('=')[0]
if o_name in ("MIN", "MAX"):
return name.replace(o_name, o_name + "_")
return name return name
class FlatBuffersWriter(IndentWriterBase): class FlatBuffersWriter(IndentWriterBase):
...@@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase): ...@@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase):
if e.combined: if e.combined:
default = e.compose_combined_enum(e.default) default = e.compose_combined_enum(e.default)
else: else:
default = scramble_enum_member_name(str(e.members[e.default])) default = scramble_enum_member_name(
str(e.members[e.default]).split(' ')[0].split('=')[0])
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default) self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, default)
def _resolve_const(self, v): def _resolve_const(self, v):
...@@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase): ...@@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase):
if s.combined: if s.combined:
default = s.compose_combined_enum(e.get_default()) default = s.compose_combined_enum(e.get_default())
else: else:
default = scramble_enum_member_name(str(s.members[e.get_default()])) default = scramble_enum_member_name(
str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
self._write("%s:%s = %s;", e.name_field, enum_name, default) self._write("%s:%s = %s;", e.name_field, enum_name, default)
def _get_fb_default(self, cppdefault): def _get_fb_default(self, cppdefault):
......
...@@ -121,10 +121,12 @@ class member_defs: ...@@ -121,10 +121,12 @@ class member_defs:
def normalize_enum_value(self, value): def normalize_enum_value(self, value):
def normalize(v): def normalize(v):
if isinstance(v, str): if isinstance(v, str):
if v not in self.members: for idx, m in enumerate(self.members):
m = str(m).split(' ')[0].split('=')[0]
if v == m :
return idx
raise ValueError( raise ValueError(
"enum member '{}' does not exist.".format(v)) "enum member '{}' does not exist.".format(v))
v = self.members.index(v)
assert isinstance(v, int) assert isinstance(v, int)
return v return v
if self.combined: if self.combined:
...@@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase): ...@@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase):
self._write_doc(e.name) self._write_doc(e.name)
for idx, emem in enumerate(e.members): for emem in e.members:
if e.combined: if e.combined:
self._write('%s = 1 << %d', emem, idx) self._write('%s', emem)
self._write_doc(emem) self._write_doc(emem)
else: else:
self._write('%s = "%s"', emem, emem) v = str(emem).split(' ')[0].split('=')[0]
n = int(str(emem).split('=')[1])
self._write('%s = "%s"', v, v)
self._write_doc(emem) self._write_doc(emem)
self._enum_member2num.append('id({}.{}):{}'.format( self._enum_member2num.append('id({}.{}):{}'.format(
qualname, emem, idx)) qualname, v, n))
for emem, emem_alias in e.member_alias: for emem, emem_alias in e.member_alias:
em_a = emem_alias.split(' ')[0].split('=')[0]
if e.combined: if e.combined:
self._write('%s = %s', emem_alias, e.compose_combined_enum(emem)) self._write('%s = %s', em_a, e.compose_combined_enum(emem))
else: else:
self._write('%s = %s', emem_alias, emem) em = str(emem).split(' ')[0].split('=')[0]
self._write('%s = %s', em_a, em)
self._unindent() self._unindent()
self._write('') self._write('')
...@@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase): ...@@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase):
if e.combined: if e.combined:
default = e.compose_combined_enum(e.default) default = e.compose_combined_enum(e.default)
else: else:
default = "'{}'".format(e.members[e.default]) default = "'{}'".format(str(e.members[e.default]).split(' ')[0].split('=')[0])
self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
...@@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase): ...@@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase):
if s.combined: if s.combined:
default = s.compose_combined_enum(e.get_default()) default = s.compose_combined_enum(e.get_default())
else: else:
default = "'{}'".format(s.members[e.get_default()]) default = "'{}'".format(str(s.members[e.get_default()]).split(' ')[0].split('=')[0])
self._cur_fields.append(self.FieldDef( self._cur_fields.append(self.FieldDef(
name=e.name_field, name=e.name_field,
cvt='{}.convert({})'.format(qualname, e.name_field), cvt='{}.convert({})'.format(qualname, e.name_field),
...@@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase): ...@@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase):
def _on_member_enum(self, e): def _on_member_enum(self, e):
self._write_doc(e.name) self._write_doc(e.name)
self._write('enum class %s: uint32_t {', e.name, indent=1) self._write('enum class %s: uint32_t {', e.name, indent=1)
for idx, i in enumerate(e.members): for i in e.members:
self._write_doc(i) self._write_doc(i)
v = '{} = {}'.format(i, idx) v = str(i)
if e.combined:
v = '{} = 1 << {}'.format(i, idx)
if i is not e.members[-1] or e.member_alias: if i is not e.members[-1] or e.member_alias:
v += ',' v += ','
self._write(v) self._write(v)
...@@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase): ...@@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase):
if e.combined: if e.combined:
self._write('%s = %s,', alias, e.compose_combined_enum(mem)) self._write('%s = %s,', alias, e.compose_combined_enum(mem))
else: else:
self._write('%s = %s,', alias, mem) self._write('%s = %s,', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0])
self._write('};', indent=-1) self._write('};', indent=-1)
self._non_static_members.append(e) self._non_static_members.append(e)
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;',
...@@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase): ...@@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase):
if e.combined: if e.combined:
default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default)) default = 'static_cast<{}>({})'.format(e.name, e.compose_combined_enum(e.default))
else: else:
default = '{}::{}'.format(e.name, e.members[e.default]) value = str(e.members[e.default])
value = value.split(' ')[0].split('=')[0]
default = '{}::{}'.format(e.name, value)
self._add_ctor_args(e.name, default, e.name_field) self._add_ctor_args(e.name, default, e.name_field)
def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
...@@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase): ...@@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase):
if s.combined: if s.combined:
default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default)) default = 'static_cast<{}>({})'.format(e.name, s.compose_combined_enum(e.default))
else: else:
default = '{}::{}'.format(e.name, s.members[e.get_default()]) value = str(s.members[e.get_default()])
value = value.split(' ')[0].split('=')[0]
default = '{}::{}'.format(e.name, value)
self._add_ctor_args(e.name, default, e.name_field) self._add_ctor_args(e.name, default, e.name_field)
def _on_member_field(self, f): def _on_member_field(self, f):
...@@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter): ...@@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter):
def _on_member_enum(self, e): def _on_member_enum(self, e):
self._write_doc(e.name) self._write_doc(e.name)
self._write('struct %s {', e.name, indent=1) self._write('struct %s {', e.name, indent=1)
for idx, val in enumerate(e.members): for val in e.members:
self._write_doc(val) self._write_doc(val)
self._write('static const uint32_t %s = %d;', val, idx) v = str(val)
self._write('static const uint32_t %s;', v)
for mem, alias in e.member_alias: for mem, alias in e.member_alias:
self._write('static const uint32_t %s = %s;', alias, mem) self._write('static const uint32_t %s = %s;', str(alias).split(' ')[0].split('=')[0], str(mem).split(' ')[0].split('=')[0])
self._write('};', indent=-1) self._write('};', indent=-1)
def _on_member_enum_alias(self, e): def _on_member_enum_alias(self, e):
...@@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase): ...@@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase):
members = e.src_enum.members members = e.src_enum.members
else: else:
members = e.members members = e.members
for idx, i in enumerate(members): for i in members:
v = str(i)
v = v.split(' ')[0].split('=')[0]
self._write('case %s::%s::%s: return "%s";', self._write('case %s::%s::%s: return "%s";',
self._param_name, e.name, i, i, indent=0) self._param_name, e.name, v, v, indent=0)
self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));', self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));',
self._param_name, e.name, indent=0) self._param_name, e.name, indent=0)
self._write('}', indent=-1) self._write('}', indent=-1)
......
...@@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase): ...@@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase):
fullname = "::megdnn::param::{}".format(p.name) fullname = "::megdnn::param::{}".format(p.name)
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
def format(v): def format(v):
return '\"{}\"'.format(str(v)) return '\"{}\"'.format(str(v).split(' ')[0].split('=')[0])
enum_def += ','.join(format(i) for i in e.members) enum_def += ','.join(format(i) for i in e.members)
if e.combined: if e.combined:
...@@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase): ...@@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase):
default_val = "static_cast<{}::{}>({})".format( default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, e.compose_combined_enum(e.default)) fullname, e.name, e.compose_combined_enum(e.default))
else: else:
default_val = "{}::{}::{}".format(fullname, e.name, e.members[e.default]) default_val = "{}::{}::{}".format(
fullname, e.name, str(e.members[e.default]).split(' ')[0].split('=')[0])
wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)
...@@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase): ...@@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase):
default_val = "static_cast<{}::{}>({})".format( default_val = "static_cast<{}::{}>({})".format(
fullname, e.name, s.compose_combined_enum(e.get_default())) fullname, e.name, s.compose_combined_enum(e.get_default()))
else: else:
default_val = "{}::{}::{}".format(fullname, e.name, s.members[e.get_default()]) default_val = "{}::{}::{}".format(fullname, e.name, str(
s.members[e.get_default()]).split(' ')[0].split('=')[0])
wrapped = self._wrapped_with_default_value(td_class, default_val) wrapped = self._wrapped_with_default_value(td_class, default_val)
......
此差异已折叠。
...@@ -241,14 +241,17 @@ private: ...@@ -241,14 +241,17 @@ private:
if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) { if (auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr)) {
body += formatv(" switch ({0}){{\n", "$_self." + it.name); body += formatv(" switch ({0}){{\n", "$_self." + it.name);
for (auto&& enumMember: enumAttr->getEnumMembers()) { for (auto&& enumMember: enumAttr->getEnumMembers()) {
body += formatv( size_t d1 = enumMember.find(' ');
" case {0}::{1}::{2}:\n", size_t d2 = enumMember.find('=');
getCppClassName(), enumAttr->getEnumName(), enumMember size_t d = d1 <= d2 ? d1 : d2;
); body += formatv(" case {0}::{1}::{2}:\n",
body += formatv( getCppClassName(),
" props_.emplace_back(\"{0}\", \"{1}\");\n", enumAttr->getEnumName(),
it.name, enumMember enumMember.substr(0, d));
); body +=
formatv(" props_.emplace_back(\"{0}\", "
"\"{1}\");\n",
it.name, enumMember.substr(0, d));
body += " break;\n"; body += " break;\n";
} }
body += " default: break;\n"; body += " default: break;\n";
......
...@@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() { ...@@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() {
std::vector<std::string> case_body; std::vector<std::string> case_body;
std::string ename = formatv("{0}::{1}", std::string ename = formatv("{0}::{1}",
op.getCppClassName(), attr->getEnumName()); op.getCppClassName(), attr->getEnumName());
llvm::for_each(attr->getEnumMembers(), [&](auto&& v){ llvm::for_each(attr->getEnumMembers(), [&](auto&& v) {
case_body.push_back(formatv( size_t d1 = v.find(' ');
"case {0}::{1}: return \"{1}\";", ename, v)); size_t d2 = v.find('=');
size_t d = d1 <= d2 ? d1 : d2;
case_body.push_back(
formatv("case {0}::{1}: return \"{1}\";", ename,
v.substr(0, d)));
}); });
os << formatv(R"( os << formatv(R"(
template <> template <>
......
...@@ -50,14 +50,15 @@ void OpDefEmitter::emit() { ...@@ -50,14 +50,15 @@ void OpDefEmitter::emit() {
); );
std::vector<std::string> body; std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) { for (auto&& i: attr->getEnumMembers()) {
os << formatv( size_t d1 = i.find(' ');
"\n .value(\"{2}\", {0}::{1}::{2})", size_t d2 = i.find('=');
className, attr->getEnumName(), i size_t d = d1 <= d2 ? d1 : d2;
); os << formatv("\n .value(\"{2}\", {0}::{1}::{2})",
className, attr->getEnumName(),
i.substr(0, d));
body.push_back(formatv( body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};", "if (str == \"{2}\") return {0}::{1}::{2};",
className, attr->getEnumName(), i className, attr->getEnumName(), i.substr(0, d)));
));
} }
if (attr->getEnumCombinedFlag()) { if (attr->getEnumCombinedFlag()) {
//! define operator | //! define operator |
......
...@@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() { ...@@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() {
&ctx); &ctx);
auto quote = [&](auto&& i) -> std::string { auto quote = [&](auto&& i) -> std::string {
return formatv("\"{0}\"", i); size_t d1 = i.find(' ');
size_t d2 = i.find('=');
size_t d = d1 <= d2 ? d1 : d2;
return formatv("\"{0}\"", i.substr(0, d));
}; };
os << tgfmt(R"( os << tgfmt(R"(
template<> const char* template<> const char*
...@@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0}; ...@@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0};
)", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", ")); )", &ctx, llvm::join(llvm::map_range(attr->getEnumMembers(), quote), ", "));
auto mem2value = [&](auto&& i) -> std::string { auto mem2value = [&](auto&& i) -> std::string {
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx, i); size_t d1 = i.find(' ');
size_t d2 = i.find('=');
size_t d = d1 <= d2 ? d1 : d2;
return tgfmt("{normalize_enum(\"$0\"), $opClass::$enumClass::$0}", &ctx,
i.substr(0, d));
}; };
os << tgfmt(R"( os << tgfmt(R"(
template<> std::unordered_map<std::string, $opClass::$enumClass> template<> std::unordered_map<std::string, $opClass::$enumClass>
...@@ -192,12 +199,15 @@ os << tgfmt(R"( ...@@ -192,12 +199,15 @@ os << tgfmt(R"(
auto&& members = attr->getEnumMembers(); auto&& members = attr->getEnumMembers();
for (size_t idx = 0; idx < members.size(); ++ idx) { for (size_t idx = 0; idx < members.size(); ++ idx) {
size_t d1 = members[idx].find(' ');
size_t d2 = members[idx].find('=');
size_t d = d1 <= d2 ? d1 : d2;
os << tgfmt(R"({ os << tgfmt(R"({
PyObject* inst = e_type->tp_alloc(e_type, 0); PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0; reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0); mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst; $enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})", &ctx, members[idx], idx); })", &ctx, members[idx].substr(0, d), idx);
} }
} }
......
...@@ -136,12 +136,13 @@ class HeaderGen: ...@@ -136,12 +136,13 @@ class HeaderGen:
mode_list = [i.strip() for i in fin] mode_list = [i.strip() for i in fin]
for i in mode_list: for i in mode_list:
i = i.split(' ')[0].split('=')[0]
if i in self._elemwise_modes: if i in self._elemwise_modes:
content = '_cb({})'.format(i) content = '_cb({})'.format(i)
else: else:
content = '' content = ''
self._write_def( self._write_def(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content) '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i.split(' ')[0].split('=')[0]), content)
self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)', self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)') '_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')
......
...@@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('ExecutionPolicy', version=0, is_legacy=True). (pdef('ExecutionPolicy', version=0, is_legacy=True).
add_enum('Strategy', add_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), Doc('HEURISTIC = 0', 'use heuristic to choose the fastest algorithm'),
Doc('HEURISTIC_REPRODUCIBLE', 'use heuristic to choose the fastest algorithm, ' Doc('HEURISTIC_REPRODUCIBLE = 1', 'use heuristic to choose the fastest algorithm, '
'and the chosen algorithm is reproducible'), 'and the chosen algorithm is reproducible'),
Doc('PROFILE', Doc('PROFILE = 2',
'run possible algorithms on real device to find the best'), 'run possible algorithms on real device to find the best'),
Doc('PROFILE_REPRODUCIBLE', Doc('PROFILE_REPRODUCIBLE = 3',
'the fastest of profile result that is also reproducible'), 'the fastest of profile result that is also reproducible'),
Doc('PROFILE_HEURISTIC', Doc('PROFILE_HEURISTIC = 4',
'use profile result and heuristic to choose the fastest algorithm')). 'use profile result and heuristic to choose the fastest algorithm')).
add_fields('uint64', add_fields('uint64',
Doc('workspace_limit', 'workspace limit in bytes'), Doc('workspace_limit', 'workspace limit in bytes'),
...@@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1). (pdef('ExecutionPolicy', 'specify how to select an algorithm for an operator', version=1).
add_bit_combination_enum('Strategy', add_bit_combination_enum('Strategy',
Doc('HEURISTIC', 'use heuristic to choose the fastest algorithm'), Doc('HEURISTIC = 1 << 0', 'use heuristic to choose the fastest algorithm'),
Doc('PROFILE', Doc('PROFILE = 1 << 1',
'run possible algorithms on real device to find the best'), 'run possible algorithms on real device to find the best'),
Doc('REPRODUCIBLE', Doc('REPRODUCIBLE = 1 << 2',
'when profile or heuristic algo selection it require the algos' 'when profile or heuristic algo selection it require the algos'
'must be reproducible'), 'must be reproducible'),
Doc('OPTIMIZED', Doc('OPTIMIZED = 1 << 3',
'profile require algos are optmized to achieve fast-profile'), 'profile require algos are optmized to achieve fast-profile'),
default=('HEURISTIC',), default=('HEURISTIC',),
member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'), member_alias=[(('HEURISTIC', 'REPRODUCIBLE'), 'HEURISTIC_REPRODUCIBLE'),
...@@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('CollectiveComm', 'collective communication between multiple computing ' (pdef('CollectiveComm', 'collective communication between multiple computing '
'nodes on localhost') 'nodes on localhost')
.add_enum(Doc('Mode', 'mode of collective communication'), .add_enum(Doc('Mode', 'mode of collective communication'),
Doc('REDUCE_SUM', 'reduce by sum to output computing node'), Doc('REDUCE_SUM = 0', 'reduce by sum to output computing node'),
Doc('BROADCAST', 'copy input value to each output computing node'), Doc('BROADCAST = 1', 'copy input value to each output computing node'),
Doc('ALL_GATHER', 'each output comp node gets the concatenated ' Doc('ALL_GATHER = 2', 'each output comp node gets the concatenated '
'value of all inputs'), 'value of all inputs'),
Doc('REDUCE_SCATTER_SUM', Doc('REDUCE_SCATTER_SUM = 3',
'reduce inputs by sum and each output gets one part of it'), 'reduce inputs by sum and each output gets one part of it'),
Doc('ALL_REDUCE_SUM', 'every output gets the sum of all inputs'), Doc('ALL_REDUCE_SUM = 4', 'every output gets the sum of all inputs'),
Doc('ALL_REDUCE_MAX', 'every output gets the max of all inputs'), Doc('ALL_REDUCE_MAX = 5', 'every output gets the max of all inputs'),
Doc('ALL_REDUCE_MIN', 'every output gets the min of all inputs'), Doc('ALL_REDUCE_MIN = 6', 'every output gets the min of all inputs'),
Doc('ALL_REDUCE_PROD', 'every output gets the prod of all inputs'), Doc('ALL_REDUCE_PROD = 7', 'every output gets the prod of all inputs'),
Doc('GATHER', 'concat inputs to one node'), Doc('GATHER = 8', 'concat inputs to one node'),
Doc('SCATTER', 'scatter input to each output computing node'), Doc('SCATTER = 9', 'scatter input to each output computing node'),
Doc('ALL_TO_ALL', 'scatter inputs and gather them on each computing node'), Doc('ALL_TO_ALL = 10', 'scatter inputs and gather them on each computing node'),
name_field='mode')) name_field='mode'))
(pdef('FakeSerializedDType', (pdef('FakeSerializedDType',
...@@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields(
'evaluate a predicate and branch keys to setup ExecutionMask objects ' 'evaluate a predicate and branch keys to setup ExecutionMask objects '
'with associated predicate proxy vars (PPVs)') 'with associated predicate proxy vars (PPVs)')
.add_enum(Doc('Mode', 'how to compare predicate var with branch keys'), .add_enum(Doc('Mode', 'how to compare predicate var with branch keys'),
Doc('CASE', Doc('CASE = 0',
'The outputs correspond to branch keys, ' 'The outputs correspond to branch keys, '
'and the one which equals predicate would be activated. ' 'and the one which equals predicate would be activated. '
'This behaves like a case-statement in many languages.'), 'This behaves like a case-statement in many languages.'),
Doc('CASE_FALLBACK', 'like :attr:`CASE`, but add an extra output ' Doc('CASE_FALLBACK = 1', 'like :attr:`CASE`, but add an extra output '
'that would be activated if no branch is matched'), 'that would be activated if no branch is matched'),
Doc('PIECEWISE', 'One more outputs would be produced than the ' Doc('PIECEWISE = 2', 'One more outputs would be produced than the '
'number of branch keys, representing the interval in which the ' 'number of branch keys, representing the interval in which the '
'predicate var fits in. The intervals are defined as ' 'predicate var fits in. The intervals are defined as '
r':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, ' r':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, '
...@@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields(
(pdef('CondExecPredLogical', (pdef('CondExecPredLogical',
'compute a logical function over a set of PPVs') 'compute a logical function over a set of PPVs')
.add_enum('Mode', Doc('OR', 'logical or'), .add_enum('Mode', Doc('OR = 0', 'logical or'),
Doc('AND', 'logical and'), Doc('AND = 1', 'logical and'),
Doc('XOR', 'exclusive-or'), Doc('XOR = 2', 'exclusive-or'),
Doc('NOR', 'not or(inputs)'), Doc('NOR = 3', 'not or(inputs)'),
Doc('NAND', 'not and(inputs)'), Doc('NAND = 4', 'not and(inputs)'),
Doc('XNOR', 'not xor(inputs)')) Doc('XNOR = 5', 'not xor(inputs)'))
) )
(pdef('CondExecMark', (pdef('CondExecMark',
'add ExecutionMask of the input PPV to this opr and readers of the ' 'add ExecutionMask of the input PPV to this opr and readers of the '
'outputs of this opr') 'outputs of this opr')
.add_enum(Doc('GradMode', 'mode for computing the gradient'), .add_enum(Doc('GradMode', 'mode for computing the gradient'),
Doc('SUM', 'normal gradient mode: sum all the activated components'), Doc('SUM = 0', 'normal gradient mode: sum all the activated components'),
Doc('SUM_COND_OUT', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so ' Doc('SUM_COND_OUT = 1', 'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
'oprs that depend on the gradient opr would not be executed ' 'oprs that depend on the gradient opr would not be executed '
'if the forward var is not used.'), 'if the forward var is not used.'),
name_field='grad_mode') name_field='grad_mode')
...@@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields(
execution into account, this option can be used to bypass static execution into account, this option can be used to bypass static
inference errors. This is currently only used by automatically inference errors. This is currently only used by automatically
generated gradient oprs."""), generated gradient oprs."""),
Doc('SHAPE_VALUE', 'enable both shape and value inference'), Doc('SHAPE_VALUE = 0', 'enable both shape and value inference'),
Doc('SHAPE_ONLY', Doc('SHAPE_ONLY = 1',
'only enable shape inference (disable value inference)'), 'only enable shape inference (disable value inference)'),
Doc('NONE', 'disable both shape and value inference'), Doc('NONE = 2', 'disable both shape and value inference'),
name_field='static_infer') name_field='static_infer')
) )
...@@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields(
'number of output vars (i.e. vars per branch)'), 'number of output vars (i.e. vars per branch)'),
1) 1)
.add_enum('Mode', .add_enum('Mode',
Doc('EXACT_ONE', 'copy the var whose mask is activated to the output' Doc('EXACT_ONE = 0', 'copy the var whose mask is activated to the output'
', requiring that exactly one branch is active'), ', requiring that exactly one branch is active'),
Doc('EXACT_ONE_SAME_SHAPE', 'like :attr:`EXACT_ONE` with the ' Doc('EXACT_ONE_SAME_SHAPE = 1', 'like :attr:`EXACT_ONE` with the '
'requirement that all branches have the same shape, so shape ' 'requirement that all branches have the same shape, so shape '
'inference can be easier'), 'inference can be easier'),
Doc('SUM', 'sum all the active branches into output var; require ' Doc('SUM = 2', 'sum all the active branches into output var; require '
'all branches to have the same shape. Extra shape vars are ' 'all branches to have the same shape. Extra shape vars are '
'needed in this mod, so the outputs can be initialized to zero ' 'needed in this mod, so the outputs can be initialized to zero '
'when no input is active (and their shapes are probably ' 'when no input is active (and their shapes are probably '
'unknown).'), 'unknown).'),
Doc('SUM_COND_OUT', 'like :attr:`SUM` but also add an ExecutionMask' Doc('SUM_COND_OUT = 3', 'like :attr:`SUM` but also add an ExecutionMask'
' to the readers of output vars, so they would be skipped if ' ' to the readers of output vars, so they would be skipped if '
' no branch is taken') ' no branch is taken')
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册