Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fb49a283
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fb49a283
编写于
9月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mgb/dnn): refactor enum used in serializing
GitOrigin-RevId: e57af4a59c9b4e090f3972b4d0cf01a2737f8355
上级
d69b5903
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
410 addition
and
370 deletion
+410
-370
dnn/scripts/gen_flatbuffers_schema.py
dnn/scripts/gen_flatbuffers_schema.py
+10
-2
dnn/scripts/gen_param_defs.py
dnn/scripts/gen_param_defs.py
+35
-24
dnn/scripts/gen_tablegen.py
dnn/scripts/gen_tablegen.py
+5
-3
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+281
-281
imperative/tablegen/helper.h
imperative/tablegen/helper.h
+11
-8
imperative/tablegen/targets/cpp_class.cpp
imperative/tablegen/targets/cpp_class.cpp
+7
-3
imperative/tablegen/targets/pybind11.cpp
imperative/tablegen/targets/pybind11.cpp
+8
-7
imperative/tablegen/targets/python_c_extension.cpp
imperative/tablegen/targets/python_c_extension.cpp
+13
-3
tools/gen_header_for_bin_reduce.py
tools/gen_header_for_bin_reduce.py
+2
-1
tools/param_defs/mgb_opr_param_defs.py
tools/param_defs/mgb_opr_param_defs.py
+38
-38
未找到文件。
dnn/scripts/gen_flatbuffers_schema.py
浏览文件 @
fb49a283
...
...
@@ -23,8 +23,14 @@ def _cname_to_fbname(cname):
}[
cname
]
def
scramble_enum_member_name
(
name
):
s
=
name
.
find
(
'<<'
)
if
s
!=
-
1
:
name
=
name
[
0
:
name
.
find
(
'='
)
+
1
]
+
' '
+
name
[
s
+
2
:]
if
name
in
(
"MIN"
,
"MAX"
):
return
name
+
"_"
o_name
=
name
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
o_name
in
(
"MIN"
,
"MAX"
):
return
name
.
replace
(
o_name
,
o_name
+
"_"
)
return
name
class
FlatBuffersWriter
(
IndentWriterBase
):
...
...
@@ -97,7 +103,8 @@ class FlatBuffersWriter(IndentWriterBase):
if
e
.
combined
:
default
=
e
.
compose_combined_enum
(
e
.
default
)
else
:
default
=
scramble_enum_member_name
(
str
(
e
.
members
[
e
.
default
]))
default
=
scramble_enum_member_name
(
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_write
(
"%s:%s%s = %s;"
,
e
.
name_field
,
p
.
name
,
e
.
name
,
default
)
def
_resolve_const
(
self
,
v
):
...
...
@@ -124,7 +131,8 @@ class FlatBuffersWriter(IndentWriterBase):
if
s
.
combined
:
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
else
:
default
=
scramble_enum_member_name
(
str
(
s
.
members
[
e
.
get_default
()]))
default
=
scramble_enum_member_name
(
str
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_write
(
"%s:%s = %s;"
,
e
.
name_field
,
enum_name
,
default
)
def
_get_fb_default
(
self
,
cppdefault
):
...
...
dnn/scripts/gen_param_defs.py
浏览文件 @
fb49a283
...
...
@@ -121,10 +121,12 @@ class member_defs:
def
normalize_enum_value
(
self
,
value
):
def
normalize
(
v
):
if
isinstance
(
v
,
str
):
if
v
not
in
self
.
members
:
raise
ValueError
(
"enum member '{}' does not exist."
.
format
(
v
))
v
=
self
.
members
.
index
(
v
)
for
idx
,
m
in
enumerate
(
self
.
members
):
m
=
str
(
m
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
v
==
m
:
return
idx
raise
ValueError
(
"enum member '{}' does not exist."
.
format
(
v
))
assert
isinstance
(
v
,
int
)
return
v
if
self
.
combined
:
...
...
@@ -524,21 +526,25 @@ class SerializedDType(_ParamDefBase):
self
.
_write_doc
(
e
.
name
)
for
idx
,
emem
in
enumerate
(
e
.
members
)
:
for
emem
in
e
.
members
:
if
e
.
combined
:
self
.
_write
(
'%s
= 1 << %d'
,
emem
,
idx
)
self
.
_write
(
'%s
'
,
emem
)
self
.
_write_doc
(
emem
)
else
:
self
.
_write
(
'%s = "%s"'
,
emem
,
emem
)
v
=
str
(
emem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
n
=
int
(
str
(
emem
).
split
(
'='
)[
1
])
self
.
_write
(
'%s = "%s"'
,
v
,
v
)
self
.
_write_doc
(
emem
)
self
.
_enum_member2num
.
append
(
'id({}.{}):{}'
.
format
(
qualname
,
emem
,
idx
))
qualname
,
v
,
n
))
for
emem
,
emem_alias
in
e
.
member_alias
:
em_a
=
emem_alias
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
e
.
combined
:
self
.
_write
(
'%s = %s'
,
em
em_alias
,
e
.
compose_combined_enum
(
emem
))
self
.
_write
(
'%s = %s'
,
em
_a
,
e
.
compose_combined_enum
(
emem
))
else
:
self
.
_write
(
'%s = %s'
,
emem_alias
,
emem
)
em
=
str
(
emem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
self
.
_write
(
'%s = %s'
,
em_a
,
em
)
self
.
_unindent
()
self
.
_write
(
''
)
...
...
@@ -546,7 +552,7 @@ class SerializedDType(_ParamDefBase):
if
e
.
combined
:
default
=
e
.
compose_combined_enum
(
e
.
default
)
else
:
default
=
"'{}'"
.
format
(
e
.
members
[
e
.
default
])
default
=
"'{}'"
.
format
(
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
name
=
e
.
name_field
,
...
...
@@ -564,7 +570,7 @@ class SerializedDType(_ParamDefBase):
if
s
.
combined
:
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
else
:
default
=
"'{}'"
.
format
(
s
.
members
[
e
.
get_default
()
])
default
=
"'{}'"
.
format
(
s
tr
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
name
=
e
.
name_field
,
cvt
=
'{}.convert({})'
.
format
(
qualname
,
e
.
name_field
),
...
...
@@ -700,11 +706,9 @@ class CPPWriter(IndentWriterBase):
def
_on_member_enum
(
self
,
e
):
self
.
_write_doc
(
e
.
name
)
self
.
_write
(
'enum class %s: uint32_t {'
,
e
.
name
,
indent
=
1
)
for
i
dx
,
i
in
enumerate
(
e
.
members
)
:
for
i
in
e
.
members
:
self
.
_write_doc
(
i
)
v
=
'{} = {}'
.
format
(
i
,
idx
)
if
e
.
combined
:
v
=
'{} = 1 << {}'
.
format
(
i
,
idx
)
v
=
str
(
i
)
if
i
is
not
e
.
members
[
-
1
]
or
e
.
member_alias
:
v
+=
','
self
.
_write
(
v
)
...
...
@@ -712,7 +716,7 @@ class CPPWriter(IndentWriterBase):
if
e
.
combined
:
self
.
_write
(
'%s = %s,'
,
alias
,
e
.
compose_combined_enum
(
mem
))
else
:
self
.
_write
(
'%s = %s,'
,
alias
,
mem
)
self
.
_write
(
'%s = %s,'
,
str
(
alias
).
split
(
' '
)[
0
].
split
(
'='
)[
0
],
str
(
mem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
self
.
_write
(
'};'
,
indent
=-
1
)
self
.
_non_static_members
.
append
(
e
)
self
.
_write
(
'static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;'
,
...
...
@@ -720,7 +724,9 @@ class CPPWriter(IndentWriterBase):
if
e
.
combined
:
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
else
:
default
=
'{}::{}'
.
format
(
e
.
name
,
e
.
members
[
e
.
default
])
value
=
str
(
e
.
members
[
e
.
default
])
value
=
value
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
default
=
'{}::{}'
.
format
(
e
.
name
,
value
)
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
def
_on_member_enum_alias
(
self
,
e
):
...
...
@@ -732,7 +738,9 @@ class CPPWriter(IndentWriterBase):
if
s
.
combined
:
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
s
.
compose_combined_enum
(
e
.
default
))
else
:
default
=
'{}::{}'
.
format
(
e
.
name
,
s
.
members
[
e
.
get_default
()])
value
=
str
(
s
.
members
[
e
.
get_default
()])
value
=
value
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
default
=
'{}::{}'
.
format
(
e
.
name
,
value
)
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
def
_on_member_field
(
self
,
f
):
...
...
@@ -754,11 +762,12 @@ class CPPEnumValueWriter(CPPWriter):
def
_on_member_enum
(
self
,
e
):
self
.
_write_doc
(
e
.
name
)
self
.
_write
(
'struct %s {'
,
e
.
name
,
indent
=
1
)
for
idx
,
val
in
enumerate
(
e
.
members
)
:
for
val
in
e
.
members
:
self
.
_write_doc
(
val
)
self
.
_write
(
'static const uint32_t %s = %d;'
,
val
,
idx
)
v
=
str
(
val
)
self
.
_write
(
'static const uint32_t %s;'
,
v
)
for
mem
,
alias
in
e
.
member_alias
:
self
.
_write
(
'static const uint32_t %s = %s;'
,
alias
,
mem
)
self
.
_write
(
'static const uint32_t %s = %s;'
,
str
(
alias
).
split
(
' '
)[
0
].
split
(
'='
)[
0
],
str
(
mem
).
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
self
.
_write
(
'};'
,
indent
=-
1
)
def
_on_member_enum_alias
(
self
,
e
):
...
...
@@ -848,9 +857,11 @@ class CPPParamJsonFuncWriter(IndentWriterBase):
members
=
e
.
src_enum
.
members
else
:
members
=
e
.
members
for
idx
,
i
in
enumerate
(
members
):
for
i
in
members
:
v
=
str
(
i
)
v
=
v
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
self
.
_write
(
'case %s::%s::%s: return "%s";'
,
self
.
_param_name
,
e
.
name
,
i
,
i
,
indent
=
0
)
self
.
_param_name
,
e
.
name
,
v
,
v
,
indent
=
0
)
self
.
_write
(
'default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));'
,
self
.
_param_name
,
e
.
name
,
indent
=
0
)
self
.
_write
(
'}'
,
indent
=-
1
)
...
...
dnn/scripts/gen_tablegen.py
浏览文件 @
fb49a283
...
...
@@ -89,7 +89,7 @@ class ConverterWriter(IndentWriterBase):
fullname
=
"::megdnn::param::{}"
.
format
(
p
.
name
)
enum_def
=
"MgbEnumAttr<
\"
{}
\"
,
\"
{}
\"
, ["
.
format
(
fullname
,
e
.
name
)
def
format
(
v
):
return
'
\"
{}
\"
'
.
format
(
str
(
v
))
return
'
\"
{}
\"
'
.
format
(
str
(
v
)
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
)
enum_def
+=
','
.
join
(
format
(
i
)
for
i
in
e
.
members
)
if
e
.
combined
:
...
...
@@ -110,7 +110,8 @@ class ConverterWriter(IndentWriterBase):
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
else
:
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
e
.
members
[
e
.
default
])
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
str
(
e
.
members
[
e
.
default
]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
...
...
@@ -134,7 +135,8 @@ class ConverterWriter(IndentWriterBase):
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
s
.
compose_combined_enum
(
e
.
get_default
()))
else
:
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
s
.
members
[
e
.
get_default
()])
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
str
(
s
.
members
[
e
.
get_default
()]).
split
(
' '
)[
0
].
split
(
'='
)[
0
])
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
...
...
dnn/scripts/opr_param_defs.py
浏览文件 @
fb49a283
...
...
@@ -3,7 +3,7 @@ pdef('Empty')
pdef
(
'Axis'
).
add_fields
(
'int32'
,
'axis'
,
0
)
(
pdef
(
'Convolution'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
'Mode'
,
'CROSS_CORRELATION
'
,
'CONVOLUTION
'
).
add_enum
(
'Mode'
,
'CROSS_CORRELATION
= 0'
,
'CONVOLUTION = 1
'
).
add_fields
(
'uint32'
,
Doc
(
'pad_h'
,
'padding on one side on the first dimension'
),
0
,
...
...
@@ -16,41 +16,41 @@ pdef('Axis').add_fields('int32', 'axis', 0)
'on the second dimension'
),
1
).
add_enum
(
'DataType'
,
Doc
(
'FLOAT'
,
'input/output both float32/float16'
),
'INT8x8x16'
,
'INT8x8x32'
,
Doc
(
'FLOAT_IO16xC32'
,
'input/output both float16, the internal '
Doc
(
'FLOAT
= 0
'
,
'input/output both float32/float16'
),
'INT8x8x16
= 1
'
,
'INT8x8x32
= 2
'
,
Doc
(
'FLOAT_IO16xC32
= 3
'
,
'input/output both float16, the internal '
'compute is float32'
),
Doc
(
'QUINT8x8x32'
,
'input QuantizedAsymm8, output QuantizedS32'
),
Doc
(
'INT8x8xX'
,
'input int8, output specified by tensor DType'
),
Doc
(
'QUINT4x4x32'
,
'input QuantizedAsymm4, output QuantizedS32'
),
Doc
(
'QUINT8x8x32
= 4
'
,
'input QuantizedAsymm8, output QuantizedS32'
),
Doc
(
'INT8x8xX
= 5
'
,
'input int8, output specified by tensor DType'
),
Doc
(
'QUINT4x4x32
= 6
'
,
'input QuantizedAsymm4, output QuantizedS32'
),
name_field
=
'data_type'
).
add_enum
(
'Sparse'
,
Doc
(
'DENSE'
,
'dense convolution: filter shape should be '
Doc
(
'DENSE
= 0
'
,
'dense convolution: filter shape should be '
'[oc, ic, spatial...] if format is NCHW, '
'[oc, spatial..., ic] if format is NHWC'
),
Doc
(
'GROUP'
,
'group convolution: filter shape should be '
Doc
(
'GROUP
= 1
'
,
'group convolution: filter shape should be '
'[group, oc_per_group, ic_per_group, spatial...] if format is NCHW, '
'[group, oc_per_group, spatial..., ic_per_group] if format is NHWC'
)
).
add_enum
(
Doc
(
'Format'
,
'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'
),
'NCHW
'
,
'NHWC'
,
'NHWCD4'
,
'NCHW4'
,
'NCHW8'
,
'NCHW32'
,
'NCHW88
'
,
'NCHW44
'
,
'NCHW44_DOT
'
,
Doc
(
'NCHW_WINOGRAD'
,
'NCHW layout with weights tranformed by winograd'
),
Doc
(
'NCHW88_WINOGRAD'
,
'NCHW88 layout with weights tranformed by winograd'
),
Doc
(
'NCHW44_WINOGRAD'
,
'NCHW44 layout with weights tranformed by winograd'
),
Doc
(
'NCHW4_NCHW32'
,
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'
),
Doc
(
'NCHW32_NCHW4'
,
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'
),
Doc
(
'NCHW4_NCHW'
,
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'
),
Doc
(
'NCHW4_NHWC
'
,
'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'
),
Doc
(
'NHWC_NCHW'
,
'NHWC_NCHW means input tensors are nhwc layout, '
'NCHW
= 0'
,
'NHWC = 1'
,
'NHWCD4 = 2'
,
'NCHW4 = 3'
,
'NCHW8 = 4'
,
'NCHW32 = 5'
,
'NCHW88 = 6
'
,
'NCHW44
= 7'
,
'NCHW44_DOT = 8
'
,
Doc
(
'NCHW_WINOGRAD
= 9
'
,
'NCHW layout with weights tranformed by winograd'
),
Doc
(
'NCHW88_WINOGRAD
= 10
'
,
'NCHW88 layout with weights tranformed by winograd'
),
Doc
(
'NCHW44_WINOGRAD
= 11
'
,
'NCHW44 layout with weights tranformed by winograd'
),
Doc
(
'NCHW4_NCHW32
= 12
'
,
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'
),
Doc
(
'NCHW32_NCHW4
= 13
'
,
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'
),
Doc
(
'NCHW4_NCHW
= 14
'
,
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'
),
Doc
(
'NCHW4_NHWC
= 15'
,
'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'
),
Doc
(
'NHWC_NCHW
= 16
'
,
'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'
),
Doc
(
'NHWC_NCHW4_IC_SMALL'
,
'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
Doc
(
'NHWC_NCHW4_IC_SMALL
= 17
'
,
'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'
),
Doc
(
'NCHW_NCHW4_IC_SMALL'
,
'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
Doc
(
'NCHW_NCHW4_IC_SMALL
= 18
'
,
'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'
),
Doc
(
'CHWN4'
,
'CHWN4 is currently only used on Nvidia platform for fast implementation '
Doc
(
'CHWN4
= 19
'
,
'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'
))
)
...
...
@@ -72,9 +72,9 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum
(
Doc
(
'ComputeMode'
,
'Specifies special computation modes, e.g. '
'different combinations of intermediate result '
'data types.'
),
Doc
(
'DEFAULT'
,
'No special requirements on the precision of '
Doc
(
'DEFAULT
= 0
'
,
'No special requirements on the precision of '
'intermediate results.'
),
Doc
(
'FLOAT32'
,
'Use Float32 accumulator and intermediate result. '
Doc
(
'FLOAT32
= 1
'
,
'Use Float32 accumulator and intermediate result. '
'Only supported when input and output is Float16.'
),
name_field
=
'compute_mode'
)
)
...
...
@@ -95,21 +95,21 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum_alias
(
'Sparse'
,
'ConvolutionV0'
).
add_enum
(
Doc
(
'Format'
,
'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'
),
'NCHW
'
,
'NHWC'
,
'NHWCD4'
,
'NCHW4'
,
'NCHW8'
,
'NCHW32'
,
'NCHW88
'
,
'NCHW44
'
,
'NCHW44_DOT
'
,
Doc
(
'NCHW4_NCHW32'
,
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'
),
Doc
(
'NCHW32_NCHW4'
,
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'
),
Doc
(
'NCHW4_NCHW
'
,
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'
),
Doc
(
'NCHW4_NHWC
'
,
'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'
),
Doc
(
'NHWC_NCHW'
,
'NHWC_NCHW means input tensors are nhwc layout, '
'NCHW
= 0'
,
'NHWC = 1'
,
'NHWCD4 = 2'
,
'NCHW4 = 3'
,
'NCHW8 = 4'
,
'NCHW32 = 5'
,
'NCHW88 = 6
'
,
'NCHW44
= 7'
,
'NCHW44_DOT = 8
'
,
Doc
(
'NCHW4_NCHW32
= 9
'
,
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'
),
Doc
(
'NCHW32_NCHW4
= 10
'
,
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'
),
Doc
(
'NCHW4_NCHW
= 11'
,
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'
),
Doc
(
'NCHW4_NHWC
= 12'
,
'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'
),
Doc
(
'NHWC_NCHW
= 13
'
,
'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'
),
Doc
(
'NHWC_NCHW4_IC_SMALL'
,
'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
Doc
(
'NHWC_NCHW4_IC_SMALL
= 14
'
,
'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'
),
Doc
(
'NCHW_NCHW4_IC_SMALL'
,
'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
Doc
(
'NCHW_NCHW4_IC_SMALL
= 15
'
,
'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'
),
Doc
(
'CHWN4'
,
'CHWN4 is currently only used on Nvidia platform for fast implementation '
Doc
(
'CHWN4
= 16
'
,
'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'
),
Doc
(
'NCHW64'
,
'NCHW64 is designed for convolution implementation to utilizing TensorCore '
Doc
(
'NCHW64
= 17
'
,
'NCHW64 is designed for convolution implementation to utilizing TensorCore '
'instructions for 4-bit integers on Nvidia platforms'
)).
add_enum_alias
(
'ComputeMode'
,
'ConvolutionV1'
,
name_field
=
'compute_mode'
)
)
...
...
@@ -129,15 +129,15 @@ pdef('Axis').add_fields('int32', 'axis', 0)
)
(
pdef
(
'ConvPooling'
).
add_enum
(
'Method'
,
'WITH_TEXTURE_OBJ
'
,
'WITH_SHARED_MEM
'
).
add_enum
(
'Method'
,
'WITH_TEXTURE_OBJ
= 0'
,
'WITH_SHARED_MEM = 1
'
).
add_enum_alias
(
'ConvMode'
,
'ConvolutionV0'
,
'Mode'
).
add_enum
(
'PoolMode'
,
'AVERAGE
'
,
'MAX
'
).
add_enum
(
'NonlineMode'
,
'IDENTITY
'
,
'RELU'
,
'SIGMOID
'
).
add_enum
(
'PoolMode'
,
'AVERAGE
= 0'
,
'MAX = 1
'
).
add_enum
(
'NonlineMode'
,
'IDENTITY
= 0'
,
'RELU = 1'
,
'SIGMOID = 2
'
).
add_fields
(
'uint32'
,
'pool_shape_h'
,
1
,
'pool_shape_w'
,
1
,
'pool_stride_h'
,
1
,
'pool_stride_w'
,
1
,
\
'pool_pad_h'
,
0
,
'pool_pad_w'
,
0
,
'conv_stride_h'
,
1
,
'conv_stride_w'
,
1
,
'conv_pad_h'
,
0
,
'conv_pad_w'
,
0
))
(
pdef
(
'ConvBias'
,
'legacy conv_bias'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
'NonlineMode'
,
'IDENTITY
'
,
'RELU'
,
'SIGMOID'
,
'H_SWISH
'
).
add_enum
(
'NonlineMode'
,
'IDENTITY
= 0'
,
'RELU = 1'
,
'SIGMOID = 2'
,
'H_SWISH = 3
'
).
add_enum_alias
(
'Mode'
,
'ConvolutionV0'
).
add_fields
(
'uint32'
,
'pad_h'
,
0
,
'pad_w'
,
0
,
'stride_h'
,
1
,
'stride_w'
,
1
))
...
...
@@ -215,9 +215,9 @@ pdef('Axis').add_fields('int32', 'axis', 0)
)
(
pdef
(
'SeparableConv'
).
add_enum_alias
(
'Mode'
,
'ConvolutionV0'
).
add_enum
(
'BorderMode'
,
'BORDER_REPLICATE
'
,
'BORDER_REFLECT
'
,
'BORDER_REFLECT_101
'
,
'BORDER_WRAP
'
,
'BORDER_CONSTANT
'
,
'BORDER_TRANSPARENT'
,
'BORDER_ISOLATED
'
).
add_enum
(
'BorderMode'
,
'BORDER_REPLICATE
= 0'
,
'BORDER_REFLECT = 1
'
,
'BORDER_REFLECT_101
= 2'
,
'BORDER_WRAP = 3
'
,
'BORDER_CONSTANT
= 4'
,
'BORDER_TRANSPARENT = 5'
,
'BORDER_ISOLATED = 6
'
).
add_fields
(
'bool'
,
'is_symm_kernel'
,
'true'
).
add_fields
(
'uint32'
,
'pad_h'
,
0
,
'pad_w'
,
0
,
'stride_h'
,
1
,
'stride_w'
,
1
,
'ksize_h'
,
3
,
'ksize_w'
,
3
,
'anchor_h'
,
1
,
'anchor_w'
,
1
))
...
...
@@ -233,11 +233,11 @@ pdef('Axis').add_fields('int32', 'axis', 0)
(
pdef
(
'Pooling'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
'Mode'
,
Doc
(
'MAX'
,
'maximum value inside pooling window'
),
Doc
(
'AVERAGE'
,
Doc
(
'MAX
= 0
'
,
'maximum value inside pooling window'
),
Doc
(
'AVERAGE
= 1
'
,
'arithmetic mean of all values inside pooling window. Padding values '
'are taken into account and are viewed as zero'
),
Doc
(
'AVERAGE_COUNT_EXCLUDE_PADDING'
,
Doc
(
'AVERAGE_COUNT_EXCLUDE_PADDING
= 2
'
,
'arithmetic mean of all values inside pooling window. No padding is'
'used.'
)
).
...
...
@@ -273,15 +273,15 @@ pdef('Axis').add_fields('int32', 'axis', 0)
(
pdef
(
'BN'
).
add_enum
(
'ParamDim'
,
Doc
(
'DIM_11HW'
,
'Dim of params (Sigma, Mu) is 1 x 1 x H x W'
),
Doc
(
'DIM_1CHW'
,
'Dim of params (Sigma, Mu) is 1 x C x H x W'
),
Doc
(
'DIM_1C11'
,
'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'
),
Doc
(
'DIM_11HW
= 0
'
,
'Dim of params (Sigma, Mu) is 1 x 1 x H x W'
),
Doc
(
'DIM_1CHW
= 1
'
,
'Dim of params (Sigma, Mu) is 1 x C x H x W'
),
Doc
(
'DIM_1C11
= 2
'
,
'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'
),
name_field
=
'param_dim'
).
add_enum
(
'FwdMode'
,
Doc
(
'TRAINING'
,
'Training phase.'
),
Doc
(
'INFERENCE'
,
'Inference phase.'
),
Doc
(
'TRAINING
= 0
'
,
'Training phase.'
),
Doc
(
'INFERENCE
= 1
'
,
'Inference phase.'
),
name_field
=
'fwd_mode'
).
add_fields
(
'float64'
,
'epsilon'
,
'1e-4f'
).
...
...
@@ -293,22 +293,22 @@ pdef('Axis').add_fields('int32', 'axis', 0)
(
pdef
(
'ROIPooling'
).
add_enum
(
'Mode'
,
Doc
(
'MAX'
,
'maximum value inside pooling window; pooling result would '
Doc
(
'MAX
= 0
'
,
'maximum value inside pooling window; pooling result would '
'be 0 if pooling window is empty'
),
Doc
(
'AVERAGE'
,
Doc
(
'AVERAGE
= 1
'
,
'arithmetic mean of all values inside pooling window; pooling result '
'would be 0 if pooling window is empty'
)
).
add_fields
(
'float32'
,
'scale'
,
'1.f'
))
INTERP_MODES
=
[
'NEAREST
'
,
'LINEAR'
,
'AREA'
,
'CUBIC'
,
'LANCZOS
4'
]
BORDER_MODES
=
[
Doc
(
'REPLICATE'
,
'aaaaaa|abcdefgh|hhhhhhh'
),
Doc
(
'REFLECT'
,
'fedcba|abcdefgh|hgfedcb'
),
Doc
(
'REFLECT_101'
,
'gfedcb|abcdefgh|gfedcba'
),
Doc
(
'WRAP'
,
'cdefgh|abcdefgh|abcdefg'
),
Doc
(
'CONSTANT'
,
'iiiiii|abcdefgh|iiiiiii'
),
Doc
(
'TRANSPARENT'
,
''
),
Doc
(
'ISOLATED'
,
''
)]
INTERP_MODES
=
[
'NEAREST
= 0'
,
'LINEAR = 1'
,
'AREA = 2'
,
'CUBIC = 3'
,
'LANCZOS4 =
4'
]
BORDER_MODES
=
[
Doc
(
'REPLICATE
= 0
'
,
'aaaaaa|abcdefgh|hhhhhhh'
),
Doc
(
'REFLECT
= 1
'
,
'fedcba|abcdefgh|hgfedcb'
),
Doc
(
'REFLECT_101
= 2
'
,
'gfedcb|abcdefgh|gfedcba'
),
Doc
(
'WRAP
= 3
'
,
'cdefgh|abcdefgh|abcdefg'
),
Doc
(
'CONSTANT
= 4
'
,
'iiiiii|abcdefgh|iiiiiii'
),
Doc
(
'TRANSPARENT
= 5
'
,
''
),
Doc
(
'ISOLATED
= 6
'
,
''
)]
(
pdef
(
'WarpPerspective'
,
version
=
1
,
is_legacy
=
True
).
add_enum
(
'InterpolationMode'
,
*
INTERP_MODES
,
name_field
=
'imode'
,
default
=
1
,
...
...
@@ -328,181 +328,181 @@ BORDER_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'),
add_fields
(
'float32'
,
Doc
(
'border_val'
,
'used for CONSTANT bmode'
),
'.0f'
))
pdef
(
'SpatialTfGridGenerator'
).
add_enum
(
'Mode'
,
'AFFINE'
)
pdef
(
'SpatialTfSampler'
).
add_enum
(
'Mode'
,
'BILINEAR'
)
pdef
(
'SpatialTfGridGenerator'
).
add_enum
(
'Mode'
,
'AFFINE
= 0
'
)
pdef
(
'SpatialTfSampler'
).
add_enum
(
'Mode'
,
'BILINEAR
= 0
'
)
pdef
(
'AddUpdate'
).
add_fields
(
'float32'
,
'alpha'
,
'1.f'
,
'beta'
,
'1.f'
,
'bias'
,
'0.f'
)
pdef
(
'Elemwise'
).
add_enum
(
'Mode'
,
Doc
(
'RELU'
,
'unary: max(x, 0)'
),
Doc
(
'ABS'
,
'unary: abs(x)'
),
Doc
(
'ACOS'
,
'unary: acos(x)'
),
Doc
(
'ASIN'
,
'unary: asin(x)'
),
Doc
(
'CEIL'
,
'unary: ceil(x)'
),
Doc
(
'COS'
,
'unary: cos(x)'
),
Doc
(
'EXP'
,
'unary: exp(x)'
),
Doc
(
'EXPM1'
,
'unary: numerically stable exp(x)-1'
),
Doc
(
'FLOOR'
,
'unary: floor(x)'
),
Doc
(
'LOG'
,
'unary: natural logarithm, log(x)'
),
Doc
(
'LOG1P'
,
'unary: numerically stable log(x+1)'
),
Doc
(
'NEGATE'
,
'unary: -x'
),
Doc
(
'SIGMOID'
,
'unary: 1/(1+exp(-x))'
),
Doc
(
'SIN'
,
'unary: sin(x)'
),
Doc
(
'TANH'
,
'unary: tanh(x)'
),
Doc
(
'ABS_GRAD'
,
'binary: x > 0 ? y : -y'
),
Doc
(
'ADD'
,
'binary: x + y'
),
Doc
(
'FLOOR_DIV'
,
'binary: floor(x / y)'
),
Doc
(
'MAX'
,
'binary: max(x, y)'
),
Doc
(
'MIN'
,
'binary: min(x, y)'
),
Doc
(
'MOD'
,
'binary: x % y or fmodf(x, y)'
),
Doc
(
'MUL'
,
'binary: x * y'
),
Doc
(
'POW'
,
'binary: pow(x, y)'
),
Doc
(
'SIGMOID_GRAD'
,
'binary: x * (1 - x) * y'
),
Doc
(
'SUB'
,
'binary: x - y'
),
Doc
(
'SWITCH_GT0'
,
'binary: (x > 0) * y'
),
Doc
(
'TANH_GRAD'
,
'binary: (1 - x * x) * y'
),
Doc
(
'TRUE_DIV'
,
'binary: x / y'
),
Doc
(
'LOG_SUM_EXP'
,
'binary: numerically stable log(exp(x) + exp(y))'
),
Doc
(
'LT'
,
'binary: x < y'
),
Doc
(
'LEQ'
,
'binary: x <= y'
),
Doc
(
'EQ'
,
'binary: x == y'
),
Doc
(
'SHL'
,
'bitwise binary: x << y. '
Doc
(
'RELU
= 0
'
,
'unary: max(x, 0)'
),
Doc
(
'ABS
= 1
'
,
'unary: abs(x)'
),
Doc
(
'ACOS
= 2
'
,
'unary: acos(x)'
),
Doc
(
'ASIN
= 3
'
,
'unary: asin(x)'
),
Doc
(
'CEIL
= 4
'
,
'unary: ceil(x)'
),
Doc
(
'COS
= 5
'
,
'unary: cos(x)'
),
Doc
(
'EXP
= 6
'
,
'unary: exp(x)'
),
Doc
(
'EXPM1
= 7
'
,
'unary: numerically stable exp(x)-1'
),
Doc
(
'FLOOR
= 8
'
,
'unary: floor(x)'
),
Doc
(
'LOG
= 9
'
,
'unary: natural logarithm, log(x)'
),
Doc
(
'LOG1P
= 10
'
,
'unary: numerically stable log(x+1)'
),
Doc
(
'NEGATE
= 11
'
,
'unary: -x'
),
Doc
(
'SIGMOID
= 12
'
,
'unary: 1/(1+exp(-x))'
),
Doc
(
'SIN
= 13
'
,
'unary: sin(x)'
),
Doc
(
'TANH
= 14
'
,
'unary: tanh(x)'
),
Doc
(
'ABS_GRAD
= 15
'
,
'binary: x > 0 ? y : -y'
),
Doc
(
'ADD
= 16
'
,
'binary: x + y'
),
Doc
(
'FLOOR_DIV
= 17
'
,
'binary: floor(x / y)'
),
Doc
(
'MAX
= 18
'
,
'binary: max(x, y)'
),
Doc
(
'MIN
= 19
'
,
'binary: min(x, y)'
),
Doc
(
'MOD
= 20
'
,
'binary: x % y or fmodf(x, y)'
),
Doc
(
'MUL
= 21
'
,
'binary: x * y'
),
Doc
(
'POW
= 22
'
,
'binary: pow(x, y)'
),
Doc
(
'SIGMOID_GRAD
= 23
'
,
'binary: x * (1 - x) * y'
),
Doc
(
'SUB
= 24
'
,
'binary: x - y'
),
Doc
(
'SWITCH_GT0
= 25
'
,
'binary: (x > 0) * y'
),
Doc
(
'TANH_GRAD
= 26
'
,
'binary: (1 - x * x) * y'
),
Doc
(
'TRUE_DIV
= 27
'
,
'binary: x / y'
),
Doc
(
'LOG_SUM_EXP
= 28
'
,
'binary: numerically stable log(exp(x) + exp(y))'
),
Doc
(
'LT
= 29
'
,
'binary: x < y'
),
Doc
(
'LEQ
= 30
'
,
'binary: x <= y'
),
Doc
(
'EQ
= 31
'
,
'binary: x == y'
),
Doc
(
'SHL
= 32
'
,
'bitwise binary: x << y. '
'Note that result is undefined if y < 0 or y >= bitwidth. Logical '
'shift is performed for unsigned intergers, and arithmetic shift for '
'signed ones.'
),
Doc
(
'SHR'
,
'bitwise binary: x >> y; see SHL mode for more details'
),
Doc
(
'SHR
= 33
'
,
'bitwise binary: x >> y; see SHL mode for more details'
),
Doc
(
'COND_LEQ_MOV'
,
'ternary: x <= y ? z : 0'
),
Doc
(
'FUSE_MUL_ADD3'
,
Doc
(
'COND_LEQ_MOV
= 34
'
,
'ternary: x <= y ? z : 0'
),
Doc
(
'FUSE_MUL_ADD3
= 35
'
,
'compute ``a * b + c`` where c must either have same layout as '
'a or b, or be a scalar'
),
Doc
(
'FUSE_MUL_ADD4'
,
Doc
(
'FUSE_MUL_ADD4
= 36
'
,
'compute ``a * A + b * B`` where a and b must have equal layout, '
'and A and B must have equal layout. In the inputs ``b`` and ``B`` '
'can be swapped'
),
Doc
(
'FUSE_ADD_RELU'
,
'binary: max(x+y, 0)'
),
Doc
(
'FUSE_ADD_SIGMOID'
,
'binary: 1/(1+exp(-(x+y)))'
),
Doc
(
'FUSE_ADD_TANH'
,
'binary: tanh(x+y)'
),
Doc
(
'FAST_TANH'
,
'unary: rational approximation of tanh(x)'
),
Doc
(
'FAST_TANH_GRAD'
,
'binary: grad of the rational approximation of tanh(x)'
),
Doc
(
'FUSE_ADD_RELU
= 37
'
,
'binary: max(x+y, 0)'
),
Doc
(
'FUSE_ADD_SIGMOID
= 38
'
,
'binary: 1/(1+exp(-(x+y)))'
),
Doc
(
'FUSE_ADD_TANH
= 39
'
,
'binary: tanh(x+y)'
),
Doc
(
'FAST_TANH
= 40
'
,
'unary: rational approximation of tanh(x)'
),
Doc
(
'FAST_TANH_GRAD
= 41
'
,
'binary: grad of the rational approximation of tanh(x)'
),
Doc
(
'ROUND'
,
'unary: round(x), the nearest integer value to x, rounding '
Doc
(
'ROUND
= 42
'
,
'unary: round(x), the nearest integer value to x, rounding '
'halfway cases away from zero. Float only.'
),
Doc
(
'RMULH'
,
'binary: rounded higher l bits of x * y, where l is the bit '
Doc
(
'RMULH
= 43
'
,
'binary: rounded higher l bits of x * y, where l is the bit '
'length of x.'
),
Doc
(
'ATAN2'
,
'binary: atan2(y,x)'
),
Doc
(
'ERF'
,
'unary: erf(x)'
),
Doc
(
'ERFINV'
,
'unary: inverse function of erf(x)'
),
Doc
(
'ERFC'
,
'unary: erfc(x)'
),
Doc
(
'ERFCINV'
,
'unary: inverse function of erfc(x)'
),
Doc
(
'H_SWISH'
,
'unary: x * clip(x + 3, 0, 6) / 6'
),
Doc
(
'H_SWISH_GRAD'
,
'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'
),
Doc
(
'FUSE_ADD_H_SWISH'
,
'binary: hswish(x+y)'
),
Doc
(
'NOT'
,
'unary: !x'
),
Doc
(
'AND'
,
'binary: x && y'
),
Doc
(
'OR'
,
'binary: x || y'
),
Doc
(
'XOR'
,
'binary: x ^ y'
),
Doc
(
'SILU'
,
'unary: x / (1 + exp(-x))'
),
Doc
(
'SILU_GRAD'
,
'binary: grad(x / (1 + exp(-x))'
),
Doc
(
'GELU'
,
'unary: x Phi(x)'
),
Doc
(
'GELU_GRAD'
,
'binary: grad(x Phi(x))'
),
Doc
(
'ATAN2
= 44
'
,
'binary: atan2(y,x)'
),
Doc
(
'ERF
= 45
'
,
'unary: erf(x)'
),
Doc
(
'ERFINV
= 46
'
,
'unary: inverse function of erf(x)'
),
Doc
(
'ERFC
= 47
'
,
'unary: erfc(x)'
),
Doc
(
'ERFCINV
= 48
'
,
'unary: inverse function of erfc(x)'
),
Doc
(
'H_SWISH
= 49
'
,
'unary: x * clip(x + 3, 0, 6) / 6'
),
Doc
(
'H_SWISH_GRAD
= 50
'
,
'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'
),
Doc
(
'FUSE_ADD_H_SWISH
= 51
'
,
'binary: hswish(x+y)'
),
Doc
(
'NOT
= 52
'
,
'unary: !x'
),
Doc
(
'AND
= 53
'
,
'binary: x && y'
),
Doc
(
'OR
= 54
'
,
'binary: x || y'
),
Doc
(
'XOR
= 55
'
,
'binary: x ^ y'
),
Doc
(
'SILU
= 56
'
,
'unary: x / (1 + exp(-x))'
),
Doc
(
'SILU_GRAD
= 57
'
,
'binary: grad(x / (1 + exp(-x))'
),
Doc
(
'GELU
= 58
'
,
'unary: x Phi(x)'
),
Doc
(
'GELU_GRAD
= 59
'
,
'binary: grad(x Phi(x))'
),
)
pdef
(
'ElemwiseMultiType'
).
add_enum
(
'Mode'
,
Doc
(
'FUSE_MUL_ADD3_INT16x32x32x32'
,
Doc
(
'FUSE_MUL_ADD3_INT16x32x32x32
= 0
'
,
'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and '
'``c`` int32, and the result is int32. This mode is optimized for '
'the channel-broadacsted case, i.e. ``a`` has shape (A, B, C) and '
'``b`` and ``c`` have shape (1, C, 1)'
),
Doc
(
'FUSE_MUL_ADD3_IXxF32xF32xI8'
,
Doc
(
'FUSE_MUL_ADD3_IXxF32xF32xI8
= 1
'
,
'compuate ``a * b + c`` where the inputs ``a`` is an integer type '
'``b`` and ``c`` are both ``float32``, the result is '
'``int8``. This is currently only optimized for ``(1, x)`` '
'broadcast for ``b`` and ``c``. Computation is carried in floating '
'points and results are rounded towards zero with saturated cast to '
'int.'
),
Doc
(
'ROUND_SHR_SATURATE_IXxI8xI8'
,
Doc
(
'ROUND_SHR_SATURATE_IXxI8xI8
= 2
'
,
'Compute ``a >> b``, round the result according to lower ``b`` bits '
'of ``a``` and make a saturating conversion to int8. Where ``a`` should'
' be an integer tensor and ``b`` should be an int8 scalar.'
),
Doc
(
'FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8'
,
Doc
(
'FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8
= 3
'
,
'Fused operation of an int16 elemwise add, an int16 rounding multiply '
'high and an int16 to int8 rounding right shift with saturation.'
),
Doc
(
'FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8'
,
Doc
(
'FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8
= 4
'
,
'Fused operation of an int32 elemwise add, an int32 rounding multiply '
'high and an int32 to int8 rounding right shift with saturation.'
),
Doc
(
'ROUND_SHR_SATURATE_IXxI8xI16'
,
Doc
(
'ROUND_SHR_SATURATE_IXxI8xI16
= 5
'
,
'Compute ``a >> b``, round the result according to lower ``b`` bits of '
'``a``` and make a saturating conversion to int16. Where ``a`` should'
' be an integer tensor and ``b`` should be an int8 scalar.'
),
Doc
(
'QADD'
,
'Fused elemwise add two quantized int8 with specified'
Doc
(
'QADD
= 6
'
,
'Fused elemwise add two quantized int8 with specified'
'output quantized dtype'
),
Doc
(
'QFUSE_ADD_RELU'
,
'Fused elemwise add two quantized int8 followed'
Doc
(
'QFUSE_ADD_RELU
= 7
'
,
'Fused elemwise add two quantized int8 followed'
' by ReLU and typecvt to specified dtype'
),
Doc
(
'QMUL'
,
'Fused elemwise multiply two quantized int8 with specified'
Doc
(
'QMUL
= 8
'
,
'Fused elemwise multiply two quantized int8 with specified'
'output quantized dtype'
),
Doc
(
'QMIN'
,
'Fused elemwise min two quantized int8 with specified'
Doc
(
'QMIN
= 9
'
,
'Fused elemwise min two quantized int8 with specified'
'output quantized dtype'
),
Doc
(
'QMAX'
,
'quantized: max(x, y), with specified output quantized dtype'
),
Doc
(
'QSUB'
,
'quantized: x - y'
),
Doc
(
'QTRUE_DIV'
,
'quantized: x / y'
),
Doc
(
'QFUSE_ADD_SIGMOID'
,
'quantized: sigmoid(x + y)'
),
Doc
(
'QFUSE_ADD_TANH'
,
'quantized: tanh(x + y)'
),
Doc
(
'QRELU'
,
'quantized: x > 0 ? x : 0'
),
Doc
(
'QABS'
,
'quantized: x > 0 ? x : -x'
),
Doc
(
'QSIGMOID'
,
'quantized: sigmoid(x)'
),
Doc
(
'QEXP'
,
'quantized: exp(x)'
),
Doc
(
'QTANH'
,
'quantized: tanh(x)'
),
Doc
(
'QFUSE_MUL_ADD3'
,
'quantized: x * y + z'
),
Doc
(
'QFAST_TANH'
,
'quantized: fast_tanh(x)'
),
Doc
(
'QNEGATE'
,
'quantized: -x'
),
Doc
(
'QACOS'
,
'quantized: acos(x)'
),
Doc
(
'QASIN'
,
'quantized: asin(x)'
),
Doc
(
'QCEIL'
,
'quantized: ceil(x)'
),
Doc
(
'QCOS'
,
'quantized: cos(x)'
),
Doc
(
'QEXPM1'
,
'quantized: expm1(x)'
),
Doc
(
'QFLOOR'
,
'quantized: floor(x)'
),
Doc
(
'QLOG'
,
'quantized: log(x)'
),
Doc
(
'QLOG1P'
,
'quantized: log1p(x)'
),
Doc
(
'QSIN'
,
'quantized: sin(x)'
),
Doc
(
'QROUND'
,
'quantized: round(x)'
),
Doc
(
'QERF'
,
'quantized: erf(x)'
),
Doc
(
'QERFINV'
,
'quantized: erfinv(x)'
),
Doc
(
'QERFC'
,
'quantized: erfc(x)'
),
Doc
(
'QERFCINV'
,
'quantized: erfcinv(x)'
),
Doc
(
'QABS_GRAD'
,
'quantized: abs_grad'
),
Doc
(
'QFLOOR_DIV'
,
'quantized floor_div'
),
Doc
(
'QMOD'
,
'quantized mod'
),
Doc
(
'QSIGMOID_GRAD'
,
'quantized sigmoid_grad'
),
Doc
(
'QSWITCH_GT0'
,
'quantized switch_gt0'
),
Doc
(
'QTANH_GRAD'
,
'quantized tanh_grad'
),
Doc
(
'QLT'
,
'quantized lt'
),
Doc
(
'QLEQ'
,
'quantized leq'
),
Doc
(
'QEQ'
,
'quantized eq'
),
Doc
(
'QPOW'
,
'quantized pow'
),
Doc
(
'QLOG_SUM_EXP'
,
'quantized log_sum_exp'
),
Doc
(
'QFAST_TANH_GRAD'
,
'quantized fast_tanh_grad'
),
Doc
(
'QATAN2'
,
'quantized atan2'
),
Doc
(
'QCOND_LEQ_MOV'
,
'quantized cond_leq_mov'
),
Doc
(
'QH_SWISH'
,
'quantized h_swish'
),
Doc
(
'QFUSE_ADD_H_SWISH'
,
'quantized h_swish(x+y)'
),
Doc
(
'QH_SWISH_GRAD'
,
'quantized h_swish_grad'
)
Doc
(
'QMAX
= 10
'
,
'quantized: max(x, y), with specified output quantized dtype'
),
Doc
(
'QSUB
= 11
'
,
'quantized: x - y'
),
Doc
(
'QTRUE_DIV
= 12
'
,
'quantized: x / y'
),
Doc
(
'QFUSE_ADD_SIGMOID
= 13
'
,
'quantized: sigmoid(x + y)'
),
Doc
(
'QFUSE_ADD_TANH
= 14
'
,
'quantized: tanh(x + y)'
),
Doc
(
'QRELU
= 15
'
,
'quantized: x > 0 ? x : 0'
),
Doc
(
'QABS
= 16
'
,
'quantized: x > 0 ? x : -x'
),
Doc
(
'QSIGMOID
= 17
'
,
'quantized: sigmoid(x)'
),
Doc
(
'QEXP
= 18
'
,
'quantized: exp(x)'
),
Doc
(
'QTANH
= 19
'
,
'quantized: tanh(x)'
),
Doc
(
'QFUSE_MUL_ADD3
= 20
'
,
'quantized: x * y + z'
),
Doc
(
'QFAST_TANH
= 21
'
,
'quantized: fast_tanh(x)'
),
Doc
(
'QNEGATE
= 22
'
,
'quantized: -x'
),
Doc
(
'QACOS
= 23
'
,
'quantized: acos(x)'
),
Doc
(
'QASIN
= 24
'
,
'quantized: asin(x)'
),
Doc
(
'QCEIL
= 25
'
,
'quantized: ceil(x)'
),
Doc
(
'QCOS
= 26
'
,
'quantized: cos(x)'
),
Doc
(
'QEXPM1
= 27
'
,
'quantized: expm1(x)'
),
Doc
(
'QFLOOR
= 28
'
,
'quantized: floor(x)'
),
Doc
(
'QLOG
= 29
'
,
'quantized: log(x)'
),
Doc
(
'QLOG1P
= 30
'
,
'quantized: log1p(x)'
),
Doc
(
'QSIN
= 31
'
,
'quantized: sin(x)'
),
Doc
(
'QROUND
= 32
'
,
'quantized: round(x)'
),
Doc
(
'QERF
= 33
'
,
'quantized: erf(x)'
),
Doc
(
'QERFINV
= 34
'
,
'quantized: erfinv(x)'
),
Doc
(
'QERFC
= 35
'
,
'quantized: erfc(x)'
),
Doc
(
'QERFCINV
= 36
'
,
'quantized: erfcinv(x)'
),
Doc
(
'QABS_GRAD
= 37
'
,
'quantized: abs_grad'
),
Doc
(
'QFLOOR_DIV
= 38
'
,
'quantized floor_div'
),
Doc
(
'QMOD
= 39
'
,
'quantized mod'
),
Doc
(
'QSIGMOID_GRAD
= 40
'
,
'quantized sigmoid_grad'
),
Doc
(
'QSWITCH_GT0
= 41
'
,
'quantized switch_gt0'
),
Doc
(
'QTANH_GRAD
= 42
'
,
'quantized tanh_grad'
),
Doc
(
'QLT
= 43
'
,
'quantized lt'
),
Doc
(
'QLEQ
= 44
'
,
'quantized leq'
),
Doc
(
'QEQ
= 45
'
,
'quantized eq'
),
Doc
(
'QPOW
= 46
'
,
'quantized pow'
),
Doc
(
'QLOG_SUM_EXP
= 47
'
,
'quantized log_sum_exp'
),
Doc
(
'QFAST_TANH_GRAD
= 48
'
,
'quantized fast_tanh_grad'
),
Doc
(
'QATAN2
= 49
'
,
'quantized atan2'
),
Doc
(
'QCOND_LEQ_MOV
= 50
'
,
'quantized cond_leq_mov'
),
Doc
(
'QH_SWISH
= 51
'
,
'quantized h_swish'
),
Doc
(
'QFUSE_ADD_H_SWISH
= 52
'
,
'quantized h_swish(x+y)'
),
Doc
(
'QH_SWISH_GRAD
= 53
'
,
'quantized h_swish_grad'
)
)
pdef
(
'PowC'
,
'power with constant exponent'
).
add_fields
(
'float32'
,
'exp'
,
0
)
(
pdef
(
'DctChannelSelect'
,
'2d discrete cosine transform'
,
version
=
0
,
is_legacy
=
True
).
add_enum_alias
(
'Format'
,
'ConvolutionV0'
).
add_enum
(
'FastImpl'
,
'NONE
'
,
'FIX_32_MASK
'
).
add_fields
(
'int32'
,
'dct_block_size'
,
8
))
add_enum
(
'FastImpl'
,
'NONE
= 0'
,
'FIX_32_MASK = 1
'
).
add_fields
(
'int32'
,
'dct_block_size'
,
8
))
(
pdef
(
'DctChannelSelect'
,
'2d discrete cosine transform'
,
version
=
1
).
add_enum_alias
(
'Format'
,
'Convolution'
).
add_enum_alias
(
'FastImpl'
,
'DctChannelSelectV0'
).
add_fields
(
'int32'
,
'dct_block_size'
,
8
))
...
...
@@ -510,13 +510,13 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
(
pdef
(
'MatrixMul'
,
version
=
0
,
is_legacy
=
True
).
add_fields
(
'bool'
,
'transposeA'
,
'false'
,
'transposeB'
,
'false'
).
add_enum
(
'DataType'
,
Doc
(
'FLOAT'
,
'input/output both float32/float16'
),
'INT8x8x16'
,
'INT8x8x32'
,
Doc
(
'FLOAT_IO16xC32'
,
'input/output both float16, the internal compute is '
Doc
(
'FLOAT
= 0
'
,
'input/output both float32/float16'
),
'INT8x8x16
= 1
'
,
'INT8x8x32
= 2
'
,
Doc
(
'FLOAT_IO16xC32
= 3
'
,
'input/output both float16, the internal compute is '
'float32'
),
Doc
(
'QUINT8x8x32'
,
'input QuantizedAsymm8, output QuantizedS32'
),
Doc
(
'QUINT4x4x32'
,
'input QuantizedAsymm4, output QuantizedS32'
),
Doc
(
'QUINT8x8x32
= 4
'
,
'input QuantizedAsymm8, output QuantizedS32'
),
Doc
(
'QUINT4x4x32
= 5
'
,
'input QuantizedAsymm4, output QuantizedS32'
),
name_field
=
'data_type'
))
(
pdef
(
'MatrixMul'
,
version
=
1
,
is_legacy
=
True
).
...
...
@@ -524,9 +524,9 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
add_enum
(
Doc
(
'ComputeMode'
,
'Specifies special computation modes, e.g. '
'different combinations of intermediate result '
'data types.'
),
Doc
(
'DEFAULT'
,
'No special requirements on the precision of '
Doc
(
'DEFAULT
= 0
'
,
'No special requirements on the precision of '
'intermediate results.'
),
Doc
(
'FLOAT32'
,
'Use Float32 accumulator and intermediate result. '
Doc
(
'FLOAT32
= 1
'
,
'Use Float32 accumulator and intermediate result. '
'Only supported when input and output is Float16.'
),
name_field
=
'compute_mode'
))
...
...
@@ -534,14 +534,14 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
add_fields
(
'bool'
,
'transposeA'
,
'false'
,
'transposeB'
,
'false'
).
add_enum_alias
(
'ComputeMode'
,
'MatrixMulV1'
,
name_field
=
'compute_mode'
).
add_enum
(
'Format'
,
Doc
(
'DEFAULT'
,
'Normal matrix mul: (M, K) x (K, N) = (M, N)'
),
Doc
(
'MK4'
,
'Split 4 from M and K, better for neon compute:'
Doc
(
'DEFAULT
= 0
'
,
'Normal matrix mul: (M, K) x (K, N) = (M, N)'
),
Doc
(
'MK4
= 1
'
,
'Split 4 from M and K, better for neon compute:'
'(M/4, K/4, 4(k), 4(m)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'
),
Doc
(
'MK8'
,
'Split 8 from M and K, better for neon compute:'
Doc
(
'MK8
= 2
'
,
'Split 8 from M and K, better for neon compute:'
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the '
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'
),
Doc
(
'MK4_DOT'
,
'Split 4 from M and K, better for neon dotprod:'
Doc
(
'MK4_DOT
= 3
'
,
'Split 4 from M and K, better for neon dotprod:'
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'
))
)
...
...
@@ -560,9 +560,9 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
(
pdef
(
'Reduce'
,
'legacy reduce'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
'Mode'
,
'SUM'
,
Doc
(
'SUM_SQR'
,
'sum of x * x for each element x'
),
'PRODUCT
'
,
'MIN'
,
'MAX
'
).
'SUM
= 0
'
,
Doc
(
'SUM_SQR
= 1
'
,
'sum of x * x for each element x'
),
'PRODUCT
= 2'
,
'MIN = 3'
,
'MAX = 4
'
).
add_fields
(
'int32'
,
Doc
(
'axis'
,
'axis along which reduction is performed; if -1 is given, '
...
...
@@ -571,16 +571,16 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
(
pdef
(
'Reduce'
,
'reduce along given axis'
,
version
=
1
,
is_legacy
=
True
).
add_enum
(
'Mode'
,
'SUM'
,
Doc
(
'SUM_SQR'
,
'sum of x * x for each element x'
),
'PRODUCT
'
,
'MIN'
,
'MAX'
,
'MEAN
'
).
'SUM
= 0
'
,
Doc
(
'SUM_SQR
= 1
'
,
'sum of x * x for each element x'
),
'PRODUCT
= 2'
,
'MIN = 3'
,
'MAX = 4'
,
'MEAN = 5
'
).
add_fields
(
'int32'
,
Doc
(
'axis'
,
'axis along which reduction is performed; if -1 is given, '
'reduce to given target shape (only used in megbrain)'
),
-
1
).
add_enum
(
'DataType'
,
Doc
(
'DEFAULT'
,
Doc
(
'DEFAULT
= 0
'
,
'''
input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode.
Currently, ```DEFAULT``` mode means:
...
...
@@ -607,26 +607,26 @@ Currently, ```DEFAULT``` mode means:
'''
),
Doc
(
'FLOAT_IO16xC32'
,
'Deprecated. This was replaced by '
Doc
(
'FLOAT_IO16xC32
= 1
'
,
'Deprecated. This was replaced by '
'FLOAT_O16xC32, and input
\'
s dtype decided by actual input tensor.'
),
Doc
(
'FLOAT_O32xC32'
,
'compute/output both are float32'
),
Doc
(
'FLOAT_O16xC32'
,
'compute are float32, output float16'
),
Doc
(
'QUINT_I8xO32'
,
'input quint8, compute and output are qint32'
),
Doc
(
'QINT_I8xO32'
,
'input qint8, compute and output are qint32'
),
Doc
(
'FLOAT_O32xC32
= 2
'
,
'compute/output both are float32'
),
Doc
(
'FLOAT_O16xC32
= 3
'
,
'compute are float32, output float16'
),
Doc
(
'QUINT_I8xO32
= 4
'
,
'input quint8, compute and output are qint32'
),
Doc
(
'QINT_I8xO32
= 5
'
,
'input qint8, compute and output are qint32'
),
name_field
=
'data_type'
))
(
pdef
(
'Reduce'
,
'reduce along given axis'
,
version
=
2
).
add_enum
(
'Mode'
,
'SUM'
,
Doc
(
'SUM_SQR'
,
'sum of x * x for each element x'
),
'PRODUCT
'
,
'MIN'
,
'MAX'
,
'MEAN
'
).
'SUM
= 0
'
,
Doc
(
'SUM_SQR
= 1
'
,
'sum of x * x for each element x'
),
'PRODUCT
= 2'
,
'MIN = 3'
,
'MAX = 4'
,
'MEAN = 5
'
).
add_fields
(
'int32'
,
Doc
(
'axis'
,
'axis along which reduction is performed; if INT_MAX is given, '
'reduce to given target shape (only used in megbrain)'
),
(
1
<<
31
)
-
1
).
add_enum
(
'DataType'
,
Doc
(
'DEFAULT'
,
Doc
(
'DEFAULT
= 0
'
,
'''
input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode.
Currently, ```DEFAULT``` mode means:
...
...
@@ -653,12 +653,12 @@ Currently, ```DEFAULT``` mode means:
'''
),
Doc
(
'FLOAT_IO16xC32'
,
'Deprecated. This was replaced by '
Doc
(
'FLOAT_IO16xC32
= 1
'
,
'Deprecated. This was replaced by '
'FLOAT_O16xC32, and input
\'
s dtype decided by actual input tensor.'
),
Doc
(
'FLOAT_O32xC32'
,
'compute/output both are float32'
),
Doc
(
'FLOAT_O16xC32'
,
'compute are float32, output float16'
),
Doc
(
'QUINT_I8xO32'
,
'input quint8, compute and output are qint32'
),
Doc
(
'QINT_I8xO32'
,
'input qint8, compute and output are qint32'
),
Doc
(
'FLOAT_O32xC32
= 2
'
,
'compute/output both are float32'
),
Doc
(
'FLOAT_O16xC32
= 3
'
,
'compute are float32, output float16'
),
Doc
(
'QUINT_I8xO32
= 4
'
,
'input quint8, compute and output are qint32'
),
Doc
(
'QINT_I8xO32
= 5
'
,
'input qint8, compute and output are qint32'
),
name_field
=
'data_type'
))
(
pdef
(
'Cumsum'
,
'calculate accumulated sum along given axis'
,
version
=
0
,
is_legacy
=
True
).
...
...
@@ -691,12 +691,12 @@ Currently, ```DEFAULT``` mode means:
(
pdef
(
'CondTake'
).
add_enum
(
'Mode'
,
Doc
(
'EQ'
,
'take if ``abs(data-val)<eps``'
),
Doc
(
'NEQ'
,
'take if ``abs(data-val)>=eps``'
),
Doc
(
'LT'
,
'take if ``data<val``'
),
Doc
(
'LEQ'
,
'take if ``data<=val``'
),
Doc
(
'GT'
,
'take if ``data>val``'
),
Doc
(
'GEQ'
,
'take if ``data>=val``'
)).
Doc
(
'EQ
= 0
'
,
'take if ``abs(data-val)<eps``'
),
Doc
(
'NEQ
= 1
'
,
'take if ``abs(data-val)>=eps``'
),
Doc
(
'LT
= 2
'
,
'take if ``data<val``'
),
Doc
(
'LEQ
= 3
'
,
'take if ``data<=val``'
),
Doc
(
'GT
= 4
'
,
'take if ``data>val``'
),
Doc
(
'GEQ
= 5
'
,
'take if ``data>=val``'
)).
add_fields
(
'float32'
,
Doc
(
'val'
,
'the value to be compared with; note that for integer '
'data, val is also converted to int'
),
0
).
...
...
@@ -704,7 +704,7 @@ Currently, ```DEFAULT``` mode means:
1e-6
))
pdef
(
'Argsort'
).
add_enum
(
'Order'
,
'ASCENDING
'
,
'DESCENDING
'
)
pdef
(
'Argsort'
).
add_enum
(
'Order'
,
'ASCENDING
= 0'
,
'DESCENDING = 1
'
)
(
pdef
(
'IndexingRemap'
).
add_fields
(
'bool'
,
...
...
@@ -791,17 +791,17 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
.
add_fields
(
'uint32'
,
'row_from'
,
0
,
'row_to'
,
0
,
'col_from'
,
0
,
'col_to'
,
0
))
(
pdef
(
'CvtColor'
)
.
add_enum
(
'Mode'
,
'RGB2GRAY
'
,
'RGB2YUV'
,
'YUV2RGB'
,
'GRAY2RGB'
,
'RGBA2RGB
'
,
'RGBA2BGR
'
,
'RGBA2GRAY'
,
'RGB2BGR'
,
'BGR2GRAY'
,
'BGR2RGB
'
,
Doc
(
'YUV2GRAY_NV21'
,
'For historical reasons, referred to as YCC by opencv'
),
'YUV2RGB_NV21
'
,
'YUV2BGR_NV21'
,
'YUV2GRAY_NV12'
,
'YUV2RGB_NV12
'
,
'YUV2BGR_NV12
'
,
'YUV2GRAY_YV12'
,
'YUV2RGB_YV12'
,
'YUV2BGR_YV12
'
,
'YUV2GRAY_YU12
'
,
'YUV2RGB_YU12'
,
'YUV2BGR_YU12
'
,
'YCrCb2RGB
'
,
'YCrCb2BGR
'
,
Doc
(
'BT601_YUV2RGB_NV21'
,
'BT601 yuv format, referred to as YUV by opencv'
),
'BT601_YUV2BGR_NV21
'
,
'BT601_YUV2RGB_NV12'
,
'BT601_YUV2BGR_NV12
'
,
'BT601_YUV2RGB_YV12
'
,
'BT601_YUV2BGR_YV12'
,
'BT601_YUV2RGB_YU12
'
,
'BT601_YUV2BGR_YU12'
,
.
add_enum
(
'Mode'
,
'RGB2GRAY
= 0'
,
'RGB2YUV = 1'
,
'YUV2RGB = 2'
,
'GRAY2RGB = 3'
,
'RGBA2RGB = 4
'
,
'RGBA2BGR
= 5'
,
'RGBA2GRAY = 6'
,
'RGB2BGR = 7'
,
'BGR2GRAY = 8'
,
'BGR2RGB = 9
'
,
Doc
(
'YUV2GRAY_NV21
= 10
'
,
'For historical reasons, referred to as YCC by opencv'
),
'YUV2RGB_NV21
= 11'
,
'YUV2BGR_NV21 = 12'
,
'YUV2GRAY_NV12 = 13'
,
'YUV2RGB_NV12 = 14
'
,
'YUV2BGR_NV12
= 15'
,
'YUV2GRAY_YV12 = 16'
,
'YUV2RGB_YV12 = 17'
,
'YUV2BGR_YV12 = 18
'
,
'YUV2GRAY_YU12
= 19'
,
'YUV2RGB_YU12 = 20'
,
'YUV2BGR_YU12 = 21
'
,
'YCrCb2RGB
= 22'
,
'YCrCb2BGR = 23
'
,
Doc
(
'BT601_YUV2RGB_NV21
= 24
'
,
'BT601 yuv format, referred to as YUV by opencv'
),
'BT601_YUV2BGR_NV21
= 25'
,
'BT601_YUV2RGB_NV12 = 26'
,
'BT601_YUV2BGR_NV12 = 27
'
,
'BT601_YUV2RGB_YV12
= 28'
,
'BT601_YUV2BGR_YV12 = 29'
,
'BT601_YUV2RGB_YU12 = 30
'
,
'BT601_YUV2BGR_YU12
= 31
'
,
member_alias
=
[(
'YUV2GRAY_NV21'
,
'BT601_YUV2GRAY_NV21'
),
(
'YUV2GRAY_NV12'
,
'BT601_YUV2GRAY_NV12'
),
(
'YUV2GRAY_YV12'
,
'BT601_YUV2GRAY_YV12'
),
...
...
@@ -855,7 +855,7 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
.
add_fields
(
'float32'
,
'scalar'
,
'0.f'
))
(
pdef
(
'Convolution3D'
).
add_enum
(
'Mode'
,
'CROSS_CORRELATION
'
,
'CONVOLUTION
'
).
add_enum
(
'Mode'
,
'CROSS_CORRELATION
= 0'
,
'CONVOLUTION = 1
'
).
add_fields
(
'uint32'
,
Doc
(
'pad_d'
,
'padding on one side on the first dimension'
),
0
,
...
...
@@ -872,32 +872,32 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'on the third dimension'
),
1
).
add_enum
(
'Sparse'
,
Doc
(
'DENSE'
,
'dense convolution: filter shape should be '
Doc
(
'DENSE
= 0
'
,
'dense convolution: filter shape should be '
'[oc, ic, spatial...] if format is NCDHW, '
'[oc, spatial..., ic] if format is NDHWC'
),
Doc
(
'GROUP'
,
'group convolution: filter shape should be '
Doc
(
'GROUP
= 1
'
,
'group convolution: filter shape should be '
'[group, oc_per_group, ic_per_group, spatial...] if format is NCDHW, '
'[group, oc_per_group, spatial..., ic_per_group] if format is NDHWC'
)
).
add_enum
(
'DataType'
,
Doc
(
'FLOAT'
,
'input/output both float32/float16'
),
Doc
(
'FLOAT_IO16xC32'
,
'input/output both float16, the internal '
Doc
(
'FLOAT
= 0
'
,
'input/output both float32/float16'
),
Doc
(
'FLOAT_IO16xC32
= 1
'
,
'input/output both float16, the internal '
'compute is float32'
),
name_field
=
'data_type'
).
add_enum
(
'Format'
,
'NCDHW
'
,
'NDHWC
'
)
add_enum
(
'Format'
,
'NCDHW
= 0'
,
'NDHWC = 1
'
)
)
(
pdef
(
'Conv3DBias'
).
add_enum
(
'NonlineMode'
,
'IDENTITY
'
,
'RELU'
,
'SIGMOID
'
).
add_enum
(
'NonlineMode'
,
'IDENTITY
= 0'
,
'RELU = 1'
,
'SIGMOID = 2
'
).
add_enum_alias
(
'Mode'
,
'Convolution3D'
).
add_fields
(
'uint32'
,
'pad_d'
,
0
,
'pad_h'
,
0
,
'pad_w'
,
0
,
'stride_d'
,
1
,
'stride_h'
,
1
,
'stride_w'
,
0
))
(
pdef
(
'SeparableConv3D'
).
add_enum_alias
(
'Mode'
,
'Convolution3D'
).
add_enum
(
'BorderMode'
,
'BORDER_REPLICATE
'
,
'BORDER_REFLECT
'
,
'BORDER_REFLECT_101
'
,
'BORDER_WRAP
'
,
'BORDER_CONSTANT
'
,
'BORDER_TRANSPARENT'
,
'BORDER_ISOLATED
'
).
add_enum
(
'BorderMode'
,
'BORDER_REPLICATE
= 0'
,
'BORDER_REFLECT = 1
'
,
'BORDER_REFLECT_101
= 2'
,
'BORDER_WRAP = 3
'
,
'BORDER_CONSTANT
= 4'
,
'BORDER_TRANSPARENT = 5'
,
'BORDER_ISOLATED = 6
'
).
add_fields
(
'bool'
,
'is_symm_kernel'
,
'true'
).
add_fields
(
'uint32'
,
'pad_d'
,
0
,
'pad_h'
,
0
,
'pad_w'
,
0
,
'stride_d'
,
0
,
'stride_h'
,
1
,
'stride_w'
,
1
,
...
...
@@ -907,11 +907,11 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
(
pdef
(
'TopK'
).
add_enum
(
'Mode'
,
Doc
(
'KTH_ONLY'
,
"only the value of the k'th element would be computed"
),
Doc
(
'VALUE_IDX_NOSORT'
,
Doc
(
'KTH_ONLY
= 0
'
,
"only the value of the k'th element would be computed"
),
Doc
(
'VALUE_IDX_NOSORT
= 1
'
,
'all the top-k values and corresponding indices would be computed; '
'no order is guaranteed'
),
Doc
(
'VALUE_IDX_SORTED'
,
Doc
(
'VALUE_IDX_SORTED
= 2
'
,
'all the top-k values and corresponding indices sorted'
))
)
...
...
@@ -983,37 +983,37 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
(
pdef
(
'RelayoutFormat'
,
'Change the tensor layout format'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
Doc
(
'Mode'
,
RELAYOUT_FORMAT_MODE_DOC
),
'NHWC_NHWCD4'
,
'NHWCD4_NHWC'
,
'NHWC_NHWCD4I'
,
'NCHW_NHWCD4'
,
'NCHW_NHWCD4I'
,
'NHWCD4I_NCHW'
,
'NHWCD4_NCHW'
,
'INTER_WEIGHT_DENSE'
,
'INTER_WEIGHT_DENSEI'
,
'INTER_WEIGHT_GROUP'
,
'INTER_WEIGHT_GROUPI'
,
'INTER_WEIGHT_CHAN'
,
'INTER_WEIGHT_CHANI'
,
'INTER_WEIGHT_DENSEI_DOT'
,
'INTER_WEIGHT_GROUPI_DOT'
,
'NCHW4_CHWN4'
,
'CHWN4_NCHW4'
,
'NCHW_NCHW88_CONV_DENSE_WEIGHT'
,
'NCHW_NCHW88_CONV_CHAN_WEIGHT'
,
'NCHW_NCHW88_CONV_GROUP_WEIGHT'
,
'NCHW_NCHW88'
,
'NCHW88_NCHW'
,
'NCHW_NCHW4_IC_SMALL'
,
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT'
,
'NCHW_NCHW4'
,
'NCHW4_NCHW'
,
'NCHW_NCHW4_WEIGHT'
,
'NCHW_NCHW64'
,
'NCHW64_NCHW'
,
'NCHW_NHWC'
,
'NHWC_NCHW
'
,
'NHWC_NHWCD4
= 0
'
,
'NHWCD4_NHWC
= 1
'
,
'NHWC_NHWCD4I
= 2
'
,
'NCHW_NHWCD4
= 3
'
,
'NCHW_NHWCD4I
= 4
'
,
'NHWCD4I_NCHW
= 5
'
,
'NHWCD4_NCHW
= 6
'
,
'INTER_WEIGHT_DENSE
= 7
'
,
'INTER_WEIGHT_DENSEI
= 8
'
,
'INTER_WEIGHT_GROUP
= 9
'
,
'INTER_WEIGHT_GROUPI
= 10
'
,
'INTER_WEIGHT_CHAN
= 11
'
,
'INTER_WEIGHT_CHANI
= 12
'
,
'INTER_WEIGHT_DENSEI_DOT
= 13
'
,
'INTER_WEIGHT_GROUPI_DOT
= 14
'
,
'NCHW4_CHWN4
= 15
'
,
'CHWN4_NCHW4
= 16
'
,
'NCHW_NCHW88_CONV_DENSE_WEIGHT
= 17
'
,
'NCHW_NCHW88_CONV_CHAN_WEIGHT
= 18
'
,
'NCHW_NCHW88_CONV_GROUP_WEIGHT
= 19
'
,
'NCHW_NCHW88
= 20
'
,
'NCHW88_NCHW
= 21
'
,
'NCHW_NCHW4_IC_SMALL
= 22
'
,
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT
= 23
'
,
'NCHW_NCHW4
= 24
'
,
'NCHW4_NCHW
= 25
'
,
'NCHW_NCHW4_WEIGHT
= 26
'
,
'NCHW_NCHW64
= 27
'
,
'NCHW64_NCHW
= 28
'
,
'NCHW_NHWC
= 29
'
,
'NHWC_NCHW
= 30'
,
)
)
...
...
@@ -1077,7 +1077,7 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
(
pdef
(
'ROIAlign'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
'Mode'
,
'MAX
'
,
'AVERAGE
'
,
name_field
=
'mode'
).
add_enum
(
'Mode'
,
'MAX
= 0'
,
'AVERAGE = 1
'
,
name_field
=
'mode'
).
add_enum_alias
(
'Format'
,
'ConvolutionV0'
).
add_fields
(
'float32'
,
'spatial_scale'
,
'1.0'
).
add_fields
(
'float32'
,
'offset'
,
'0.0'
).
...
...
@@ -1173,9 +1173,9 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
pdef
(
'Fill'
).
add_fields
(
'float32'
,
'value'
,
'0'
)
PADDING_MODES
=
[
Doc
(
'REPLICATE'
,
'aaaaaa|abcdefgh|hhhhhhh'
),
Doc
(
'REFLECT'
,
'fedcba|abcdefgh|hgfedcb'
),
Doc
(
'CONSTANT'
,
'iiiiii|abcdefgh|iiiiiii'
)]
PADDING_MODES
=
[
Doc
(
'REPLICATE
= 0
'
,
'aaaaaa|abcdefgh|hhhhhhh'
),
Doc
(
'REFLECT
= 1
'
,
'fedcba|abcdefgh|hgfedcb'
),
Doc
(
'CONSTANT
= 2
'
,
'iiiiii|abcdefgh|iiiiiii'
)]
(
pdef
(
'Padding'
).
add_fields
(
'uint32'
,
Doc
(
'front_offset_dim0'
,
'offset in dim 0'
),
0
).
add_fields
(
'uint32'
,
Doc
(
'front_offset_dim1'
,
'offset in dim 1'
),
0
).
...
...
imperative/tablegen/helper.h
浏览文件 @
fb49a283
...
...
@@ -241,14 +241,17 @@ private:
if
(
auto
*
enumAttr
=
llvm
::
dyn_cast
<
MgbEnumAttrMixin
>
(
&
it
.
attr
))
{
body
+=
formatv
(
" switch ({0}){{
\n
"
,
"$_self."
+
it
.
name
);
for
(
auto
&&
enumMember
:
enumAttr
->
getEnumMembers
())
{
body
+=
formatv
(
" case {0}::{1}::{2}:
\n
"
,
getCppClassName
(),
enumAttr
->
getEnumName
(),
enumMember
);
body
+=
formatv
(
" props_.emplace_back(
\"
{0}
\"
,
\"
{1}
\"
);
\n
"
,
it
.
name
,
enumMember
);
size_t
d1
=
enumMember
.
find
(
' '
);
size_t
d2
=
enumMember
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
body
+=
formatv
(
" case {0}::{1}::{2}:
\n
"
,
getCppClassName
(),
enumAttr
->
getEnumName
(),
enumMember
.
substr
(
0
,
d
));
body
+=
formatv
(
" props_.emplace_back(
\"
{0}
\"
, "
"
\"
{1}
\"
);
\n
"
,
it
.
name
,
enumMember
.
substr
(
0
,
d
));
body
+=
" break;
\n
"
;
}
body
+=
" default: break;
\n
"
;
...
...
imperative/tablegen/targets/cpp_class.cpp
浏览文件 @
fb49a283
...
...
@@ -177,9 +177,13 @@ void OpDefEmitter::emit_tpl_spl() {
std
::
vector
<
std
::
string
>
case_body
;
std
::
string
ename
=
formatv
(
"{0}::{1}"
,
op
.
getCppClassName
(),
attr
->
getEnumName
());
llvm
::
for_each
(
attr
->
getEnumMembers
(),
[
&
](
auto
&&
v
){
case_body
.
push_back
(
formatv
(
"case {0}::{1}: return
\"
{1}
\"
;"
,
ename
,
v
));
llvm
::
for_each
(
attr
->
getEnumMembers
(),
[
&
](
auto
&&
v
)
{
size_t
d1
=
v
.
find
(
' '
);
size_t
d2
=
v
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
case_body
.
push_back
(
formatv
(
"case {0}::{1}: return
\"
{1}
\"
;"
,
ename
,
v
.
substr
(
0
,
d
)));
});
os
<<
formatv
(
R"(
template <>
...
...
imperative/tablegen/targets/pybind11.cpp
浏览文件 @
fb49a283
...
...
@@ -50,14 +50,15 @@ void OpDefEmitter::emit() {
);
std
::
vector
<
std
::
string
>
body
;
for
(
auto
&&
i
:
attr
->
getEnumMembers
())
{
os
<<
formatv
(
"
\n
.value(
\"
{2}
\"
, {0}::{1}::{2})"
,
className
,
attr
->
getEnumName
(),
i
);
size_t
d1
=
i
.
find
(
' '
);
size_t
d2
=
i
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
os
<<
formatv
(
"
\n
.value(
\"
{2}
\"
, {0}::{1}::{2})"
,
className
,
attr
->
getEnumName
(),
i
.
substr
(
0
,
d
));
body
.
push_back
(
formatv
(
"if (str ==
\"
{2}
\"
) return {0}::{1}::{2};"
,
className
,
attr
->
getEnumName
(),
i
));
"if (str ==
\"
{2}
\"
) return {0}::{1}::{2};"
,
className
,
attr
->
getEnumName
(),
i
.
substr
(
0
,
d
)));
}
if
(
attr
->
getEnumCombinedFlag
())
{
//! define operator |
...
...
imperative/tablegen/targets/python_c_extension.cpp
浏览文件 @
fb49a283
...
...
@@ -102,7 +102,10 @@ void EnumAttrEmitter::emit_tpl_spl() {
&
ctx
);
auto
quote
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
formatv
(
"
\"
{0}
\"
"
,
i
);
size_t
d1
=
i
.
find
(
' '
);
size_t
d2
=
i
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
return
formatv
(
"
\"
{0}
\"
"
,
i
.
substr
(
0
,
d
));
};
os
<<
tgfmt
(
R"(
template<> const char*
...
...
@@ -110,7 +113,11 @@ $enumTpl<$opClass::$enumClass>::members[] = {$0};
)"
,
&
ctx
,
llvm
::
join
(
llvm
::
map_range
(
attr
->
getEnumMembers
(),
quote
),
", "
));
auto
mem2value
=
[
&
](
auto
&&
i
)
->
std
::
string
{
return
tgfmt
(
"{normalize_enum(
\"
$0
\"
), $opClass::$enumClass::$0}"
,
&
ctx
,
i
);
size_t
d1
=
i
.
find
(
' '
);
size_t
d2
=
i
.
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
return
tgfmt
(
"{normalize_enum(
\"
$0
\"
), $opClass::$enumClass::$0}"
,
&
ctx
,
i
.
substr
(
0
,
d
));
};
os
<<
tgfmt
(
R"(
template<> std::unordered_map<std::string, $opClass::$enumClass>
...
...
@@ -192,12 +199,15 @@ os << tgfmt(R"(
auto
&&
members
=
attr
->
getEnumMembers
();
for
(
size_t
idx
=
0
;
idx
<
members
.
size
();
++
idx
)
{
size_t
d1
=
members
[
idx
].
find
(
' '
);
size_t
d2
=
members
[
idx
].
find
(
'='
);
size_t
d
=
d1
<=
d2
?
d1
:
d2
;
os
<<
tgfmt
(
R"({
PyObject* inst = e_type->tp_alloc(e_type, 0);
reinterpret_cast<$enumTpl<$opClass::$enumClass>*>(inst)->value = $opClass::$enumClass::$0;
mgb_assert(PyDict_SetItemString(e_type->tp_dict, "$0", inst) >= 0);
$enumTpl<$opClass::$enumClass>::pyobj_insts[$1] = inst;
})"
,
&
ctx
,
members
[
idx
],
idx
);
})"
,
&
ctx
,
members
[
idx
]
.
substr
(
0
,
d
)
,
idx
);
}
}
...
...
tools/gen_header_for_bin_reduce.py
浏览文件 @
fb49a283
...
...
@@ -136,12 +136,13 @@ class HeaderGen:
mode_list
=
[
i
.
strip
()
for
i
in
fin
]
for
i
in
mode_list
:
i
=
i
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
if
i
in
self
.
_elemwise_modes
:
content
=
'_cb({})'
.
format
(
i
)
else
:
content
=
''
self
.
_write_def
(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'
.
format
(
i
),
content
)
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'
.
format
(
i
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
),
content
)
self
.
_write_def
(
'MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)'
,
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)'
)
...
...
tools/param_defs/mgb_opr_param_defs.py
浏览文件 @
fb49a283
...
...
@@ -20,14 +20,14 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'ExecutionPolicy'
,
version
=
0
,
is_legacy
=
True
).
add_enum
(
'Strategy'
,
Doc
(
'HEURISTIC'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'HEURISTIC_REPRODUCIBLE'
,
'use heuristic to choose the fastest algorithm, '
Doc
(
'HEURISTIC
= 0
'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'HEURISTIC_REPRODUCIBLE
= 1
'
,
'use heuristic to choose the fastest algorithm, '
'and the chosen algorithm is reproducible'
),
Doc
(
'PROFILE'
,
Doc
(
'PROFILE
= 2
'
,
'run possible algorithms on real device to find the best'
),
Doc
(
'PROFILE_REPRODUCIBLE'
,
Doc
(
'PROFILE_REPRODUCIBLE
= 3
'
,
'the fastest of profile result that is also reproducible'
),
Doc
(
'PROFILE_HEURISTIC'
,
Doc
(
'PROFILE_HEURISTIC
= 4
'
,
'use profile result and heuristic to choose the fastest algorithm'
)).
add_fields
(
'uint64'
,
Doc
(
'workspace_limit'
,
'workspace limit in bytes'
),
...
...
@@ -35,13 +35,13 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'ExecutionPolicy'
,
'specify how to select an algorithm for an operator'
,
version
=
1
).
add_bit_combination_enum
(
'Strategy'
,
Doc
(
'HEURISTIC'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'PROFILE'
,
Doc
(
'HEURISTIC
= 1 << 0
'
,
'use heuristic to choose the fastest algorithm'
),
Doc
(
'PROFILE
= 1 << 1
'
,
'run possible algorithms on real device to find the best'
),
Doc
(
'REPRODUCIBLE'
,
Doc
(
'REPRODUCIBLE
= 1 << 2
'
,
'when profile or heuristic algo selection it require the algos'
'must be reproducible'
),
Doc
(
'OPTIMIZED'
,
Doc
(
'OPTIMIZED
= 1 << 3
'
,
'profile require algos are optmized to achieve fast-profile'
),
default
=
(
'HEURISTIC'
,),
member_alias
=
[((
'HEURISTIC'
,
'REPRODUCIBLE'
),
'HEURISTIC_REPRODUCIBLE'
),
...
...
@@ -66,19 +66,19 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'CollectiveComm'
,
'collective communication between multiple computing '
'nodes on localhost'
)
.
add_enum
(
Doc
(
'Mode'
,
'mode of collective communication'
),
Doc
(
'REDUCE_SUM'
,
'reduce by sum to output computing node'
),
Doc
(
'BROADCAST'
,
'copy input value to each output computing node'
),
Doc
(
'ALL_GATHER'
,
'each output comp node gets the concatenated '
Doc
(
'REDUCE_SUM
= 0
'
,
'reduce by sum to output computing node'
),
Doc
(
'BROADCAST
= 1
'
,
'copy input value to each output computing node'
),
Doc
(
'ALL_GATHER
= 2
'
,
'each output comp node gets the concatenated '
'value of all inputs'
),
Doc
(
'REDUCE_SCATTER_SUM'
,
Doc
(
'REDUCE_SCATTER_SUM
= 3
'
,
'reduce inputs by sum and each output gets one part of it'
),
Doc
(
'ALL_REDUCE_SUM'
,
'every output gets the sum of all inputs'
),
Doc
(
'ALL_REDUCE_MAX'
,
'every output gets the max of all inputs'
),
Doc
(
'ALL_REDUCE_MIN'
,
'every output gets the min of all inputs'
),
Doc
(
'ALL_REDUCE_PROD'
,
'every output gets the prod of all inputs'
),
Doc
(
'GATHER'
,
'concat inputs to one node'
),
Doc
(
'SCATTER'
,
'scatter input to each output computing node'
),
Doc
(
'ALL_TO_ALL'
,
'scatter inputs and gather them on each computing node'
),
Doc
(
'ALL_REDUCE_SUM
= 4
'
,
'every output gets the sum of all inputs'
),
Doc
(
'ALL_REDUCE_MAX
= 5
'
,
'every output gets the max of all inputs'
),
Doc
(
'ALL_REDUCE_MIN
= 6
'
,
'every output gets the min of all inputs'
),
Doc
(
'ALL_REDUCE_PROD
= 7
'
,
'every output gets the prod of all inputs'
),
Doc
(
'GATHER
= 8
'
,
'concat inputs to one node'
),
Doc
(
'SCATTER
= 9
'
,
'scatter input to each output computing node'
),
Doc
(
'ALL_TO_ALL
= 10
'
,
'scatter inputs and gather them on each computing node'
),
name_field
=
'mode'
))
(
pdef
(
'FakeSerializedDType'
,
...
...
@@ -91,13 +91,13 @@ pdef('PersistentOutputStorage').add_fields(
'evaluate a predicate and branch keys to setup ExecutionMask objects '
'with associated predicate proxy vars (PPVs)'
)
.
add_enum
(
Doc
(
'Mode'
,
'how to compare predicate var with branch keys'
),
Doc
(
'CASE'
,
Doc
(
'CASE
= 0
'
,
'The outputs correspond to branch keys, '
'and the one which equals predicate would be activated. '
'This behaves like a case-statement in many languages.'
),
Doc
(
'CASE_FALLBACK'
,
'like :attr:`CASE`, but add an extra output '
Doc
(
'CASE_FALLBACK
= 1
'
,
'like :attr:`CASE`, but add an extra output '
'that would be activated if no branch is matched'
),
Doc
(
'PIECEWISE'
,
'One more outputs would be produced than the '
Doc
(
'PIECEWISE
= 2
'
,
'One more outputs would be produced than the '
'number of branch keys, representing the interval in which the '
'predicate var fits in. The intervals are defined as '
r
':math:`(-\\infty, k_0), [k_0, k_1), \\ldots, '
...
...
@@ -112,20 +112,20 @@ pdef('PersistentOutputStorage').add_fields(
(
pdef
(
'CondExecPredLogical'
,
'compute a logical function over a set of PPVs'
)
.
add_enum
(
'Mode'
,
Doc
(
'OR'
,
'logical or'
),
Doc
(
'AND'
,
'logical and'
),
Doc
(
'XOR'
,
'exclusive-or'
),
Doc
(
'NOR'
,
'not or(inputs)'
),
Doc
(
'NAND'
,
'not and(inputs)'
),
Doc
(
'XNOR'
,
'not xor(inputs)'
))
.
add_enum
(
'Mode'
,
Doc
(
'OR
= 0
'
,
'logical or'
),
Doc
(
'AND
= 1
'
,
'logical and'
),
Doc
(
'XOR
= 2
'
,
'exclusive-or'
),
Doc
(
'NOR
= 3
'
,
'not or(inputs)'
),
Doc
(
'NAND
= 4
'
,
'not and(inputs)'
),
Doc
(
'XNOR
= 5
'
,
'not xor(inputs)'
))
)
(
pdef
(
'CondExecMark'
,
'add ExecutionMask of the input PPV to this opr and readers of the '
'outputs of this opr'
)
.
add_enum
(
Doc
(
'GradMode'
,
'mode for computing the gradient'
),
Doc
(
'SUM'
,
'normal gradient mode: sum all the activated components'
),
Doc
(
'SUM_COND_OUT'
,
'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
Doc
(
'SUM
= 0
'
,
'normal gradient mode: sum all the activated components'
),
Doc
(
'SUM_COND_OUT
= 1
'
,
'use :attr:`CondExecMerge.SUM_COND_OUT` mode so '
'oprs that depend on the gradient opr would not be executed '
'if the forward var is not used.'
),
name_field
=
'grad_mode'
)
...
...
@@ -135,10 +135,10 @@ pdef('PersistentOutputStorage').add_fields(
execution into account, this option can be used to bypass static
inference errors. This is currently only used by automatically
generated gradient oprs."""
),
Doc
(
'SHAPE_VALUE'
,
'enable both shape and value inference'
),
Doc
(
'SHAPE_ONLY'
,
Doc
(
'SHAPE_VALUE
= 0
'
,
'enable both shape and value inference'
),
Doc
(
'SHAPE_ONLY
= 1
'
,
'only enable shape inference (disable value inference)'
),
Doc
(
'NONE'
,
'disable both shape and value inference'
),
Doc
(
'NONE
= 2
'
,
'disable both shape and value inference'
),
name_field
=
'static_infer'
)
)
...
...
@@ -147,17 +147,17 @@ pdef('PersistentOutputStorage').add_fields(
'number of output vars (i.e. vars per branch)'
),
1
)
.
add_enum
(
'Mode'
,
Doc
(
'EXACT_ONE'
,
'copy the var whose mask is activated to the output'
Doc
(
'EXACT_ONE
= 0
'
,
'copy the var whose mask is activated to the output'
', requiring that exactly one branch is active'
),
Doc
(
'EXACT_ONE_SAME_SHAPE'
,
'like :attr:`EXACT_ONE` with the '
Doc
(
'EXACT_ONE_SAME_SHAPE
= 1
'
,
'like :attr:`EXACT_ONE` with the '
'requirement that all branches have the same shape, so shape '
'inference can be easier'
),
Doc
(
'SUM'
,
'sum all the active branches into output var; require '
Doc
(
'SUM
= 2
'
,
'sum all the active branches into output var; require '
'all branches to have the same shape. Extra shape vars are '
'needed in this mod, so the outputs can be initialized to zero '
'when no input is active (and their shapes are probably '
'unknown).'
),
Doc
(
'SUM_COND_OUT'
,
'like :attr:`SUM` but also add an ExecutionMask'
Doc
(
'SUM_COND_OUT
= 3
'
,
'like :attr:`SUM` but also add an ExecutionMask'
' to the readers of output vars, so they would be skipped if '
' no branch is taken'
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录