Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
53075cd3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
53075cd3
编写于
2月 25, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/experimental): add visualization and net stats for python graph
GitOrigin-RevId: a1ab77c20aff8b9205fb3b34532e8f86a2733d69
上级
ae3123b3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
402 addition
and
172 deletion
+402
-172
imperative/python/megengine/tools/README.md
imperative/python/megengine/tools/README.md
+8
-0
imperative/python/megengine/tools/__init__.py
imperative/python/megengine/tools/__init__.py
+0
-0
imperative/python/megengine/tools/compare_binary_iodump.py
imperative/python/megengine/tools/compare_binary_iodump.py
+50
-6
imperative/python/megengine/tools/network_visualize.py
imperative/python/megengine/tools/network_visualize.py
+176
-0
imperative/python/megengine/tools/profile_analyze.py
imperative/python/megengine/tools/profile_analyze.py
+1
-1
imperative/python/megengine/utils/module_stats.py
imperative/python/megengine/utils/module_stats.py
+121
-103
imperative/python/megengine/utils/network.py
imperative/python/megengine/utils/network.py
+2
-4
imperative/python/megengine/utils/network_node.py
imperative/python/megengine/utils/network_node.py
+44
-1
imperative/python/megengine/utils/plugin.py
imperative/python/megengine/utils/plugin.py
+0
-57
未找到文件。
imperative/python/megengine/tools/README.md
0 → 100644
浏览文件 @
53075cd3
# MegEngine Tools
This directory contains executable python files.
Use these files in the following way (replace
`xxx`
to specific file name, like
`network_visualize`
):
```
python -m megengine.tools.xxx
```
imperative/python/megengine/tools/__init__.py
0 → 100644
浏览文件 @
53075cd3
imperative/python/megengine/
uti
ls/compare_binary_iodump.py
→
imperative/python/megengine/
too
ls/compare_binary_iodump.py
浏览文件 @
53075cd3
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...
...
@@ -7,12 +8,55 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
argparse
import
os
import
struct
import
textwrap
from
pathlib
import
Path
import
numpy
as
np
from
megengine.utils
import
plugin
def
load_tensor_binary
(
fobj
):
"""
Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
tensor value dump is implemented by ``mgb::debug::dump_tensor``.
:param fobj: file object, or a string that contains the file name.
:return: tuple ``(tensor_value, tensor_name)``.
"""
if
isinstance
(
fobj
,
str
):
with
open
(
fobj
,
"rb"
)
as
fin
:
return
load_tensor_binary
(
fin
)
DTYPE_LIST
=
{
0
:
np
.
float32
,
1
:
np
.
uint8
,
2
:
np
.
int8
,
3
:
np
.
int16
,
4
:
np
.
int32
,
# 5: _mgb.intb1,
# 6: _mgb.intb2,
# 7: _mgb.intb4,
8
:
None
,
9
:
np
.
float16
,
# quantized dtype start from 100000
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in
# dnn/include/megdnn/dtype.h
100000
:
np
.
uint8
,
100001
:
np
.
int32
,
100002
:
np
.
int8
,
}
header_fmt
=
struct
.
Struct
(
"III"
)
name_len
,
dtype
,
max_ndim
=
header_fmt
.
unpack
(
fobj
.
read
(
header_fmt
.
size
))
assert
(
DTYPE_LIST
[
dtype
]
is
not
None
),
"Cannot load this tensor: dtype Byte is unsupported."
shape
=
list
(
struct
.
unpack
(
"I"
*
max_ndim
,
fobj
.
read
(
max_ndim
*
4
)))
while
shape
[
-
1
]
==
0
:
shape
.
pop
(
-
1
)
name
=
fobj
.
read
(
name_len
).
decode
(
"ascii"
)
return
np
.
fromfile
(
fobj
,
dtype
=
DTYPE_LIST
[
dtype
]).
reshape
(
shape
),
name
def
check
(
v0
,
v1
,
name
,
max_err
):
...
...
@@ -26,9 +70,9 @@ def check(v0, v1, name, max_err):
)
vdiv
=
np
.
max
([
np
.
abs
(
v0
),
np
.
abs
(
v1
),
np
.
ones_like
(
v0
)],
axis
=
0
)
err
=
np
.
abs
(
v0
-
v1
)
/
vdiv
check
=
err
>
max_err
if
check
.
sum
():
idx
=
tuple
(
i
[
0
]
for
i
in
np
.
nonzero
(
check
))
rst
=
err
>
max_err
if
rst
.
sum
():
idx
=
tuple
(
i
[
0
]
for
i
in
np
.
nonzero
(
rst
))
raise
AssertionError
(
"{} not equal: "
"shape={} nonequal_idx={} v0={} v1={} err={}"
.
format
(
...
...
@@ -79,8 +123,8 @@ def main():
files1
=
sorted
(
files1
)
for
i
,
j
in
zip
(
files0
,
files1
):
val0
,
name0
=
plugin
.
load_tensor_binary
(
i
)
val1
,
name1
=
plugin
.
load_tensor_binary
(
j
)
val0
,
name0
=
load_tensor_binary
(
i
)
val1
,
name1
=
load_tensor_binary
(
j
)
name
=
"{}:
\n
{}
\n
{}
\n
"
.
format
(
i
,
"
\n
"
.
join
(
textwrap
.
wrap
(
name0
)),
"
\n
"
.
join
(
textwrap
.
wrap
(
name1
))
)
...
...
imperative/python/megengine/tools/network_visualize.py
0 → 100755
浏览文件 @
53075cd3
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
argparse
import
numpy
as
np
from
megengine.core.tensor.dtype
import
is_quantize
from
megengine.logger
import
get_logger
from
megengine.utils.module_stats
import
(
print_flops_stats
,
print_params_stats
,
sizeof_fmt
,
)
from
megengine.utils.network
import
Network
logger
=
get_logger
(
__name__
)
def
visualize
(
model_path
:
str
,
log_path
:
str
,
bar_length_max
:
int
=
20
,
log_params
:
bool
=
True
,
log_flops
:
bool
=
True
,
):
r
"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
Can also record and print model's statistics like :func:`~.net_stats`
:param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
try
:
from
tensorboard.compat.proto.attr_value_pb2
import
AttrValue
from
tensorboard.compat.proto.config_pb2
import
RunMetadata
from
tensorboard.compat.proto.graph_pb2
import
GraphDef
from
tensorboard.compat.proto.node_def_pb2
import
NodeDef
from
tensorboard.compat.proto.step_stats_pb2
import
(
AllocatorMemoryUsed
,
DeviceStepStats
,
NodeExecStats
,
StepStats
,
)
from
tensorboard.compat.proto.tensor_shape_pb2
import
TensorShapeProto
from
tensorboard.compat.proto.versions_pb2
import
VersionDef
from
tensorboardX
import
SummaryWriter
except
ImportError
:
logger
.
error
(
"TensorBoard and TensorboardX are required for visualize."
,
exc_info
=
True
)
return
graph
=
Network
.
load
(
model_path
)
writer
=
SummaryWriter
(
log_path
)
def
process_name
(
name
):
return
name
.
replace
(
"."
,
"/"
).
encode
(
encoding
=
"utf-8"
)
node_list
=
[]
flops_list
=
[]
params_list
=
[]
for
node
in
graph
.
all_oprs
:
if
hasattr
(
node
,
"output_idx"
):
node_oup
=
node
.
outputs
[
node
.
output_idx
]
else
:
if
len
(
node
.
outputs
)
!=
1
:
logger
.
warning
(
"OpNode {} has more than one output and not has 'output_idx' attr."
.
format
(
node
)
)
node_oup
=
node
.
outputs
[
0
]
inp_list
=
[
process_name
(
var
.
owner
.
name
)
for
var
in
node
.
inputs
]
attr
=
{
"_output_shapes"
:
AttrValue
(
list
=
AttrValue
.
ListValue
(
shape
=
[
TensorShapeProto
(
dim
=
[
TensorShapeProto
.
Dim
(
size
=
d
)
for
d
in
node_oup
.
shape
]
)
]
)
),
}
if
hasattr
(
node
,
"calc_flops"
):
flops_num
=
node
.
calc_flops
()
# add op flops attr
attr
[
"flops"
]
=
AttrValue
(
s
=
sizeof_fmt
(
flops_num
).
encode
(
encoding
=
"utf-8"
))
flops_list
.
append
(
dict
(
name
=
node
.
name
,
class_name
=
node
.
type
,
input_shapes
=
[
i
.
shape
for
i
in
node
.
inputs
],
output_shapes
=
[
o
.
shape
for
o
in
node
.
outputs
],
flops_num
=
flops_num
,
flops_cum
=
0
,
)
)
if
node
.
type
==
"ImmutableTensor"
:
param_dim
=
np
.
prod
(
node_oup
.
shape
)
# TODO: consider other quantize dtypes
param_bytes
=
1
if
is_quantize
(
node_oup
.
dtype
)
else
4
# add tensor size attr
attr
[
"size"
]
=
AttrValue
(
s
=
sizeof_fmt
(
param_dim
*
param_bytes
).
encode
(
encoding
=
"utf-8"
)
)
params_list
.
append
(
dict
(
name
=
node
.
name
,
shape
=
node_oup
.
shape
,
param_dim
=
param_dim
,
bits
=
param_bytes
*
8
,
size
=
param_dim
*
param_bytes
,
size_cum
=
0
,
mean
=
"{:.2g}"
.
format
(
node
.
numpy
().
mean
()),
std
=
"{:.2g}"
.
format
(
node
.
numpy
().
std
()),
)
)
node_list
.
append
(
NodeDef
(
name
=
process_name
(
node
.
name
),
op
=
node
.
type
,
input
=
inp_list
,
attr
=
attr
,
)
)
total_flops
,
total_params
=
0
,
0
if
log_params
:
total_params
=
print_params_stats
(
params_list
,
bar_length_max
)
if
log_flops
:
total_flops
=
print_flops_stats
(
flops_list
,
bar_length_max
)
graph_def
=
GraphDef
(
node
=
node_list
,
versions
=
VersionDef
(
producer
=
22
))
device
=
"/device:CPU:0"
stepstats
=
RunMetadata
(
step_stats
=
StepStats
(
dev_stats
=
[
DeviceStepStats
(
device
=
device
)])
)
writer
.
_get_file_writer
().
add_graph
((
graph_def
,
stepstats
))
return
total_params
,
total_flops
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"load a megengine dumped model and export log file for tensorboard visualization."
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
"model_path"
,
help
=
"dumped model path."
)
parser
.
add_argument
(
"log_path"
,
help
=
"tensorboard log path."
)
parser
.
add_argument
(
"--bar_length_max"
,
type
=
int
,
default
=
20
,
help
=
"size of bar indicating max flops or parameter size in net stats."
,
)
parser
.
add_argument
(
"--log_params"
,
action
=
"store_true"
,
help
=
"whether print and record params size."
,
)
parser
.
add_argument
(
"--log_flops"
,
action
=
"store_true"
,
help
=
"whether print and record op flops."
,
)
visualize
(
**
vars
(
parser
.
parse_args
()))
if
__name__
==
"__main__"
:
main
()
imperative/python/megengine/
uti
ls/profile_analyze.py
→
imperative/python/megengine/
too
ls/profile_analyze.py
浏览文件 @
53075cd3
#
-*- coding: utf-8 -*-
#
! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...
...
imperative/python/megengine/utils/
net
_stats.py
→
imperative/python/megengine/utils/
module
_stats.py
浏览文件 @
53075cd3
...
...
@@ -84,26 +84,125 @@ hook_modules = (
)
def
net_stats
(
model
,
input_size
,
bar_length_max
=
20
,
log_params
=
True
,
log_flops
=
True
):
def
dict2table
(
list_of_dict
,
header
):
table_data
=
[
header
]
for
d
in
list_of_dict
:
row
=
[]
for
h
in
header
:
v
=
""
if
h
in
d
:
v
=
d
[
h
]
row
.
append
(
v
)
table_data
.
append
(
row
)
return
table_data
def
sizeof_fmt
(
num
,
suffix
=
"B"
):
for
unit
in
[
""
,
"Ki"
,
"Mi"
,
"Gi"
,
"Ti"
,
"Pi"
,
"Ei"
,
"Zi"
]:
if
abs
(
num
)
<
1024.0
:
return
"{:3.3f} {}{}"
.
format
(
num
,
unit
,
suffix
)
num
/=
1024.0
sign_str
=
"-"
if
num
<
0
else
""
return
"{}{:.1f} {}{}"
.
format
(
sign_str
,
num
,
"Yi"
,
suffix
)
def
dict2table
(
list_of_dict
,
header
):
table_data
=
[
header
]
for
d
in
list_of_dict
:
row
=
[]
for
h
in
header
:
v
=
""
if
h
in
d
:
v
=
d
[
h
]
row
.
append
(
v
)
table_data
.
append
(
row
)
return
table_data
def
sizeof_fmt
(
num
,
suffix
=
"B"
):
for
unit
in
[
""
,
"Ki"
,
"Mi"
,
"Gi"
,
"Ti"
,
"Pi"
,
"Ei"
,
"Zi"
]:
if
abs
(
num
)
<
1024.0
:
return
"{:3.3f} {}{}"
.
format
(
num
,
unit
,
suffix
)
num
/=
1024.0
sign_str
=
"-"
if
num
<
0
else
""
return
"{}{:.1f} {}{}"
.
format
(
sign_str
,
num
,
"Yi"
,
suffix
)
def
print_flops_stats
(
flops
,
bar_length_max
=
20
):
flops_list
=
[
i
[
"flops_num"
]
for
i
in
flops
]
max_flops_num
=
max
(
flops_list
+
[
0
])
# calc total flops and set flops_cum
total_flops_num
=
0
for
d
in
flops
:
total_flops_num
+=
int
(
d
[
"flops_num"
])
d
[
"flops_cum"
]
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
for
i
in
flops
:
f
=
i
[
"flops_num"
]
i
[
"flops"
]
=
sizeof_fmt
(
f
,
suffix
=
"OPs"
)
r
=
i
[
"ratio"
]
=
f
/
total_flops_num
i
[
"percentage"
]
=
"{:.2f}%"
.
format
(
r
*
100
)
bar_length
=
int
(
f
/
max_flops_num
*
bar_length_max
)
i
[
"bar"
]
=
"#"
*
bar_length
header
=
[
"name"
,
"class_name"
,
"input_shapes"
,
"output_shapes"
,
"flops"
,
"flops_cum"
,
"percentage"
,
"bar"
,
]
total_flops_str
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
total_var_size
=
sum
(
sum
(
s
[
1
]
for
s
in
i
[
"output_shapes"
])
for
i
in
flops
)
flops
.
append
(
dict
(
name
=
"total"
,
flops
=
total_flops_str
,
output_shapes
=
total_var_size
)
)
logger
.
info
(
"flops stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
flops
,
header
=
header
)))
return
total_flops_num
def
print_params_stats
(
params
,
bar_length_max
=
20
):
total_param_dims
,
total_param_size
=
0
,
0
for
d
in
params
:
total_param_dims
+=
int
(
d
[
"param_dim"
])
total_param_size
+=
int
(
d
[
"size"
])
d
[
"size"
]
=
sizeof_fmt
(
d
[
"size"
])
d
[
"size_cum"
]
=
sizeof_fmt
(
total_param_size
)
for
d
in
params
:
ratio
=
d
[
"param_dim"
]
/
total_param_dims
d
[
"ratio"
]
=
ratio
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
ratio
*
100
)
# construct bar
max_ratio
=
max
([
d
[
"ratio"
]
for
d
in
params
])
for
d
in
params
:
bar_length
=
int
(
d
[
"ratio"
]
/
max_ratio
*
bar_length_max
)
d
[
"size_bar"
]
=
"#"
*
bar_length
param_size
=
sizeof_fmt
(
total_param_size
)
params
.
append
(
dict
(
name
=
"total"
,
param_dim
=
total_param_dims
,
size
=
param_size
,))
header
=
[
"name"
,
"shape"
,
"mean"
,
"std"
,
"param_dim"
,
"bits"
,
"size"
,
"size_cum"
,
"percentage"
,
"size_bar"
,
]
logger
.
info
(
"param stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
params
,
header
=
header
))
)
return
total_param_size
def
net_stats
(
model
:
m
.
Module
,
input_size
:
int
,
bar_length_max
:
int
=
20
,
log_params
:
bool
=
True
,
log_flops
:
bool
=
True
,
):
r
"""
Calculate and print ``model``'s statistics by adding hook and record Module's inputs outputs size.
:param model: model that need to get stats info.
:param input_size: size of input for running model and calculating stats.
:param bar_length_max: size of bar indicating max flops or parameter size in net stats.
:param log_params: whether print and record params size.
:param log_flops: whether print and record op flops.
"""
def
get_byteswidth
(
tensor
):
if
dtype
.
is_quantize
(
tensor
.
dtype
):
...
...
@@ -113,87 +212,6 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T
else
:
return
4
def
print_flops_stats
(
flops
):
flops_list
=
[
i
[
"flops_num"
]
for
i
in
flops
]
max_flops_num
=
max
(
flops_list
+
[
0
])
# calc total flops and set flops_cum
total_flops_num
=
0
for
d
in
flops
:
total_flops_num
+=
int
(
d
[
"flops_num"
])
d
[
"flops_cum"
]
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
for
i
in
flops
:
f
=
i
[
"flops_num"
]
i
[
"flops"
]
=
sizeof_fmt
(
f
,
suffix
=
"OPs"
)
r
=
i
[
"ratio"
]
=
f
/
total_flops_num
i
[
"percentage"
]
=
"{:.2f}%"
.
format
(
r
*
100
)
bar_length
=
int
(
f
/
max_flops_num
*
bar_length_max
)
i
[
"bar"
]
=
"#"
*
bar_length
header
=
[
"name"
,
"class_name"
,
"input_shapes"
,
"output_shapes"
,
"flops"
,
"flops_cum"
,
"percentage"
,
"bar"
,
]
total_flops_str
=
sizeof_fmt
(
total_flops_num
,
suffix
=
"OPs"
)
total_var_size
=
sum
(
sum
(
s
[
1
]
for
s
in
i
[
"output_shapes"
])
for
i
in
flops
)
flops
.
append
(
dict
(
name
=
"total"
,
flops
=
total_flops_str
,
output_shapes
=
total_var_size
)
)
logger
.
info
(
"flops stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
flops
,
header
=
header
))
)
return
total_flops_num
def
print_params_stats
(
params
):
total_param_dims
,
total_param_size
=
0
,
0
for
d
in
params
:
total_param_dims
+=
int
(
d
[
"param_dim"
])
total_param_size
+=
int
(
d
[
"size"
])
d
[
"size"
]
=
sizeof_fmt
(
d
[
"size"
])
d
[
"size_cum"
]
=
sizeof_fmt
(
total_param_size
)
for
d
in
params
:
ratio
=
d
[
"param_dim"
]
/
total_param_dims
d
[
"ratio"
]
=
ratio
d
[
"percentage"
]
=
"{:.2f}%"
.
format
(
ratio
*
100
)
# construct bar
max_ratio
=
max
([
d
[
"ratio"
]
for
d
in
params
])
for
d
in
params
:
bar_length
=
int
(
d
[
"ratio"
]
/
max_ratio
*
bar_length_max
)
d
[
"size_bar"
]
=
"#"
*
bar_length
param_size
=
sizeof_fmt
(
total_param_size
)
params
.
append
(
dict
(
name
=
"total"
,
param_dim
=
total_param_dims
,
size
=
param_size
,))
header
=
[
"name"
,
"shape"
,
"mean"
,
"std"
,
"param_dim"
,
"bits"
,
"size"
,
"size_cum"
,
"percentage"
,
"size_bar"
,
]
logger
.
info
(
"param stats:
\n
"
+
tabulate
.
tabulate
(
dict2table
(
params
,
header
=
header
))
)
return
total_param_size
def
net_stats_hook
(
module
,
input
,
output
,
name
=
""
):
class_name
=
str
(
module
.
__class__
).
split
(
"."
)[
-
1
].
split
(
"'"
)[
0
]
...
...
@@ -273,8 +291,8 @@ def net_stats(model, input_size, bar_length_max=20, log_params=True, log_flops=T
total_flops
,
total_params
=
0
,
0
if
log_params
:
total_params
=
print_params_stats
(
params
)
total_params
=
print_params_stats
(
params
,
bar_length_max
)
if
log_flops
:
total_flops
=
print_flops_stats
(
flops
)
total_flops
=
print_flops_stats
(
flops
,
bar_length_max
)
return
total_params
,
total_flops
imperative/python/megengine/utils/network.py
浏览文件 @
53075cd3
...
...
@@ -19,9 +19,9 @@ from ..core._imperative_rt import ComputingGraph
from
..core.tensor
import
megbrain_graph
as
G
from
.comp_graph_tools
import
get_dep_vars
,
get_opr_type
,
get_oprs_seq
from
.network_node
import
(
NetworkNode
,
Host2DeviceCopy
,
ImmutableTensor
,
NetworkNode
,
OpNode
,
VarNode
,
str_to_mge_class
,
...
...
@@ -606,9 +606,7 @@ class NodeFilterType(NodeFilter):
_node_type
=
None
def
__init__
(
self
,
node_iter
,
node_type
):
assert
issubclass
(
node_type
,
NetworkNode
),
"bad opr type: {}"
.
format
(
node_type
)
assert
issubclass
(
node_type
,
NetworkNode
),
"bad opr type: {}"
.
format
(
node_type
)
super
().
__init__
(
node_iter
)
self
.
_node_type
=
node_type
...
...
imperative/python/megengine/utils/network_node.py
浏览文件 @
53075cd3
...
...
@@ -10,6 +10,8 @@ import json
import
sys
from
typing
import
Callable
import
numpy
as
np
from
..core
import
_imperative_rt
as
rt
from
..core._wrap
import
Device
from
..core.ops
import
builtin
...
...
@@ -52,7 +54,7 @@ class VarNode(NetworkNode):
return
self
.
var
.
dtype
if
self
.
var
else
None
def
set_owner_opr
(
self
,
owner_opr
):
self
.
owner
_opr
=
owner_opr
self
.
owner
=
owner_opr
class
OpNode
(
NetworkNode
):
...
...
@@ -223,6 +225,9 @@ class Elemwise(OpNode):
type
=
"Elemwise"
opdef
=
builtin
.
Elemwise
def
calc_flops
(
self
):
return
np
.
prod
(
self
.
outputs
[
0
].
shape
)
class
Reduce
(
OpNode
):
type
=
"Reduce"
...
...
@@ -250,11 +255,21 @@ class MatrixMul(OpNode):
type
=
"MatrixMul"
opdef
=
builtin
.
MatrixMul
def
calc_flops
(
self
):
assert
len
(
self
.
inputs
[
0
].
shape
)
==
2
and
len
(
self
.
outputs
[
0
].
shape
)
==
2
mid_shape
=
self
.
inputs
[
0
].
shape
[
1
]
return
np
.
prod
(
self
.
outputs
[
0
].
shape
)
*
mid_shape
class
BatchedMatrixMul
(
OpNode
):
type
=
"BatchedMatmul"
opdef
=
builtin
.
BatchedMatrixMul
def
calc_flops
(
self
):
assert
len
(
self
.
inputs
[
0
].
shape
)
==
3
and
len
(
self
.
outputs
[
0
].
shape
)
==
3
mid_shape
=
self
.
inputs
[
0
].
shape
[
2
]
return
np
.
prod
(
self
.
outputs
[
0
].
shape
)
*
mid_shape
class
Dot
(
OpNode
):
type
=
"Dot"
...
...
@@ -270,6 +285,18 @@ class ConvolutionForward(OpNode):
type
=
"Convolution"
opdef
=
builtin
.
Convolution
def
calc_flops
(
self
):
param_W_shape
=
self
.
inputs
[
1
].
shape
kh
=
param_W_shape
[
-
2
]
kw
=
param_W_shape
[
-
1
]
if
len
(
param_W_shape
)
==
5
:
num_input
=
param_W_shape
[
2
]
else
:
num_input
=
param_W_shape
[
1
]
NCHW
=
np
.
prod
(
self
.
outputs
[
0
].
shape
)
# N x Cout x H x W x (Cin x Kw x Kh)
return
NCHW
*
(
num_input
*
kw
*
kh
)
class
ConvolutionBackwardData
(
OpNode
):
type
=
"ConvTranspose"
...
...
@@ -316,6 +343,18 @@ class ConvBiasForward(OpNode):
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
return
obj
def
calc_flops
(
self
):
param_W_shape
=
self
.
inputs
[
1
].
shape
kh
=
param_W_shape
[
-
2
]
kw
=
param_W_shape
[
-
1
]
if
len
(
param_W_shape
)
==
5
:
num_input
=
param_W_shape
[
2
]
else
:
num_input
=
param_W_shape
[
1
]
NCHW
=
np
.
prod
(
self
.
outputs
[
0
].
shape
)
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return
NCHW
*
(
num_input
*
kw
*
kh
+
1
)
class
BatchConvBiasForward
(
OpNode
):
type
=
"BatchConvBias"
...
...
@@ -331,6 +370,7 @@ class BatchConvBiasForward(OpNode):
class
BatchNormForward
(
OpNode
):
type
=
"BatchNorm"
opdef
=
builtin
.
BatchNorm
output_idx
=
-
1
class
ROIAlignForward
(
OpNode
):
...
...
@@ -622,6 +662,9 @@ class ElemwiseMultiType(OpNode):
obj
.
params
[
"dtype"
]
=
opr
.
outputs
[
0
].
dtype
return
obj
def
calc_flops
(
self
):
return
np
.
prod
(
self
.
outputs
[
0
].
shape
)
class
CvtColorForward
(
OpNode
):
type
=
"CvtColor"
...
...
imperative/python/megengine/utils/plugin.py
已删除
100644 → 0
浏览文件 @
ae3123b3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
struct
import
numpy
as
np
def
load_tensor_binary
(
fobj
):
"""
Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
tensor value dump is implemented by ``mgb::debug::dump_tensor``.
Multiple values can be compared by ``tools/compare_binary_iodump.py``.
:param fobj: file object, or a string that contains the file name.
:return: tuple ``(tensor_value, tensor_name)``.
"""
if
isinstance
(
fobj
,
str
):
with
open
(
fobj
,
"rb"
)
as
fin
:
return
load_tensor_binary
(
fin
)
DTYPE_LIST
=
{
0
:
np
.
float32
,
1
:
np
.
uint8
,
2
:
np
.
int8
,
3
:
np
.
int16
,
4
:
np
.
int32
,
# 5: _mgb.intb1,
# 6: _mgb.intb2,
# 7: _mgb.intb4,
8
:
None
,
9
:
np
.
float16
,
# quantized dtype start from 100000
# see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in
# dnn/include/megdnn/dtype.h
100000
:
np
.
uint8
,
100001
:
np
.
int32
,
100002
:
np
.
int8
,
}
header_fmt
=
struct
.
Struct
(
"III"
)
name_len
,
dtype
,
max_ndim
=
header_fmt
.
unpack
(
fobj
.
read
(
header_fmt
.
size
))
assert
(
DTYPE_LIST
[
dtype
]
is
not
None
),
"Cannot load this tensor: dtype Byte is unsupported."
shape
=
list
(
struct
.
unpack
(
"I"
*
max_ndim
,
fobj
.
read
(
max_ndim
*
4
)))
while
shape
[
-
1
]
==
0
:
shape
.
pop
(
-
1
)
name
=
fobj
.
read
(
name_len
).
decode
(
"ascii"
)
return
np
.
fromfile
(
fobj
,
dtype
=
DTYPE_LIST
[
dtype
]).
reshape
(
shape
),
name
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录