Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fb49a283
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fb49a283
编写于
9月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mgb/dnn): refactor enum used in serializing
GitOrigin-RevId: e57af4a59c9b4e090f3972b4d0cf01a2737f8355
上级
d69b5903
变更
10
展开全部
显示空白变更内容
内联
并排
Showing
10 changed file
with
410 addition
and
370 deletion
+410
-370
dnn/scripts/gen_flatbuffers_schema.py
dnn/scripts/gen_flatbuffers_schema.py
+10
-2
dnn/scripts/gen_param_defs.py
dnn/scripts/gen_param_defs.py
+35
-24
dnn/scripts/gen_tablegen.py
dnn/scripts/gen_tablegen.py
+5
-3
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+281
-281
imperative/tablegen/helper.h
imperative/tablegen/helper.h
+11
-8
imperative/tablegen/targets/cpp_class.cpp
imperative/tablegen/targets/cpp_class.cpp
+7
-3
imperative/tablegen/targets/pybind11.cpp
imperative/tablegen/targets/pybind11.cpp
+8
-7
imperative/tablegen/targets/python_c_extension.cpp
imperative/tablegen/targets/python_c_extension.cpp
+13
-3
tools/gen_header_for_bin_reduce.py
tools/gen_header_for_bin_reduce.py
+2
-1
tools/param_defs/mgb_opr_param_defs.py
tools/param_defs/mgb_opr_param_defs.py
+38
-38
未找到文件。
dnn/scripts/gen_flatbuffers_schema.py
浏览文件 @
fb49a283
...
...
@@ -23,8 +23,14 @@ def _cname_to_fbname(cname):
}[
cname
]
def
scramble_enum_member_name
(
name
):
s
=
name
.
find
(
'<<'
)
if
s
!=
-
1
:
name
=
name
[
0
:
name
.
find
(
'='
)
+
1
]
+
' '
+
name
[
s
+
2
:]
if
name
in
(
"MIN"
,
"MAX"
):
return
name
+
"_"
o_name
=
name
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
o_name
in
(
"MIN"
,
"MAX"
):
return
name
.
replace
(
o_name
,
o_name
+
"_"
)
return
name
class
FlatBuffersWriter
(
IndentWriterBase
):
...
...
@@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase):
if
e
.
combined
:
default
=
e
.
compose_combined_enum
(
e
.
default
)
else
:
default
=
scramble_enum_member_name
(
str
(
e
.
members
[
e
.
default
]))
default
=
scramble_enum_member_name
(
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_write
(
"%s:%s%s = %s;"
,
e
.
name_field
,
p
.
name
,
e
.
name
,
default
)
def
_resolve_const
(
self
,
v
):
...
...
@@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase):
if
s
.
combined
:
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
else
:
default
=
scramble_enum_member_name
(
str
(
s
.
members
[
e
.
get_default
()]))
default
=
scramble_enum_member_name
(
str
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_write
(
"%s:%s = %s;"
,
e
.
name_field
,
enum_name
,
default
)
def
_get_fb_default
(
self
,
cppdefault
):
...
...
dnn/scripts/gen_param_defs.py
浏览文件 @
fb49a283
...
...
@@ -121,10 +121,12 @@ class member_defs:
def
normalize_enum_value
(
self
,
value
):
def
normalize
(
v
):
if
isinstance
(
v
,
str
):
if
v
not
in
self
.
members
:
for
idx
,
m
in
enumerate
(
self
.
members
):
m
=
str
(
m
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
v
==
m
:
return
idx
raise
ValueError
(
"enum member '{}' does not exist."
.
format
(
v
))
v
=
self
.
members
.
index
(
v
)
assert
isinstance
(
v
,
int
)
return
v
if
self
.
combined
:
...
...
@@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase):
self
.
_write_doc
(
e
.
name
)
for
idx
,
emem
in
enumerate
(
e
.
members
)
:
for
emem
in
e
.
members
:
if
e
.
combined
:
self
.
_write
(
'%s
= 1 << %d'
,
emem
,
idx
)
self
.
_write
(
'%s
'
,
emem
)
self
.
_write_doc
(
emem
)
else
:
self
.
_write
(
'%s = "%s"'
,
emem
,
emem
)
v
=
str
(
emem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
n
=
int
(
str
(
emem
).
split
(
'='
)[
1
])
self
.
_write
(
'%s = "%s"'
,
v
,
v
)
self
.
_write_doc
(
emem
)
self
.
_enum_member2num
.
append
(
'id({}.{}):{}'
.
format
(
qualname
,
emem
,
idx
))
qualname
,
v
,
n
))
for
emem
,
emem_alias
in
e
.
member_alias
:
em_a
=
emem_alias
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
e
.
combined
:
self
.
_write
(
'%s = %s'
,
em
em_alias
,
e
.
compose_combined_enum
(
emem
))
self
.
_write
(
'%s = %s'
,
em
_a
,
e
.
compose_combined_enum
(
emem
))
else
:
self
.
_write
(
'%s = %s'
,
emem_alias
,
emem
)
em
=
str
(
emem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
self
.
_write
(
'%s = %s'
,
em_a
,
em
)
self
.
_unindent
()
self
.
_write
(
''
)
...
...
@@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase):
if
e
.
combined
:
default
=
e
.
compose_combined_enum
(
e
.
default
)
else
:
default
=
"'{}'"
.
format
(
e
.
members
[
e
.
default
])
default
=
"'{}'"
.
format
(
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
name
=
e
.
name_field
,
...
...
@@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase):
if
s
.
combined
:
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
else
:
default
=
"'{}'"
.
format
(
s
.
members
[
e
.
get_default
()
])
default
=
"'{}'"
.
format
(
s
tr
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
name
=
e
.
name_field
,
cvt
=
'{}.convert({})'
.
format
(
qualname
,
e
.
name_field
),
...
...
@@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase):
def
_on_member_enum
(
self
,
e
):
self
.
_write_doc
(
e
.
name
)
self
.
_write
(
'enum class %s: uint32_t {'
,
e
.
name
,
indent
=
1
)
for
i
dx
,
i
in
enumerate
(
e
.
members
)
:
for
i
in
e
.
members
:
self
.
_write_doc
(
i
)
v
=
'{} = {}'
.
format
(
i
,
idx
)
if
e
.
combined
:
v
=
'{} = 1 << {}'
.
format
(
i
,
idx
)
v
=
str
(
i
)
if
i
is
not
e
.
members
[
-
1
]
or
e
.
member_alias
:
v
+=
','
self
.
_write
(
v
)
...
...
@@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase):
if
e
.
combined
:
self
.
_write
(
'%s = %s,'
,
alias
,
e
.
compose_combined_enum
(
mem
))
else
:
self
.
_write
(
'%s = %s,'
,
alias
,
mem
)
self
.
_write
(
'%s = %s,'
,
str
(
alias
).
split
(
' '
)[
0
].
split
(
'='
)[
0
],
str
(
mem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
self
.
_write
(
'};'
,
indent
=-
1
)
self
.
_non_static_members
.
append
(
e
)
self
.
_write
(
'static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;'
,
...
...
@@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase):
if
e
.
combined
:
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
else
:
default
=
'{}::{}'
.
format
(
e
.
name
,
e
.
members
[
e
.
default
])
value
=
str
(
e
.
members
[
e
.
default
])
value
=
value
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
default
=
'{}::{}'
.
format
(
e
.
name
,
value
)
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
def
_on_member_enum_alias
(
self
,
e
):
...
...
@@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase):
if
s
.
combined
:
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
s
.
compose_combined_enum
(
e
.
default
))
else
:
default
=
'{}::{}'
.
format
(
e
.
name
,
s
.
members
[
e
.
get_default
()])
value
=
str
(
s
.
members
[
e
.
get_default
()])
value
=
value
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
default
=
'{}::{}'
.
format
(
e
.
name
,
value
)
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
def
_on_member_field
(
self
,
f
):
...
...
@@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter):
def
_on_member_enum
(
self
,
e
):
self
.
_write_doc
(
e
.
name
)
self
.
_write
(
'struct %s {'
,
e
.
name
,
indent
=
1
)
for
idx
,
val
in
enumerate
(
e
.
members
)
:
for
val
in
e
.
members
:
self
.
_write_doc
(
val
)
self
.
_write
(
'static const uint32_t %s = %d;'
,
val
,
idx
)
v
=
str
(
val
)
self
.
_write
(
'static const uint32_t %s;'
,
v
)
for
mem
,
alias
in
e
.
member_alias
:
self
.
_write
(
'static const uint32_t %s = %s;'
,
alias
,
mem
)
self
.
_write
(
'static const uint32_t %s = %s;'
,
str
(
alias
).
split
(
' '
)[
0
].
split
(
'='
)[
0
],
str
(
mem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
self
.
_write
(
'};'
,
indent
=-
1
)
def
_on_member_enum_alias
(
self
,
e
):
...
...
@@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase):
members
=
e
.
src_enum
.
members
else
:
members
=
e
.
members
for
idx
,
i
in
enumerate
(
members
):
for
i
in
members
:
v
=
str
(
i
)
v
=
v
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
self
.
_write
(
'case %s::%s::%s: return "%s";'
,
self
.
_param_name
,
e
.
name
,
i
,
i
,
indent
=
0
)
self
.
_param_name
,
e
.
name
,
v
,
v
,
indent
=
0
)
self
.
_write
(
'default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));'
,
self
.
_param_name
,
e
.
name
,
indent
=
0
)
self
.
_write
(
'}'
,
indent
=-
1
)
...
...
dnn/scripts/gen_tablegen.py
浏览文件 @
fb49a283
...
...
@@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase):
fullname
=
"::megdnn::param::{}"
.
format
(
p
.
name
)
enum_def
=
"MgbEnumAttr<
\"
{}
\"
,
\"
{}
\"
, ["
.
format
(
fullname
,
e
.
name
)
def
format
(
v
):
return
'
\"
{}
\"
'
.
format
(
str
(
v
))
return
'
\"
{}
\"
'
.
format
(
str
(
v
)
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
enum_def
+=
','
.
join
(
format
(
i
)
for
i
in
e
.
members
)
if
e
.
combined
:
...
...
@@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase):
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
else
:
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
e
.
members
[
e
.
default
])
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
...
...
@@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase):
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
s
.
compose_combined_enum
(
e
.
get_default
()))
else
:
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
s
.
members
[
e
.
get_default
()])
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
str
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
...
...
dnn/scripts/opr_param_defs.py
浏览文件 @
fb49a283
此差异已折叠。
点击以展开。
imperative/tablegen/helper.h
浏览文件 @
fb49a283
...
...
@@ -241,14 +241,17 @@ private:
if
(
auto
*
enumAttr
=
llvm
::
dyn_cast
<
MgbEnumAttrMixin
>
(
&
it
.
attr
))
{
body
+=
formatv
(
" switch ({0}){{
\n
"
,
"$_self."
+
it
.
name
);
for
(
auto
&&
enumMember
:
enumAttr
->
getEnumMembers
())
{
body
+=
formatv
(
" case {0}::{1}::{2}:
\n
"
,
getCppClassName
(),
enumAttr
->
getEnumName
(),
enumMember
);
body
+=
formatv
(
" props_.emplace_back(
\"
{0}
\"
,
\"
{1}
\"
);
\n
"
,
it
.
name
,
enumMember
);
size_t
d1
=
enumMember
.
find
(
' '
);
size_t
d2
=
enumMember
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
body
+=
formatv
(
" case {0}::{1}::{2}:
\n
"
,
getCppClassName
(),
enumAttr
->
getEnumName
(),
enumMember
.
substr
(
0
,
d
));
body
+=
formatv
(
" props_.emplace_back(
\"
{0}
\"
, "
"
\"
{1}
\"
);
\n
"
,
it
.
name
,
enumMember
.
substr
(
0
,
d
));
body
+=
" break;
\n
"
;
}
body
+=
" default: break;
\n
"
;
...
...
imperative/tablegen/targets/cpp_class.cpp
浏览文件 @
fb49a283
...
...
@@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() {
std
::
vector
<
std
::
string
>
case_body
;
std
::
string
ename
=
formatv
(
"{0}::{1}"
,
op
.
getCppClassName
(),
attr
->
getEnumName
());
llvm
::
for_each
(
attr
->
getEnumMembers
(),
[
&
](
auto
&&
v
){
case_body
.
push_back
(
formatv
(
"case {0}::{1}: return
\"
{1}
\"
;"
,
ename
,
v
));
llvm
::
for_each
(
attr
->
getEnumMembers
(),
[
&
](
auto
&&
v
)
{
size_t
d1
=
v
.
find
(
' '
);
size_t
d2
=
v
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
case_body
.
push_back
(
formatv
(
"case {0}::{1}: return
\"
{1}
\"
;"
,
ename
,
v
.
substr
(
0
,
d
)));
});
os
<<
formatv
(
R"(
template <>
...
...
imperative/tablegen/targets/pybind11.cpp
浏览文件 @
fb49a283
...
...
@@ -50,14 +50,15 @@ void OpDefEmitter::emit() {
);
std
::
vector
<
std
::
string
>
body
;
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
os
<<
formatv
(
"
\n
.value(
\"
{2}
\"
, {0}::{1}::{2})"
,
className
,
attr
->
getEnumName
(),
i
);
size_t
d1
=
i
.
find
(
' '
);
size_t
d2
=
i
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
os
<<
formatv
(
"
\n
.value(
\"
{2}
\"
, {0}::{1}::{2})"
,
className
,
attr
->
getEnumName
(),
i
.
substr
(
0
,
d
));
body
.
push_back
(
formatv
(
"if (str ==
\"
{2}
\"
) return {0}::{1}::{2};"
,
className
,
attr
->
getEnumName
(),
i
));
className
,
attr
->
getEnumName
(),
i
.
substr
(
0
,
d
)));
}
if
(
attr
->
getEnumCombinedFlag
())
{
//! define operator |
...
...
imperative/tablegen/targets/python_c_extension.cpp
浏览文件 @
fb49a283
...
...
@@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() {
&
ctx
);
auto
quote
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
formatv
(
"
\"
{0}
\"
"
,
i
);
size_t
d1
=
i
.
find
(
' '
);
size_t
d2
=
i
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
return
formatv
(
"
\"
{0}
\"
"
,
i
.
substr
(
0
,
d
));
};
os
<<
tgfmt
(
R"(
template<> const char*
...
...
@@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0};
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
attr
->
getEnumMembers
(),
quote
),
", "
));
auto
mem2value
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
tgfmt
(
"{normalize_enum(
\"
$0
\"
), $opClass::$enumClass::$0}"
,
&
ctx
,
i
);
size_t
d1
=
i
.
find
(
' '
);
size_t
d2
=
i
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
return
tgfmt
(
"{normalize_enum(
\"
$0
\"
), $opClass::$enumClass::$0}"
,
&
ctx
,
i
.
substr
(
0
,
d
));
};
os
<<
tgfmt
(
R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
...
...
@@ -192,12 +199,15 @@ os << tgfmt(R"(
auto
&&
members
=
attr
->
getEnumMembers
();
for
(
size_t
idx
=
0
;
idx
<
members
.
size
();
++
idx
)
{
size_t
d1
=
members
[
idx
].
find
(
' '
);
size_t
d2
=
members
[
idx
].
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
os
<<
tgfmt
(
R"({
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})"
,
&
ctx
,
members
[
idx
],
idx
);
})"
,
&
ctx
,
members
[
idx
]
.
substr
(
0
,
d
)
,
idx
);
}
}
...
...
tools/gen_header_for_bin_reduce.py
浏览文件 @
fb49a283
...
...
@@ -136,12 +136,13 @@ class HeaderGen:
mode_list
=
[
i
.
strip
()
for
i
in
fin
]
for
i
in
mode_list
:
i
=
i
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
i
in
self
.
_elemwise_modes
:
content
=
'_cb({})'
.
format
(
i
)
else
:
content
=
''
self
.
_write_def
(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'
.
format
(
i
),
content
)
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'
.
format
(
i
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
),
content
)
self
.
_write_def
(
'MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)'
,
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)'
)
...
...
tools/param_defs/mgb_opr_param_defs.py
浏览文件 @
fb49a283
...
...
@@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'ExecutionPolicy'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
'Strategy'
,
Doc
(
'HEURISTIC'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'HEURISTIC_REPRODUCIBLE'
,
'use heuristic to choose the fastest algorithm, '
Doc
(
'HEURISTIC
= 0
'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'HEURISTIC_REPRODUCIBLE
= 1
'
,
'use heuristic to choose the fastest algorithm, '
'and the chosen algorithm is reproducible'
),
Doc
(
'PROFILE'
,
Doc
(
'PROFILE
= 2
'
,
'run possible algorithms on real device to find the best'
),
Doc
(
'PROFILE_REPRODUCIBLE'
,
Doc
(
'PROFILE_REPRODUCIBLE
= 3
'
,
'the fastest of profile result that is also reproducible'
),
Doc
(
'PROFILE_HEURISTIC'
,
Doc
(
'PROFILE_HEURISTIC
= 4
'
,
'use profile result and heuristic to choose the fastest algorithm'
)).
add_fields
(
'uint64'
,
Doc
(
'workspace_limit'
,
'workspace limit in bytes'
),
...
...
@@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'ExecutionPolicy'
,
'specify how to select an algorithm for an operator'
,
version
=
1
).
add_bit_combination_enum
(
'Strategy'
,
Doc
(
'HEURISTIC'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'PROFILE'
,
Doc
(
'HEURISTIC
= 1 << 0
'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'PROFILE
= 1 << 1
'
,
'run possible algorithms on real device to find the best'
),
Doc
(
'REPRODUCIBLE'
,
Doc
(
'REPRODUCIBLE
= 1 << 2
'
,
'when profile or heuristic algo selection it require the algos'
'must be reproducible'
),
Doc
(
'OPTIMIZED'
,
Doc
(
'OPTIMIZED
= 1 << 3
'
,
'profile require algos are optmized to achieve fast-profile'
),
default
=
(
'HEURISTIC'
,),
member_alias
=
[((
'HEURISTIC'
,
'REPRODUCIBLE'
),
'HEURISTIC_REPRODUCIBLE'
),
...
...
@@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'CollectiveComm'
,
'collective communication between multiple computing '
'nodes on localhost'
)
.
add_enum
(
Doc
(
'Mode'
,
'mode of collective communication'
),
Doc
(
'REDUCE_SUM'
,
'reduce by sum to output computing node'
),
Doc
(
'BROADCAST'
,
'copy input value to each output computing node'
),
Doc
(
'ALL_GATHER'
,
'each output comp node gets the concatenated '
Doc
(
'REDUCE_SUM
= 0
'
,
'reduce by sum to output computing node'
),
Doc
(
'BROADCAST
= 1
'
,
'copy input value to each output computing node'
),
Doc
(
'ALL_GATHER
= 2
'
,
'each output comp node gets the concatenated '
'value of all inputs'
),
Doc
(
'REDUCE_SCATTER_SUM'
,
Doc
(
'REDUCE_SCATTER_SUM
= 3
'
,
'reduce inputs by sum and each output gets one part of it'
),
Doc
(
'ALL_REDUCE_SUM'
,
'every output gets the sum of all inputs'
),
Doc
(
'ALL_REDUCE_MAX'
,
'every output gets the max of all inputs'
),
Doc
(
'ALL_REDUCE_MIN'
,
'every output gets the min of all inputs'
),
Doc
(
'ALL_REDUCE_PROD'
,
'every output gets the prod of all inputs'
),
Doc
(
'GATHER'
,
'concat inputs to one node'
),
Doc
(
'SCATTER'
,
'scatter input to each output computing node'
),
Doc
(
'ALL_TO_ALL'
,
'scatter inputs and gather them on each computing node'
),
Doc
(
'ALL_REDUCE_SUM
= 4
'
,
'every output gets the sum of all inputs'
),
Doc
(
'ALL_REDUCE_MAX
= 5
'
,
'every output gets the max of all inputs'
),
Doc
(
'ALL_REDUCE_MIN
= 6
'
,
'every output gets the min of all inputs'
),
Doc
(
'ALL_REDUCE_PROD
= 7
'
,
'every output gets the prod of all inputs'
),
Doc
(
'GATHER
= 8
'
,
'concat inputs to one node'
),
Doc
(
'SCATTER
= 9
'
,
'scatter input to each output computing node'
),
Doc
(
'ALL_TO_ALL
= 10
'
,
'scatter inputs and gather them on each computing node'
),
name_field
=
'mode'
))
(
pdef
(
'FakeSerializedDType'
,
...
...
@@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields(
'evaluate a predicate and branch keys to setup ExecutionMask objects '
'with associated predicate proxy vars (PPVs)'
)
.
add_enum
(
Doc
(
'Mode'
,
'how to compare predicate var with branch keys'
),
Doc
(
'CASE'
,
Doc
(
'CASE
= 0
'
,
'The outputs correspond to branch keys, '
'and the one which equals predicate would be activated. '
'This behaves like a case-statement in many languages.'
),
Doc
(
'CASE_FALLBACK'
,
'like :attr:`CASE`, but add an extra output '
Doc
(
'CASE_FALLBACK
= 1
'
,
'like :attr:`CASE`, but add an extra output '
'that would be activated if no branch is matched'
),
Doc
(
'PIECEWISE'
,
'One more outputs would be produced than the '
Doc
(
'PIECEWISE
= 2
'
,
'One more outputs would be produced than the '
'number of branch keys, representing the interval in which the '
'predicate var fits in. The intervals are defined as '
r
':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, '
...
...
@@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'CondExecPredLogical'
,
'compute a logical function over a set of PPVs'
)
.
add_enum
(
'Mode'
,
Doc
(
'OR'
,
'logical or'
),
Doc
(
'AND'
,
'logical and'
),
Doc
(
'XOR'
,
'exclusive-or'
),
Doc
(
'NOR'
,
'not or(inputs)'
),
Doc
(
'NAND'
,
'not and(inputs)'
),
Doc
(
'XNOR'
,
'not xor(inputs)'
))
.
add_enum
(
'Mode'
,
Doc
(
'OR
= 0
'
,
'logical or'
),
Doc
(
'AND
= 1
'
,
'logical and'
),
Doc
(
'XOR
= 2
'
,
'exclusive-or'
),
Doc
(
'NOR
= 3
'
,
'not or(inputs)'
),
Doc
(
'NAND
= 4
'
,
'not and(inputs)'
),
Doc
(
'XNOR
= 5
'
,
'not xor(inputs)'
))
)
(
pdef
(
'CondExecMark'
,
'add ExecutionMask of the input PPV to this opr and readers of the '
'outputs of this opr'
)
.
add_enum
(
Doc
(
'GradMode'
,
'mode for computing the gradient'
),
Doc
(
'SUM'
,
'normal gradient mode: sum all the activated components'
),
Doc
(
'SUM_COND_OUT'
,
'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
Doc
(
'SUM
= 0
'
,
'normal gradient mode: sum all the activated components'
),
Doc
(
'SUM_COND_OUT
= 1
'
,
'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
'oprs that depend on the gradient opr would not be executed '
'if the forward var is not used.'
),
name_field
=
'grad_mode'
)
...
...
@@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields(
execution into account, this option can be used to bypass static
inference errors. This is currently only used by automatically
generated gradient oprs."""
),
Doc
(
'SHAPE_VALUE'
,
'enable both shape and value inference'
),
Doc
(
'SHAPE_ONLY'
,
Doc
(
'SHAPE_VALUE
= 0
'
,
'enable both shape and value inference'
),
Doc
(
'SHAPE_ONLY
= 1
'
,
'only enable shape inference (disable value inference)'
),
Doc
(
'NONE'
,
'disable both shape and value inference'
),
Doc
(
'NONE
= 2
'
,
'disable both shape and value inference'
),
name_field
=
'static_infer'
)
)
...
...
@@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields(
'number of output vars (i.e. vars per branch)'
),
1
)
.
add_enum
(
'Mode'
,
Doc
(
'EXACT_ONE'
,
'copy the var whose mask is activated to the output'
Doc
(
'EXACT_ONE
= 0
'
,
'copy the var whose mask is activated to the output'
', requiring that exactly one branch is active'
),
Doc
(
'EXACT_ONE_SAME_SHAPE'
,
'like :attr:`EXACT_ONE` with the '
Doc
(
'EXACT_ONE_SAME_SHAPE
= 1
'
,
'like :attr:`EXACT_ONE` with the '
'requirement that all branches have the same shape, so shape '
'inference can be easier'
),
Doc
(
'SUM'
,
'sum all the active branches into output var; require '
Doc
(
'SUM
= 2
'
,
'sum all the active branches into output var; require '
'all branches to have the same shape. Extra shape vars are '
'needed in this mod, so the outputs can be initialized to zero '
'when no input is active (and their shapes are probably '
'unknown).'
),
Doc
(
'SUM_COND_OUT'
,
'like :attr:`SUM` but also add an ExecutionMask'
Doc
(
'SUM_COND_OUT
= 3
'
,
'like :attr:`SUM` but also add an ExecutionMask'
' to the readers of output vars, so they would be skipped if '
' no branch is taken'
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录