Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9be8de60
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9be8de60
编写于
5月 06, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(midout): formatting midout tools
GitOrigin-RevId: 9aa6a9ec575d2343fed3dd8f1a0e4241bc99b539
上级
8182af6e
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
110 addition
and
79 deletion
+110
-79
tools/gen_header_for_bin_reduce.py
tools/gen_header_for_bin_reduce.py
+110
-79
未找到文件。
tools/gen_header_for_bin_reduce.py
浏览文件 @
9be8de60
...
...
@@ -8,21 +8,22 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
sys
import
re
if
sys
.
version_info
[
0
]
!=
3
or
sys
.
version_info
[
1
]
<
5
:
print
(
'This script requires Python version 3.5'
)
sys
.
exit
(
1
)
import
argparse
import
json
import
os
import
re
import
subprocess
import
sys
import
tempfile
from
pathlib
import
Path
MIDOUT_TRACE_MAGIC
=
'midout_trace v1
\n
'
if
sys
.
version_info
[
0
]
!=
3
or
sys
.
version_info
[
1
]
<
5
:
print
(
"This script requires Python version 3.5"
)
sys
.
exit
(
1
)
MIDOUT_TRACE_MAGIC
=
"midout_trace v1
\n
"
class
HeaderGen
:
_dtypes
=
None
...
...
@@ -42,13 +43,14 @@ class HeaderGen:
self
.
_midout_files
=
[]
_megvii3_root_cache
=
None
@
classmethod
def
get_megvii3_root
(
cls
):
if
cls
.
_megvii3_root_cache
is
not
None
:
return
cls
.
_megvii3_root_cache
wd
=
Path
(
__file__
).
resolve
().
parent
while
wd
.
parent
!=
wd
:
workspace_file
=
wd
/
'WORKSPACE'
workspace_file
=
wd
/
"WORKSPACE"
if
workspace_file
.
is_file
():
cls
.
_megvii3_root_cache
=
str
(
wd
)
return
cls
.
_megvii3_root_cache
...
...
@@ -56,6 +58,7 @@ class HeaderGen:
return
None
_megengine_root_cache
=
None
@
classmethod
def
get_megengine_root
(
cls
):
if
cls
.
_megengine_root_cache
is
not
None
:
...
...
@@ -66,15 +69,15 @@ class HeaderGen:
def
extend_netinfo
(
self
,
data
):
self
.
_has_netinfo
=
True
if
'hash'
not
in
data
:
if
"hash"
not
in
data
:
self
.
_file_without_hash
=
True
else
:
self
.
_graph_hashes
.
add
(
str
(
data
[
'hash'
]))
for
i
in
data
[
'dtypes'
]:
self
.
_graph_hashes
.
add
(
str
(
data
[
"hash"
]))
for
i
in
data
[
"dtypes"
]:
self
.
_dtypes
.
add
(
i
)
for
i
in
data
[
'opr_types'
]:
for
i
in
data
[
"opr_types"
]:
self
.
_oprs
.
add
(
i
)
for
i
in
data
[
'elemwise_modes'
]:
for
i
in
data
[
"elemwise_modes"
]:
self
.
_elemwise_modes
.
add
(
i
)
def
extend_midout
(
self
,
fname
):
...
...
@@ -82,7 +85,7 @@ class HeaderGen:
def
generate
(
self
,
fout
):
self
.
_fout
=
fout
self
.
_write_def
(
'MGB_BINREDUCE_VERSION'
,
'20190219'
)
self
.
_write_def
(
"MGB_BINREDUCE_VERSION"
,
"20190219"
)
if
self
.
_has_netinfo
:
self
.
_write_dtype
()
self
.
_write_elemwise_modes
()
...
...
@@ -93,13 +96,13 @@ class HeaderGen:
def
strip_opr_name_with_version
(
self
,
name
):
pos
=
len
(
name
)
t
=
re
.
search
(
r
'V\d+$'
,
name
)
t
=
re
.
search
(
r
"V\d+$"
,
name
)
if
t
:
pos
=
t
.
start
()
return
name
[:
pos
]
def
_write_oprs
(
self
):
defs
=
[
'}'
,
'namespace opr {'
]
defs
=
[
"}"
,
"namespace opr {"
]
already_declare
=
set
()
already_instance
=
set
()
for
i
in
self
.
_oprs
:
...
...
@@ -109,13 +112,15 @@ class HeaderGen:
else
:
already_declare
.
add
(
i
)
defs
.
append
(
'class {};'
.
format
(
i
))
defs
.
append
(
'}'
)
defs
.
append
(
'namespace serialization {'
)
defs
.
append
(
"""
defs
.
append
(
"class {};"
.
format
(
i
))
defs
.
append
(
"}"
)
defs
.
append
(
"namespace serialization {"
)
defs
.
append
(
"""
template<class Opr, class Callee>
struct OprRegistryCaller {
}; """
)
}; """
)
for
i
in
sorted
(
self
.
_oprs
):
i
=
self
.
strip_opr_name_with_version
(
i
)
if
i
in
already_instance
:
...
...
@@ -123,40 +128,53 @@ class HeaderGen:
else
:
already_instance
.
add
(
i
)
defs
.
append
(
"""
defs
.
append
(
"""
template<class Callee>
struct OprRegistryCaller<opr::{}, Callee>: public
OprRegistryCallerDefaultImpl<Callee> {{
}}; """
.
format
(
i
))
self
.
_write_def
(
'MGB_OPR_REGISTRY_CALLER_SPECIALIZE'
,
defs
)
}}; """
.
format
(
i
)
)
self
.
_write_def
(
"MGB_OPR_REGISTRY_CALLER_SPECIALIZE"
,
defs
)
def
_write_elemwise_modes
(
self
):
with
tempfile
.
NamedTemporaryFile
()
as
ftmp
:
fpath
=
os
.
path
.
realpath
(
ftmp
.
name
)
subprocess
.
check_call
(
[
'./dnn/scripts/gen_param_defs.py'
,
'--write-enum-items'
,
'Elemwise:Mode'
,
'./dnn/scripts/opr_param_defs.py'
,
fpath
],
cwd
=
self
.
get_megengine_root
()
[
"./dnn/scripts/gen_param_defs.py"
,
"--write-enum-items"
,
"Elemwise:Mode"
,
"./dnn/scripts/opr_param_defs.py"
,
fpath
,
],
cwd
=
self
.
get_megengine_root
(),
)
with
open
(
fpath
)
as
fin
:
mode_list
=
[
i
.
strip
()
for
i
in
fin
]
for
i
in
mode_list
:
i
=
i
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]
i
=
i
.
split
(
" "
)[
0
].
split
(
"="
)[
0
]
if
i
in
self
.
_elemwise_modes
:
content
=
'_cb({})'
.
format
(
i
)
content
=
"_cb({})"
.
format
(
i
)
else
:
content
=
''
content
=
""
self
.
_write_def
(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'
.
format
(
i
.
split
(
' '
)[
0
].
split
(
'='
)[
0
]),
content
)
self
.
_write_def
(
'MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)'
,
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)'
)
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)"
.
format
(
i
.
split
(
" "
)[
0
].
split
(
"="
)[
0
]
),
content
,
)
self
.
_write_def
(
"MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)"
,
"_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)"
,
)
def
_write_dtype
(
self
):
if
'Float16'
not
in
self
.
_dtypes
:
if
"Float16"
not
in
self
.
_dtypes
:
# MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
# support in the past; however `FLOT16' is really a typo. We plan to
# change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
...
...
@@ -166,36 +184,41 @@ class HeaderGen:
# In the future when the situation is settled and no one would ever
# use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
# safely deleted.
self
.
_write_def
(
'MEGDNN_DISABLE_FLOT16'
,
1
)
self
.
_write_def
(
'MEGDNN_DISABLE_FLOAT16'
,
1
)
self
.
_write_def
(
"MEGDNN_DISABLE_FLOT16"
,
1
)
self
.
_write_def
(
"MEGDNN_DISABLE_FLOAT16"
,
1
)
def
_write_hash
(
self
):
if
self
.
_file_without_hash
:
print
(
'WARNING: network info has no graph hash. Using json file '
'generated by MegBrain >= 7.28.0 is recommended'
)
print
(
"WARNING: network info has no graph hash. Using json file "
"generated by MegBrain >= 7.28.0 is recommended"
)
else
:
defs
=
'ULL,'
.
join
(
self
.
_graph_hashes
)
+
'ULL'
self
.
_write_def
(
'MGB_BINREDUCE_GRAPH_HASHES'
,
defs
)
defs
=
"ULL,"
.
join
(
self
.
_graph_hashes
)
+
"ULL"
self
.
_write_def
(
"MGB_BINREDUCE_GRAPH_HASHES"
,
defs
)
def
_write_def
(
self
,
name
,
val
):
if
isinstance
(
val
,
list
):
val
=
'
\n
'
.
join
(
val
)
val
=
str
(
val
).
strip
().
replace
(
'
\n
'
,
'
\\\n
'
)
self
.
_fout
.
write
(
'#define {} {}
\n
'
.
format
(
name
,
val
))
val
=
"
\n
"
.
join
(
val
)
val
=
str
(
val
).
strip
().
replace
(
"
\n
"
,
"
\\\n
"
)
self
.
_fout
.
write
(
"#define {} {}
\n
"
.
format
(
name
,
val
))
def
_write_midout
(
self
):
if
not
self
.
_midout_files
:
return
gen
=
os
.
path
.
join
(
self
.
get_megengine_root
(),
'third_party'
,
'midout'
,
'gen_header.py'
)
gen
=
os
.
path
.
join
(
self
.
get_megengine_root
(),
"third_party"
,
"midout"
,
"gen_header.py"
)
if
self
.
get_megvii3_root
():
gen
=
os
.
path
.
join
(
self
.
get_megvii3_root
(),
'brain'
,
'midout'
,
'gen_header.py'
)
print
(
'use {} to gen bin_reduce header'
.
format
(
gen
))
gen
=
os
.
path
.
join
(
self
.
get_megvii3_root
(),
"brain"
,
"midout"
,
"gen_header.py"
)
print
(
"use {} to gen bin_reduce header"
.
format
(
gen
))
cvt
=
subprocess
.
run
(
[
gen
]
+
self
.
_midout_files
,
stdout
=
subprocess
.
PIPE
,
check
=
True
,
).
stdout
.
decode
(
'utf-8'
)
self
.
_fout
.
write
(
'// midout
\n
'
)
[
gen
]
+
self
.
_midout_files
,
stdout
=
subprocess
.
PIPE
,
check
=
True
,
).
stdout
.
decode
(
"utf-8"
)
self
.
_fout
.
write
(
"// midout
\n
"
)
self
.
_fout
.
write
(
cvt
)
if
cvt
.
find
(
" half,"
)
>
0
:
change
=
open
(
self
.
_fout
.
name
).
read
().
replace
(
" half,"
,
" __fp16,"
)
...
...
@@ -212,28 +235,35 @@ class HeaderGen:
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
'generate header file for reducing binary size by '
'stripping unused oprs in a particular network; output file would '
'be written to bin_reduce.h'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
description
=
"generate header file for reducing binary size by "
"stripping unused oprs in a particular network; output file would "
"be written to bin_reduce.h"
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
'inputs'
,
nargs
=
'+'
,
help
=
'input files that describe specific traits of the network; '
'can be one of the following:'
' 1. json files generated by '
'megbrain.serialize_comp_graph_to_file() in python; '
' 2. trace files generated by midout library'
)
default_file
=
os
.
path
.
join
(
HeaderGen
.
get_megengine_root
(),
'src'
,
'bin_reduce_cmake.h'
)
"inputs"
,
nargs
=
"+"
,
help
=
"input files that describe specific traits of the network; "
"can be one of the following:"
" 1. json files generated by "
"megbrain.serialize_comp_graph_to_file() in python; "
" 2. trace files generated by midout library"
,
)
default_file
=
os
.
path
.
join
(
HeaderGen
.
get_megengine_root
(),
"src"
,
"bin_reduce_cmake.h"
)
is_megvii3
=
HeaderGen
.
get_megvii3_root
()
if
is_megvii3
:
default_file
=
os
.
path
.
join
(
HeaderGen
.
get_megvii3_root
(),
'utils'
,
'bin_reduce.h'
)
parser
.
add_argument
(
'-o'
,
'--output'
,
help
=
'output file'
,
default
=
default_file
)
default_file
=
os
.
path
.
join
(
HeaderGen
.
get_megvii3_root
(),
"utils"
,
"bin_reduce.h"
)
parser
.
add_argument
(
"-o"
,
"--output"
,
help
=
"output file"
,
default
=
default_file
)
args
=
parser
.
parse_args
()
print
(
'config output file: {}'
.
format
(
args
.
output
))
print
(
"config output file: {}"
.
format
(
args
.
output
))
gen
=
HeaderGen
()
for
i
in
args
.
inputs
:
print
(
'==== processing {}'
.
format
(
i
))
print
(
"==== processing {}"
.
format
(
i
))
with
open
(
i
)
as
fin
:
if
fin
.
read
(
len
(
MIDOUT_TRACE_MAGIC
))
==
MIDOUT_TRACE_MAGIC
:
gen
.
extend_midout
(
i
)
...
...
@@ -241,8 +271,9 @@ def main():
fin
.
seek
(
0
)
gen
.
extend_netinfo
(
json
.
loads
(
fin
.
read
()))
with
open
(
args
.
output
,
'w'
)
as
fout
:
with
open
(
args
.
output
,
"w"
)
as
fout
:
gen
.
generate
(
fout
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录