Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8494a152
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8494a152
编写于
4月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
chore(scripts): clarify and fix default value of bit combined enum
GitOrigin-RevId: 3716bf9bb566a23c6916df611dae563934e824cf
上级
da167cbc
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
177 addition
and
64 deletion
+177
-64
dnn/scripts/gen_flatbuffers_schema.py
dnn/scripts/gen_flatbuffers_schema.py
+14
-11
dnn/scripts/gen_param_defs.py
dnn/scripts/gen_param_defs.py
+106
-30
dnn/scripts/gen_tablegen.py
dnn/scripts/gen_tablegen.py
+13
-2
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+29
-12
imperative/tablegen/targets/python_c_extension.cpp
imperative/tablegen/targets/python_c_extension.cpp
+6
-5
src/opr/impl/dnn/dnn.oprdecl
src/opr/impl/dnn/dnn.oprdecl
+3
-3
tools/param_defs/mgb_opr_param_defs.py
tools/param_defs/mgb_opr_param_defs.py
+6
-1
未找到文件。
dnn/scripts/gen_flatbuffers_schema.py
浏览文件 @
8494a152
...
@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase):
...
@@ -52,14 +52,11 @@ class FlatBuffersWriter(IndentWriterBase):
name
=
p
+
e
name
=
p
+
e
e
=
self
.
_enums
[(
p
,
e
)]
e
=
self
.
_enums
[(
p
,
e
)]
self
.
_write_doc
(
e
.
name
)
self
.
_write_doc
(
e
.
name
)
self
.
_write
(
"enum %s%s : uint {"
,
p
,
e
.
name
,
indent
=
1
)
attribute
=
"(bit_flags)"
if
e
.
combined
else
""
self
.
_write
(
"enum %s%s : uint %s {"
,
p
,
e
.
name
,
attribute
,
indent
=
1
)
for
idx
,
member
in
enumerate
(
e
.
members
):
for
idx
,
member
in
enumerate
(
e
.
members
):
self
.
_write_doc
(
member
)
self
.
_write_doc
(
member
)
if
e
.
combined
:
self
.
_write
(
"%s,"
,
scramble_enum_member_name
(
str
(
member
)))
self
.
_write
(
"%s=%d,"
,
scramble_enum_member_name
(
str
(
member
)),
1
<<
idx
)
else
:
self
.
_write
(
"%s,"
,
scramble_enum_member_name
(
str
(
member
)))
self
.
_write
(
"}
\n
"
,
indent
=-
1
)
self
.
_write
(
"}
\n
"
,
indent
=-
1
)
def
_write_doc
(
self
,
doc
):
def
_write_doc
(
self
,
doc
):
...
@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase):
...
@@ -97,8 +94,11 @@ class FlatBuffersWriter(IndentWriterBase):
return
return
self
.
_write_doc
(
e
.
name
)
self
.
_write_doc
(
e
.
name
)
self
.
_used_enum
.
add
(
key
)
self
.
_used_enum
.
add
(
key
)
self
.
_write
(
"%s:%s%s = %s;"
,
e
.
name_field
,
p
.
name
,
e
.
name
,
if
e
.
combined
:
scramble_enum_member_name
(
str
(
e
.
members
[
e
.
default
])))
default
=
e
.
compose_combined_enum
(
e
.
default
)
else
:
default
=
scramble_enum_member_name
(
str
(
e
.
members
[
e
.
default
]))
self
.
_write
(
"%s:%s%s = %s;"
,
e
.
name_field
,
p
.
name
,
e
.
name
,
default
)
def
_resolve_const
(
self
,
v
):
def
_resolve_const
(
self
,
v
):
while
v
in
self
.
_cur_const_val
:
while
v
in
self
.
_cur_const_val
:
...
@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase):
...
@@ -120,9 +120,12 @@ class FlatBuffersWriter(IndentWriterBase):
return
return
self
.
_used_enum
.
add
((
e
.
src_class
,
e
.
src_name
))
self
.
_used_enum
.
add
((
e
.
src_class
,
e
.
src_name
))
enum_name
=
e
.
src_class
+
e
.
src_name
enum_name
=
e
.
src_class
+
e
.
src_name
self
.
_write
(
s
=
e
.
src_enum
"%s:%s = %s;"
,
e
.
name_field
,
enum_name
,
if
s
.
combined
:
scramble_enum_member_name
(
str
(
e
.
src_enum
.
members
[
e
.
get_default
()])))
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
else
:
default
=
scramble_enum_member_name
(
str
(
s
.
members
[
e
.
get_default
()]))
self
.
_write
(
"%s:%s = %s;"
,
e
.
name_field
,
enum_name
,
default
)
def
_get_fb_default
(
self
,
cppdefault
):
def
_get_fb_default
(
self
,
cppdefault
):
if
not
isinstance
(
cppdefault
,
str
):
if
not
isinstance
(
cppdefault
,
str
):
...
...
dnn/scripts/gen_param_defs.py
浏览文件 @
8494a152
...
@@ -73,11 +73,21 @@ class member_defs:
...
@@ -73,11 +73,21 @@ class member_defs:
"""define an enum; the result would contain both an enum class def and its
"""define an enum; the result would contain both an enum class def and its
corresponding data field
corresponding data field
:param default: index of default member value
:param default:
for normal enum class: index of default member value
for bit combined class: tuple of index of default member value
For example, following representations of the default value for bit
combined class are all equivalent:
Enum(members=('a', 'b', 'c'), default=('a', 'b'), ...)
Enum(members=('a', 'b', 'c'), default=(0, 1), ...)
Enum(members=('a', 'b', 'c'), default=(1 << 0) | (1 << 1), ...)
:attr name_field: name of the data field of this enum in the param
:attr name_field: name of the data field of this enum in the param
struct
struct
:attr member_alias: list of (member, alias) pairs
:attr member_alias:
for normal enum class: list of (member, alias) pairs
for bit combined class: list of (tuple of members, alias) paris
"""
"""
__slots__
=
[
'name'
,
'name_field'
,
'members'
,
'default'
,
__slots__
=
[
'name'
,
'name_field'
,
'members'
,
'default'
,
'member_alias'
,
'combined'
]
'member_alias'
,
'combined'
]
...
@@ -90,17 +100,11 @@ class member_defs:
...
@@ -90,17 +100,11 @@ class member_defs:
name
=
member_defs
.
Doc
.
make
(
name
)
name
=
member_defs
.
Doc
.
make
(
name
)
assert
name
.
id
[
0
].
isupper
()
assert
name
.
id
[
0
].
isupper
()
members
=
tuple
(
map
(
member_defs
.
Doc
.
make
,
members
))
members
=
tuple
(
map
(
member_defs
.
Doc
.
make
,
members
))
if
isinstance
(
default
,
str
):
if
default
not
in
name_field
:
raise
ValueError
(
"Default value '{}' does not exist."
.
format
(
default
))
default
=
name_field
.
index
(
default
)
assert
isinstance
(
default
,
int
)
self
.
name
=
name
self
.
name
=
name
self
.
combined
=
combined
self
.
combined
=
combined
self
.
name_field
=
self
.
get_name_field
(
name
.
id
,
name_field
)
self
.
name_field
=
self
.
get_name_field
(
name
.
id
,
name_field
)
self
.
members
=
members
self
.
members
=
members
self
.
default
=
default
self
.
default
=
self
.
normalize_enum_value
(
default
)
self
.
all_enums
[(
param_name
,
name
.
id
)]
=
self
self
.
all_enums
[(
param_name
,
name
.
id
)]
=
self
...
@@ -114,6 +118,43 @@ class member_defs:
...
@@ -114,6 +118,43 @@ class member_defs:
assert
isinstance
(
name_field
,
str
)
assert
isinstance
(
name_field
,
str
)
return
name_field
return
name_field
def
normalize_enum_value
(
self
,
value
):
def
normalize
(
v
):
if
isinstance
(
v
,
str
):
if
v
not
in
self
.
members
:
raise
ValueError
(
"enum member '{}' does not exist."
.
format
(
v
))
v
=
self
.
members
.
index
(
v
)
assert
isinstance
(
v
,
int
)
return
v
if
self
.
combined
:
if
isinstance
(
value
,
int
):
value
=
self
.
decompose_combined_enum
(
value
)
assert
isinstance
(
value
,
tuple
)
value
=
tuple
(
normalize
(
i
)
for
i
in
value
)
return
value
else
:
return
normalize
(
value
)
@
staticmethod
def
decompose_combined_enum
(
v
):
"""Integer => tuple of the indexes of the enum members"""
assert
isinstance
(
v
,
int
)
idx
=
0
members
=
[]
while
v
>
0
:
if
v
&
1
:
members
.
append
(
idx
)
idx
+=
1
v
>>=
1
return
tuple
(
members
)
def
compose_combined_enum
(
self
,
v
):
"""tuple of members => Integer"""
assert
self
.
combined
and
isinstance
(
v
,
tuple
)
norm_v
=
self
.
normalize_enum_value
(
v
)
return
sum
(
1
<<
i
for
i
in
norm_v
)
class
Field
(
Base
):
class
Field
(
Base
):
"""define a normal data field"""
"""define a normal data field"""
__slots__
=
[
'name'
,
'dtype'
,
'default'
]
__slots__
=
[
'name'
,
'dtype'
,
'default'
]
...
@@ -146,6 +187,10 @@ class member_defs:
...
@@ -146,6 +187,10 @@ class member_defs:
src_name
=
name
src_name
=
name
self
.
src_name
=
src_name
self
.
src_name
=
src_name
self
.
default
=
default
self
.
default
=
default
# TODO: remove this assertion if needed; adding mock param_defs in
# current testing framework is too complicated, and currently we
# only allow aliasing of normal enum
assert
not
self
.
src_enum
.
combined
@
property
@
property
def
src_enum
(
self
):
def
src_enum
(
self
):
...
@@ -157,7 +202,7 @@ class member_defs:
...
@@ -157,7 +202,7 @@ class member_defs:
set"""
set"""
if
self
.
default
is
None
:
if
self
.
default
is
None
:
return
self
.
src_enum
.
default
return
self
.
src_enum
.
default
return
self
.
default
return
self
.
src_enum
.
normalize_enum_value
(
self
.
default
)
class
ParamDef
:
class
ParamDef
:
...
@@ -198,7 +243,7 @@ class ParamDef:
...
@@ -198,7 +243,7 @@ class ParamDef:
self
.
name
.
id
,
name
,
name_field
,
members
,
default
,
member_alias
))
self
.
name
.
id
,
name
,
name_field
,
members
,
default
,
member_alias
))
return
self
return
self
def
add_bit_combination_enum
(
self
,
name
,
*
members
,
default
=
0
,
def
add_bit_combination_enum
(
self
,
name
,
*
members
,
default
=
tuple
()
,
name_field
=
None
,
member_alias
=
[]):
name_field
=
None
,
member_alias
=
[]):
self
.
members
.
append
(
member_defs
.
Enum
(
self
.
members
.
append
(
member_defs
.
Enum
(
self
.
name
.
id
,
name
,
name_field
,
members
,
default
,
member_alias
,
True
))
self
.
name
.
id
,
name
,
name_field
,
members
,
default
,
member_alias
,
True
))
...
@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase):
...
@@ -322,11 +367,13 @@ class PyWriter(IndentWriterBase):
' for idx, v in enumerate(pdata):
\n
'
' for idx, v in enumerate(pdata):
\n
'
' if isinstance(v, _EnumBase):
\n
'
' if isinstance(v, _EnumBase):
\n
'
' pdata[idx] = _enum_member2num[id(v)]
\n
'
' pdata[idx] = _enum_member2num[id(v)]
\n
'
' elif isinstance(v, _BitCombinedEnumBase):
\n
'
' pdata[idx] = v._value_
\n
'
' return tag + self._packer.pack(*pdata)
\n
'
' return tag + self._packer.pack(*pdata)
\n
'
'
\n
'
'
\n
'
)
)
self
.
_write
(
# it's hard to mix custom implemention into enum, just do copy-paste instead
'class _EnumBase(enum.Enum):
\n
'
classbody
=
(
' @classmethod
\n
'
' @classmethod
\n
'
' def __normalize(cls, val):
\n
'
' def __normalize(cls, val):
\n
'
' if isinstance(val, str):
\n
'
' if isinstance(val, str):
\n
'
...
@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase):
...
@@ -349,6 +396,12 @@ class PyWriter(IndentWriterBase):
' return super()._missing_(value)
\n
'
' return super()._missing_(value)
\n
'
'
\n
'
'
\n
'
)
)
self
.
_write
(
'class _EnumBase(enum.Enum):
\n
'
+
classbody
)
self
.
_write
(
'class _BitCombinedEnumBase(enum.Flag):
\n
'
+
classbody
)
if
not
self
.
_imperative
:
if
not
self
.
_imperative
:
self
.
_write
(
self
.
_write
(
'def _as_dtype_num(dtype):
\n
'
'def _as_dtype_num(dtype):
\n
'
...
@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase):
...
@@ -464,30 +517,42 @@ class SerializedDType(_ParamDefBase):
def
_on_member_enum
(
self
,
e
):
def
_on_member_enum
(
self
,
e
):
qualname
=
'{}.{}'
.
format
(
self
.
_cur_param_name
,
e
.
name
)
qualname
=
'{}.{}'
.
format
(
self
.
_cur_param_name
,
e
.
name
)
self
.
_write
(
'class %s(_EnumBase):'
,
e
.
name
,
indent
=
1
)
if
e
.
combined
:
self
.
_write
(
'class %s(_BitCombinedEnumBase):'
,
e
.
name
,
indent
=
1
)
else
:
self
.
_write
(
'class %s(_EnumBase):'
,
e
.
name
,
indent
=
1
)
self
.
_write_doc
(
e
.
name
)
self
.
_write_doc
(
e
.
name
)
for
idx
,
emem
in
enumerate
(
e
.
members
):
for
idx
,
emem
in
enumerate
(
e
.
members
):
self
.
_write
(
'%s = "%s"'
,
emem
,
emem
)
self
.
_write_doc
(
emem
)
if
e
.
combined
:
if
e
.
combined
:
self
.
_
enum_member2num
.
append
(
'id({}.{}):{}'
.
format
(
self
.
_
write
(
'%s = 1 << %d'
,
emem
,
idx
)
qualname
,
emem
,
1
<<
idx
)
)
self
.
_write_doc
(
emem
)
else
:
else
:
self
.
_write
(
'%s = "%s"'
,
emem
,
emem
)
self
.
_write_doc
(
emem
)
self
.
_enum_member2num
.
append
(
'id({}.{}):{}'
.
format
(
self
.
_enum_member2num
.
append
(
'id({}.{}):{}'
.
format
(
qualname
,
emem
,
idx
))
qualname
,
emem
,
idx
))
for
emem
,
emem_alis
in
e
.
member_alias
:
for
emem
,
emem_alias
in
e
.
member_alias
:
self
.
_write
(
'%s = %s'
,
emem_alis
,
emem
)
if
e
.
combined
:
self
.
_write
(
'%s = %s'
,
emem_alias
,
e
.
compose_combined_enum
(
emem
))
else
:
self
.
_write
(
'%s = %s'
,
emem_alias
,
emem
)
self
.
_unindent
()
self
.
_unindent
()
self
.
_write
(
''
)
self
.
_write
(
''
)
if
e
.
combined
:
default
=
e
.
compose_combined_enum
(
e
.
default
)
else
:
default
=
"'{}'"
.
format
(
e
.
members
[
e
.
default
])
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
name
=
e
.
name_field
,
name
=
e
.
name_field
,
cvt
=
'{}.convert({})'
.
format
(
qualname
,
e
.
name_field
),
cvt
=
'{}.convert({})'
.
format
(
qualname
,
e
.
name_field
),
fmt
=
'I'
,
fmt
=
'I'
,
default
=
"'{}'"
.
format
(
e
.
members
[
e
.
default
])
,
default
=
default
,
type
=
qualname
,
type
=
qualname
,
doc
=
None
))
doc
=
None
))
...
@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase):
...
@@ -495,11 +560,16 @@ class SerializedDType(_ParamDefBase):
self
.
_write
(
'%s = %s.%s'
,
e
.
name
,
e
.
src_class
,
e
.
src_name
)
self
.
_write
(
'%s = %s.%s'
,
e
.
name
,
e
.
src_class
,
e
.
src_name
)
s
=
e
.
src_enum
s
=
e
.
src_enum
qualname
=
'{}.{}'
.
format
(
e
.
src_class
,
e
.
src_name
)
qualname
=
'{}.{}'
.
format
(
e
.
src_class
,
e
.
src_name
)
if
s
.
combined
:
default
=
s
.
compose_combined_enum
(
e
.
get_default
())
else
:
default
=
"'{}'"
.
format
(
s
.
members
[
e
.
get_default
()])
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
self
.
_cur_fields
.
append
(
self
.
FieldDef
(
name
=
e
.
name_field
,
name
=
e
.
name_field
,
cvt
=
'{}.convert({})'
.
format
(
qualname
,
e
.
name_field
),
cvt
=
'{}.convert({})'
.
format
(
qualname
,
e
.
name_field
),
fmt
=
'I'
,
fmt
=
'I'
,
default
=
"'{}'"
.
format
(
s
.
members
[
e
.
get_default
()])
,
default
=
default
,
type
=
qualname
,
type
=
qualname
,
doc
=
None
))
doc
=
None
))
...
@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase):
...
@@ -639,14 +709,19 @@ class CPPWriter(IndentWriterBase):
v
+=
','
v
+=
','
self
.
_write
(
v
)
self
.
_write
(
v
)
for
mem
,
alias
in
e
.
member_alias
:
for
mem
,
alias
in
e
.
member_alias
:
self
.
_write
(
'%s = %s,'
,
alias
,
mem
)
if
e
.
combined
:
self
.
_write
(
'%s = %s,'
,
alias
,
e
.
compose_combined_enum
(
mem
))
else
:
self
.
_write
(
'%s = %s,'
,
alias
,
mem
)
self
.
_write
(
'};'
,
indent
=-
1
)
self
.
_write
(
'};'
,
indent
=-
1
)
self
.
_non_static_members
.
append
(
e
)
self
.
_non_static_members
.
append
(
e
)
self
.
_write
(
'static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;'
,
self
.
_write
(
'static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;'
,
str
(
e
.
name
).
upper
(),
len
(
e
.
members
))
str
(
e
.
name
).
upper
(),
len
(
e
.
members
))
self
.
_add_ctor_args
(
e
.
name
,
if
e
.
combined
:
'{}::{}'
.
format
(
e
.
name
,
e
.
members
[
e
.
default
]),
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
e
.
name_field
)
else
:
default
=
'{}::{}'
.
format
(
e
.
name
,
e
.
members
[
e
.
default
])
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
def
_on_member_enum_alias
(
self
,
e
):
def
_on_member_enum_alias
(
self
,
e
):
s
=
e
.
src_enum
s
=
e
.
src_enum
...
@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase):
...
@@ -654,10 +729,11 @@ class CPPWriter(IndentWriterBase):
self
.
_non_static_members
.
append
(
e
)
self
.
_non_static_members
.
append
(
e
)
self
.
_write
(
'static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;'
,
self
.
_write
(
'static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;'
,
str
(
e
.
name
).
upper
(),
len
(
s
.
members
))
str
(
e
.
name
).
upper
(),
len
(
s
.
members
))
self
.
_add_ctor_args
(
e
.
name
,
if
s
.
combined
:
'{}::{}'
.
format
(
e
.
name
,
default
=
'static_cast<{}>({})'
.
format
(
e
.
name
,
s
.
compose_combined_enum
(
e
.
default
))
s
.
members
[
e
.
get_default
()]),
else
:
e
.
name_field
)
default
=
'{}::{}'
.
format
(
e
.
name
,
s
.
members
[
e
.
get_default
()])
self
.
_add_ctor_args
(
e
.
name
,
default
,
e
.
name_field
)
def
_on_member_field
(
self
,
f
):
def
_on_member_field
(
self
,
f
):
self
.
_non_static_members
.
append
(
f
)
self
.
_non_static_members
.
append
(
f
)
...
...
dnn/scripts/gen_tablegen.py
浏览文件 @
8494a152
...
@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase):
...
@@ -106,7 +106,12 @@ class ConverterWriter(IndentWriterBase):
return
return
# wrapped with default value
# wrapped with default value
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
e
.
default
)
if
e
.
combined
:
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
e
.
compose_combined_enum
(
e
.
default
))
else
:
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
e
.
members
[
e
.
default
])
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
self
.
_current_tparams
.
append
(
"{}:${}"
.
format
(
wrapped
,
e
.
name_field
))
self
.
_current_tparams
.
append
(
"{}:${}"
.
format
(
wrapped
,
e
.
name_field
))
...
@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase):
...
@@ -124,7 +129,13 @@ class ConverterWriter(IndentWriterBase):
self
.
_write
(
"def {} : {};"
.
format
(
td_class
,
enum_def
))
self
.
_write
(
"def {} : {};"
.
format
(
td_class
,
enum_def
))
# wrapped with default value
# wrapped with default value
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
e
.
get_default
())
s
=
e
.
src_enum
if
s
.
combined
:
default_val
=
"static_cast<{}::{}>({})"
.
format
(
fullname
,
e
.
name
,
s
.
compose_combined_enum
(
e
.
get_default
()))
else
:
default_val
=
"{}::{}::{}"
.
format
(
fullname
,
e
.
name
,
s
.
members
[
e
.
get_default
()])
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
wrapped
=
self
.
_wrapped_with_default_value
(
td_class
,
default_val
)
self
.
_current_tparams
.
append
(
"{}:${}"
.
format
(
wrapped
,
e
.
name_field
))
self
.
_current_tparams
.
append
(
"{}:${}"
.
format
(
wrapped
,
e
.
name_field
))
...
...
imperative/python/src/ops.cpp
浏览文件 @
8494a152
...
@@ -87,9 +87,13 @@ struct pyobj_convert_generic {
...
@@ -87,9 +87,13 @@ struct pyobj_convert_generic {
}
}
};
};
template
<
typename
T
,
typename
SFINAE
=
void
>
struct
EnumTrait
;
template
<
typename
T
>
template
<
typename
T
>
struct
EnumTrait
{
struct
EnumTrait
<
T
,
std
::
enable_if_t
<
std
::
is_enum_v
<
T
>>>
{
static
constexpr
bool
is_bit_combined
=
false
;
static
constexpr
bool
is_bit_combined
=
false
;
static
constexpr
std
::
underlying_type_t
<
T
>
max
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper {
...
@@ -264,18 +268,25 @@ struct BitCombinedEnumWrapper {
return
ret
;
return
ret
;
}
}
}
}
static
PyObject
*
py_new_combined_enum
(
PyTypeObject
*
type
,
PyObject
*
,
PyObject
*
)
{
static
PyObject
*
py_new_combined_enum
(
PyTypeObject
*
type
,
PyObject
*
args
,
PyObject
*
)
{
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
if
(
!
PyTuple_Size
(
args
))
{
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
static_cast
<
T
>
(
1
);
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
return
obj
;
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
T
();
}
return
obj
;
static
int
py_init
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
)
{
}
int
input
=
1
;
else
{
if
(
PyArg_ParseTuple
(
args
,
"|i"
,
&
input
)){
PyObject
*
input
;
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
self
)
->
value
=
if
(
!
PyArg_ParseTuple
(
args
,
"|O"
,
&
input
))
{
static_cast
<
T
>
(
input
);
return
nullptr
;
}
T
value
;
try
{
value
=
pyobj_convert_generic
<
T
>::
from
(
input
);
}
CATCH_ALL
(
nullptr
);
PyObject
*
obj
=
type
->
tp_alloc
(
type
,
0
);
reinterpret_cast
<
BitCombinedEnumWrapper
*>
(
obj
)
->
value
=
value
;
return
obj
;
}
}
return
0
;
}
}
static
PyObject
*
py_repr
(
PyObject
*
self
)
{
static
PyObject
*
py_repr
(
PyObject
*
self
)
{
return
pyobj_convert_generic
<
std
::
string
>::
to
(
return
pyobj_convert_generic
<
std
::
string
>::
to
(
...
@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T,
...
@@ -325,6 +336,12 @@ struct pyobj_convert_generic<T,
static
T
from
(
PyObject
*
obj
)
{
static
T
from
(
PyObject
*
obj
)
{
if
(
PyObject_TypeCheck
(
obj
,
&
Wrapper
::
type
))
{
if
(
PyObject_TypeCheck
(
obj
,
&
Wrapper
::
type
))
{
return
reinterpret_cast
<
Wrapper
*>
(
obj
)
->
value
;
return
reinterpret_cast
<
Wrapper
*>
(
obj
)
->
value
;
}
else
if
(
PyLong_Check
(
obj
))
{
auto
value
=
pyobj_convert_generic
<
std
::
underlying_type_t
<
T
>>::
from
(
obj
);
mgb_throw_if
(
value
>
EnumTrait
<
T
>::
max
,
mgb
::
MegBrainError
,
"out of range, cannot convert %zu to %s"
,
static_cast
<
uint32_t
>
(
value
),
Wrapper
::
name
);
return
static_cast
<
T
>
(
value
);
}
}
// try as string
// try as string
// TODO: type checkcd
// TODO: type checkcd
...
...
imperative/tablegen/targets/python_c_extension.cpp
浏览文件 @
8494a152
...
@@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() {
...
@@ -90,10 +90,12 @@ void EnumAttrEmitter::emit_tpl_spl() {
"template<> PyNumberMethods "
"template<> PyNumberMethods "
"$enumTpl<$opClass::$enumClass>::number_methods={};
\n
"
,
"$enumTpl<$opClass::$enumClass>::number_methods={};
\n
"
,
&
ctx
);
&
ctx
);
os
<<
tgfmt
(
os
<<
tgfmt
(
R"(
"template<> struct EnumTrait<$opClass::$enumClass> { static constexpr "
template<> struct EnumTrait<$opClass::$enumClass> {
"bool is_bit_combined = true;};
\n
"
,
static constexpr bool is_bit_combined = true;
&
ctx
);
static constexpr std::underlying_type_t<$opClass::$enumClass> max = (1llu << $0) - 1;
};
)"
,
&
ctx
,
attr
->
getEnumMembers
().
size
());
}
}
auto
str2type
=
[
&
](
auto
&&
i
)
->
std
::
string
{
auto
str2type
=
[
&
](
auto
&&
i
)
->
std
::
string
{
...
@@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) {
...
@@ -138,7 +140,6 @@ void $0(PyTypeObject& py_type) {
// others should always use singleton
// others should always use singleton
os
<<
tgfmt
(
R"(
os
<<
tgfmt
(
R"(
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
e_type.tp_new = $enumTpl<$opClass::$enumClass>::py_new_combined_enum;
e_type.tp_init = $enumTpl<$opClass::$enumClass>::py_init;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
auto& number_method = $enumTpl<$opClass::$enumClass>::number_methods;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_or = $enumTpl<$opClass::$enumClass>::py_or;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
number_method.nb_and = $enumTpl<$opClass::$enumClass>::py_and;
...
...
src/opr/impl/dnn/dnn.oprdecl
浏览文件 @
8494a152
...
@@ -6,7 +6,7 @@ decl_opr('Convolution',
...
@@ -6,7 +6,7 @@ decl_opr('Convolution',
'convolution kernel in '
'convolution kernel in '
'(out channel, in channel, kern row, kern col) format'
)],
'(out channel, in channel, kern row, kern col) format'
)],
params
=
[(
'param'
,
'ConvolutionV0'
),
params
=
[(
'param'
,
'ConvolutionV0'
),
(
'execution_polity'
,
'ExecutionPolicy'
)],
(
'execution_polity'
,
'ExecutionPolicy
V0
'
)],
desc
=
'batched convolution on channeled 2D images'
)
desc
=
'batched convolution on channeled 2D images'
)
decl_opr
(
'Convolution'
,
decl_opr
(
'Convolution'
,
...
@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData',
...
@@ -28,7 +28,7 @@ decl_opr('ConvolutionBackwardData',
'convolution kernel in '
'convolution kernel in '
'(out channel, in channel, kern row, kern col) format'
)],
'(out channel, in channel, kern row, kern col) format'
)],
params
=
[(
'param'
,
'ConvolutionV0'
),
params
=
[(
'param'
,
'ConvolutionV0'
),
(
'execution_polity'
,
'ExecutionPolicy'
)],
(
'execution_polity'
,
'ExecutionPolicy
V0
'
)],
body
=
[
body
=
[
'a, b = all_inputs'
,
'a, b = all_inputs'
,
'all_inputs = [b, a]'
'all_inputs = [b, a]'
...
@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward',
...
@@ -201,7 +201,7 @@ decl_opr('ConvBiasForward',
Doc
(
'bias'
,
'bias'
),
Doc
(
'bias'
,
'bias'
),
],
],
params
=
[(
'param'
,
'ConvBiasV1'
),
params
=
[(
'param'
,
'ConvBiasV1'
),
(
'execution_policy'
,
'ExecutionPolicy'
)],
(
'execution_policy'
,
'ExecutionPolicy
V0
'
)],
desc
=
(
'activation(convolution(src, filter) + bias) with specified '
desc
=
(
'activation(convolution(src, filter) + bias) with specified '
'dtype'
),
'dtype'
),
has_out_dtype
=
True
)
has_out_dtype
=
True
)
...
...
tools/param_defs/mgb_opr_param_defs.py
浏览文件 @
8494a152
...
@@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields(
...
@@ -42,7 +42,12 @@ pdef('PersistentOutputStorage').add_fields(
'when profile or heuristic algo selection it require the algos'
'when profile or heuristic algo selection it require the algos'
'must be reproducible'
),
'must be reproducible'
),
Doc
(
'OPTMIZED'
,
Doc
(
'OPTMIZED'
,
'profile require algos are optmized to achieve fast-profile'
)).
'profile require algos are optmized to achieve fast-profile'
),
default
=
(
'HEURISTIC'
,),
member_alias
=
[((
'HEURISTIC'
,
'REPRODUCIBLE'
),
'HEURISTIC_REPRODUCIBLE'
),
((
'PROFILE'
,
'REPRODUCIBLE'
),
'PROFILE_REPRODUCIBLE'
),
((
'PROFILE'
,
'HEURISTIC'
),
'PROFILE_HEURISTIC'
),
]).
add_fields
(
'uint64'
,
add_fields
(
'uint64'
,
Doc
(
'workspace_limit'
,
'workspace limit in bytes'
),
Doc
(
'workspace_limit'
,
'workspace limit in bytes'
),
str
(
2
**
64
-
1
)
+
'ull'
))
str
(
2
**
64
-
1
)
+
'ull'
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录