Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4f6c5d8f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4f6c5d8f
编写于
10月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/dump): enable jit.dump to dump with testcase
GitOrigin-RevId: 5dce3564529c9a04f118a599637237f68a101e77
上级
182ca25d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
384 addition
and
535 deletion
+384
-535
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+374
-0
imperative/python/test/unit/jit/test_tracing.py
imperative/python/test/unit/jit/test_tracing.py
+10
-0
sdk/load-and-run/dump_with_testcase_mge.py
sdk/load-and-run/dump_with_testcase_mge.py
+0
-535
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
4f6c5d8f
...
...
@@ -13,10 +13,17 @@ import itertools
import
json
import
os
import
pickle
import
re
import
struct
from
typing
import
Any
import
cv2
import
numpy
as
np
from
megengine.logger
import
get_logger
from
..
import
tensor
from
..core
import
_imperative_rt
as
rt
from
..core._imperative_rt
import
GraphProfiler
,
GraphProfiler2
,
SerializationMetadata
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
(
...
...
@@ -38,12 +45,15 @@ from ..core._wrap import as_device
from
..core.ops.builtin
import
BatchNorm
,
OpDef
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.utils
import
setscalar
from
..utils
import
comp_graph_tools
as
cgtools
from
..utils.naming
import
AutoNaming
from
..utils.profiler
import
is_profiling
from
.dtr_config
import
DTRConfig
from
.graph_opt_config
import
GraphOptimizationConfig
from
.sublinear_memory_config
import
SublinearMemoryConfig
logger
=
get_logger
(
__name__
)
def
_input_node_use_static_shape
():
return
os
.
environ
.
get
(
"MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE"
)
is
not
None
...
...
@@ -692,6 +702,289 @@ class trace:
self
.
_process_outputs
(
outputs
)
return
outputs
def
_make_feed
(
self
,
graph
,
outputs
,
input_data
,
repeat
,
silent
,
no_assert
,
maxerr
,
resize_input
,
input_transform
,
):
def
auto_reformat_image
(
path
,
data
,
dst_shape
):
"""reformat image to target shape
:param data: image data as numpy array
:param dst_shape: target shape
"""
dim3_format
=
False
# required input format does not contain batch
hwc_format
=
False
# required input format is NHWC
if
not
dst_shape
:
# input tensor shape is not predefined
if
len
(
data
.
shape
)
==
2
:
chl
=
1
h
=
data
.
shape
[
0
]
w
=
data
.
shape
[
1
]
else
:
assert
(
len
(
data
.
shape
)
==
3
),
"Input image must be of dimension 2 or 3"
h
,
w
,
chl
=
data
.
shape
dst_shape
=
(
1
,
chl
,
h
,
w
)
if
len
(
dst_shape
)
==
3
:
dst_shape
=
(
1
,)
+
dst_shape
dim3_format
=
True
assert
len
(
dst_shape
)
==
4
,
"bad dst_shape: {}"
.
format
(
dst_shape
)
chl
=
dst_shape
[
1
]
if
chl
in
[
1
,
3
]:
n
,
c
,
h
,
w
=
dst_shape
dst_shape
=
(
n
,
h
,
w
,
c
)
else
:
chl
=
dst_shape
[
3
]
assert
chl
in
[
1
,
3
,
],
"can not infer input format from shape: {}"
.
format
(
dst_shape
)
hwc_format
=
True
# dst_shape has now been normalized to NHWC format
if
resize_input
:
h
,
w
=
dst_shape
[
1
:
3
]
data
=
cv2
.
resize
(
data
,
(
w
,
h
))
logger
.
info
(
"input {} resized to {}"
.
format
(
path
,
data
.
shape
))
if
chl
==
1
:
data
=
cv2
.
cvtColor
(
data
,
cv2
.
COLOR_BGR2GRAY
)
data
=
data
[:,
:,
np
.
newaxis
]
assert
data
.
ndim
==
3
data
=
data
[
np
.
newaxis
]
# data normalized to NHWC format
if
not
hwc_format
:
data
=
np
.
transpose
(
data
,
(
0
,
3
,
1
,
2
))
if
dim3_format
:
data
=
np
.
squeeze
(
data
,
0
)
return
data
def
read_input_data
(
dst_shape
,
dtype
,
path
):
def
check_shape_equal
(
dst_shape
,
data_shape
):
if
len
(
dst_shape
):
assert
len
(
data_shape
)
==
len
(
dst_shape
),
"input/data shapes mismatch: {} vs {}"
.
format
(
dst_shape
,
data_shape
)
if
data_shape
[
1
:]
!=
dst_shape
[
1
:]:
logger
.
warning
(
"dst_shape is {}; data_shape is {}"
.
format
(
dst_shape
,
data_shape
)
)
if
path
.
startswith
(
"#"
):
assert
not
resize_input
assert
not
input_transform
spec
=
path
m
=
re
.
match
(
r
"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$"
,
spec
)
assert
m
,
"bad spec {}"
.
format
(
spec
)
rng_min
=
float
(
m
.
group
(
1
))
rng_max
=
float
(
m
.
group
(
2
))
if
m
.
group
(
3
):
shape_str
=
m
.
group
(
3
)
try
:
shape
=
shape_str
[
1
:].
split
(
","
)
if
shape
[
-
1
].
strip
()
==
"..."
:
shape
=
shape
[:
-
1
]
shape
.
extend
(
list
(
dst_shape
[
len
(
shape
)
:]))
data_shape
=
tuple
(
map
(
int
,
shape
))
except
ValueError
as
e
:
raise
ValueError
(
"bad spec {}: {}"
.
format
(
spec
,
e
.
args
))
else
:
data_shape
=
dst_shape
check_shape_equal
(
dst_shape
,
data_shape
)
return
np
.
random
.
uniform
(
rng_min
,
rng_max
,
data_shape
).
astype
(
dtype
)
# try to load image
data
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_COLOR
)
if
data
is
None
:
assert
not
resize_input
data
=
np
.
load
(
path
)
assert
isinstance
(
data
,
np
.
ndarray
)
else
:
# load image succeeds, so we expect input format is image format
data
=
auto_reformat_image
(
path
,
data
,
dst_shape
)
data
=
np
.
repeat
(
data
,
repeat
,
axis
=
0
)
if
repeat
>
1
:
logger
.
info
(
"repeat input for {} times, data shape is {}"
.
format
(
repeat
,
data
.
shape
)
)
check_shape_equal
(
dst_shape
,
data
.
shape
)
if
input_transform
:
data
=
eval
(
input_transform
,
{
"data"
:
data
,
"np"
:
np
})
return
data
def
gen_one_testcase
(
inputs
,
spec
):
paths
=
spec
.
split
(
";"
)
if
len
(
paths
)
!=
len
(
inputs
):
if
len
(
paths
)
==
1
and
paths
[
0
].
startswith
(
"#"
):
paths
=
[
"{}:{}"
.
format
(
name
,
paths
[
0
])
for
name
in
inputs
.
keys
()]
assert
len
(
paths
)
==
len
(
inputs
),
"required inputs: {}; data paths: {}"
.
format
(
inputs
.
keys
(),
paths
)
if
len
(
paths
)
==
1
and
":"
not
in
paths
[
0
]:
paths
[
0
]
=
next
(
iter
(
inputs
.
keys
()))
+
":"
+
paths
[
0
]
ret
=
{}
for
path
in
paths
:
var
,
path
=
path
.
split
(
":"
)
ret
[
var
]
=
read_input_data
(
inputs
[
var
].
shape
,
inputs
[
var
].
dtype
,
path
)
return
ret
inputs
=
cgtools
.
get_dep_vars
(
outputs
,
"Host2DeviceCopy"
)
inputs
=
{
i
.
name
:
i
for
i
in
inputs
}
if
not
no_assert
:
replace_varmap
=
{}
inp_map
=
{}
# replace var use InputNode
for
name
,
var
in
inputs
.
items
():
inp
=
G
.
InputNode
(
device
=
"xpux"
,
dtype
=
var
.
dtype
,
shape
=
var
.
shape
,
graph
=
graph
)
replace_varmap
[
var
]
=
inp
.
outputs
[
0
].
_node
inp_map
[
name
]
=
inp
new
=
cgtools
.
replace_vars
(
outputs
,
replace_varmap
)
if
isinstance
(
new
,
rt
.
VarNode
):
new
=
list
(
new
)
output_nodes
=
[
G
.
OutputNode
(
var
)
for
var
in
new
]
func
=
graph
.
compile
(
*
[
node
.
outputs
[
0
].
_node
for
node
in
output_nodes
])
def
make_dev_tensor
(
value
,
dtype
=
None
,
device
=
None
):
return
tensor
(
value
,
dtype
=
dtype
,
device
=
device
).
_dev_tensor
()
def
calculate
(
*
args
,
**
kwargs
):
output_val
=
[]
# set inputs value
for
name
,
var
in
inputs
.
items
():
val
=
kwargs
.
pop
(
name
,
None
)
assert
val
is
not
None
,
"miss input name{}"
.
format
(
name
)
dev_tensor
=
make_dev_tensor
(
val
,
dtype
=
var
.
dtype
,
device
=
"xpux"
)
inp_map
[
name
].
set_value
(
dev_tensor
)
func
.
execute
()
for
res
in
output_nodes
:
output_val
.
append
(
res
.
get_value
().
numpy
())
return
output_val
def
expect_name
(
var
):
return
"{}:expect"
.
format
(
var
.
name
)
testcases
=
[]
np
.
set_printoptions
(
precision
=
2
,
threshold
=
4
,
suppress
=
True
)
data_list
=
[]
for
item
in
input_data
:
if
item
.
startswith
(
"@"
):
with
open
(
item
[
1
:],
"r"
)
as
f
:
data_list
.
extend
(
[
line
.
rstrip
()
for
line
in
f
if
line
.
rstrip
()
!=
""
]
)
else
:
data_list
.
append
(
item
)
for
inp_spec
in
data_list
:
cur_testcase
=
gen_one_testcase
(
inputs
,
inp_spec
)
assert
len
(
cur_testcase
)
==
len
(
inputs
),
"required inputs: {}; given data: {}"
.
format
(
inputs
.
keys
(),
cur_testcase
.
keys
()
)
if
not
no_assert
:
outputs_get
=
calculate
(
**
cur_testcase
)
for
var
,
val
in
zip
(
outputs
,
outputs_get
):
cur_testcase
[
expect_name
(
var
)]
=
val
logger
.
info
(
"generate test groundtruth: var={} shape={} range=({}, {})"
" mean={} var={}"
.
format
(
var
,
val
.
shape
,
val
.
min
(),
val
.
max
(),
np
.
mean
(
val
),
np
.
var
(
val
),
)
)
testcases
.
append
(
cur_testcase
)
logger
.
info
(
"add testcase:
\n
{}"
.
format
(
"
\n
"
.
join
(
"{}: shape={} dtype={} range=({:.2f},{:.2f}) "
"mean={:.2f} sd={:.2f}"
.
format
(
k
,
v
.
shape
,
v
.
dtype
,
v
.
min
(),
v
.
max
(),
np
.
mean
(
v
),
np
.
std
(
v
)
)
for
k
,
v
in
sorted
(
cur_testcase
.
items
())
)
)
)
if
not
no_assert
:
def
expect_shp
(
var
):
ret
=
var
.
shape
if
ret
:
return
ret
return
testcases
[
0
][
expect_name
(
var
)].
shape
def
assert_equal
(
expect
,
real
,
**
kwargs
):
op
=
AssertEqual
(
**
kwargs
)
(
res
,)
=
G
.
apply_normal_varnode
(
op
,
expect
,
real
)
return
res
.
_node
verbose
=
not
silent
outputs_new
=
[]
for
i
in
outputs
:
device
=
rt
.
CompNode
(
"xpux"
)
dtype
=
i
.
dtype
name
=
expect_name
(
i
)
shape
=
expect_shp
(
i
)
# make expect output as one input of model.
expect_get
=
rt
.
make_h2d
(
graph
,
device
,
dtype
,
shape
,
name
)
# insert assert opr to check expect and real.
outputs_new
.
append
(
assert_equal
(
expect_get
,
i
,
verbose
=
verbose
,
maxerr
=
maxerr
,)
)
inputs
[
expect_name
(
i
)]
=
expect_get
outputs
=
outputs_new
return
{
"outputs"
:
outputs
,
"testcases"
:
testcases
}
def
dump
(
self
,
file
,
...
...
@@ -708,6 +1001,13 @@ class trace:
optimize_for_inference
=
True
,
user_info
:
Any
=
None
,
enable_metadata
:
bool
=
True
,
input_data
=
None
,
repeat
=
1
,
silent
=
False
,
no_assert
=
False
,
maxerr
=
1e-4
,
resize_input
=
False
,
input_transform
=
None
,
**
kwargs
):
r
"""Serializes trace to file system.
...
...
@@ -738,6 +1038,27 @@ class trace:
will skip all optimize options if this is False. Default: True
user_info: any type object, which will be pickled to bytes.
enable_metadata: whether to save metadata into output file.
input_data: input test data and current network output would be used as groundtruth.
The format is "var0:file0;var1:file1..." to specify data files for input vars.
It can also be "#rand(min,max,shape...)" for generating random input data, for
example, "#rand(0,255)", "#rand(0,255,1,3,224,224)" or "#rand(0, 255, 1, ...)"
where `...` means the remaining part of the original shape. If the shape is not
specified, the shape of corresponding input tensors in the network will be used.
If there is only one input var, its name can be omitted. Each data file can either
be an image which can be loaded by opencv, or a pickled numpy.ndarray. This option
can be given multiple times to add multiple testcases. If you start the data
with the letter @, the rest should be a filename, and each line in the file should
be a single datum in the format described above. *NOTE* If `input_data` is not None,
you can only use load-and-run to run the output file.
repeat: how many times the input image is repeated. Useful when running benchmark for
batch size other than one. Have no effect on randomly generated input data.
silent: whether set verbose to False in assert_equal opr.
no_assert: whether insert assert_equal opr to check result; this option is useful for
benchmarking.
maxerr: max error for assert_equal check during runtime.
resize_input: whether resize input image to fit input var shape.
input_transform: a python expression to transform the input data.
Example: data / np.std(data)
Keyword Arguments:
...
...
@@ -778,6 +1099,8 @@ class trace:
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
* enable_fuse_preprocess: whether to fuse astype\pad_channel\dimshuffle and
etc opr
"""
if
not
self
.
_capture_as_const
:
raise
ValueError
(
...
...
@@ -892,8 +1215,28 @@ class trace:
v
.
name
=
output_names
[
i
]
dest_vars
.
append
(
v
)
dest_vars
=
[
i
.
_node
for
i
in
dest_vars
]
if
input_data
is
not
None
:
feeds
=
self
.
_make_feed
(
graph
,
dest_vars
,
input_data
,
repeat
,
silent
,
no_assert
,
maxerr
,
resize_input
,
input_transform
,
)
assert
(
isinstance
(
feeds
,
dict
)
and
feeds
[
"testcases"
]
),
"testcases can not be empty"
dest_vars
=
feeds
[
"outputs"
]
if
optimize_for_inference
:
dest_vars
,
optimize_options
=
G
.
optimize_for_inference
(
dest_vars
,
**
kwargs
)
dest_vars
=
[
i
.
_node
for
i
in
dest_vars
]
metadata
=
SerializationMetadata
()
if
enable_metadata
:
...
...
@@ -910,6 +1253,9 @@ class trace:
if
keep_opr_priority
:
graph
.
_set_priority_to_id
(
dest_vars
)
if
input_data
is
not
None
:
file
.
write
(
b
"mgbtest0"
)
file
.
write
(
struct
.
pack
(
"I"
,
len
(
feeds
[
"testcases"
])))
dump_content
,
dump_info
=
G
.
dump_graph
(
dest_vars
,
keep_var_name
=
keep_var_name
,
...
...
@@ -921,6 +1267,34 @@ class trace:
metadata
=
metadata
,
)
file
.
write
(
dump_content
)
if
input_data
is
not
None
:
inputs
=
cgtools
.
get_dep_vars
(
dest_vars
,
"Host2DeviceCopy"
)
inputs
=
sorted
((
i
.
name
,
i
.
dtype
)
for
i
in
inputs
)
def
make_dev_tensor
(
value
,
dtype
=
None
,
device
=
None
):
return
tensor
(
value
,
dtype
=
dtype
,
device
=
device
).
_dev_tensor
()
for
testcase
in
feeds
[
"testcases"
]:
assert
isinstance
(
testcase
,
dict
)
cg
=
G
.
Graph
()
output_mgbvars
=
[]
for
name
,
dtype
in
inputs
:
output_mgbvars
.
append
(
cg
.
make_const
(
make_dev_tensor
(
testcase
.
pop
(
name
),
dtype
=
dtype
,
device
=
"cpux"
)
)
)
assert
not
testcase
,
"extra inputs provided in testcase: {}"
.
format
(
testcase
.
keys
()
)
dump_content
,
_
=
G
.
dump_graph
(
output_mgbvars
,
strip_info_file
=
strip_info_file
,
append_json
=
True
,
)
file
.
write
(
dump_content
)
return
dump_info
def
_process_inputs
(
self
,
*
args
,
**
kwargs
):
...
...
imperative/python/test/unit/jit/test_tracing.py
浏览文件 @
4f6c5d8f
...
...
@@ -287,6 +287,16 @@ def test_dump_backward_graph():
np
.
testing
.
assert_equal
(
results
[
1
],
dx0
)
def
test_dump_with_testcase
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
x
):
return
exp
(
x
)
f
(
tensor
(
1.0
))
file
=
io
.
BytesIO
()
f
.
dump
(
file
,
input_data
=
[
"#rand(0, 255, 1)"
])
@
pytest
.
mark
.
parametrize
(
"trace_mode"
,
[
False
,
True
])
def
test_trace_profiler
(
trace_mode
):
@
trace
(
symbolic
=
trace_mode
,
profiling
=
True
)
...
...
sdk/load-and-run/dump_with_testcase_mge.py
已删除
100755 → 0
浏览文件 @
182ca25d
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
argparse
import
os
import
re
import
struct
import
cv2
import
numpy
as
np
import
megengine
as
mge
import
megengine.core._imperative_rt
as
rt
import
megengine.core.tensor.megbrain_graph
as
G
from
megengine
import
tensor
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core.ops
import
builtin
from
megengine.utils
import
comp_graph_tools
as
cgtools
logger
=
mge
.
get_logger
(
__name__
)
def
auto_reformat_image
(
args
,
path
,
data
,
dst_shape
):
"""reformat image to target shape
:param data: image data as numpy array
:param dst_shape: target shape
"""
dim3_format
=
False
# required input format does not contain batch
hwc_format
=
False
# required input format is NHWC
if
not
dst_shape
:
# input tensor shape is not predefined
if
len
(
data
.
shape
)
==
2
:
chl
=
1
h
=
data
.
shape
[
0
]
w
=
data
.
shape
[
1
]
else
:
assert
len
(
data
.
shape
)
==
3
,
"Input image must be of dimension 2 or 3"
h
,
w
,
chl
=
data
.
shape
dst_shape
=
(
1
,
chl
,
h
,
w
)
if
len
(
dst_shape
)
==
3
:
dst_shape
=
(
1
,)
+
dst_shape
dim3_format
=
True
assert
len
(
dst_shape
)
==
4
,
"bad dst_shape: {}"
.
format
(
dst_shape
)
chl
=
dst_shape
[
1
]
if
chl
in
[
1
,
3
]:
n
,
c
,
h
,
w
=
dst_shape
dst_shape
=
(
n
,
h
,
w
,
c
)
else
:
chl
=
dst_shape
[
3
]
assert
chl
in
[
1
,
3
],
"can not infer input format from shape: {}"
.
format
(
dst_shape
)
hwc_format
=
True
# dst_shape has now been normalized to NHWC format
if
args
.
resize_input
:
h
,
w
=
dst_shape
[
1
:
3
]
data
=
cv2
.
resize
(
data
,
(
w
,
h
))
logger
.
info
(
"input {} resized to {}"
.
format
(
path
,
data
.
shape
))
if
chl
==
1
:
data
=
cv2
.
cvtColor
(
data
,
cv2
.
COLOR_BGR2GRAY
)
data
=
data
[:,
:,
np
.
newaxis
]
assert
data
.
ndim
==
3
data
=
data
[
np
.
newaxis
]
# data normalized to NHWC format
if
not
hwc_format
:
data
=
np
.
transpose
(
data
,
(
0
,
3
,
1
,
2
))
if
dim3_format
:
data
=
np
.
squeeze
(
data
,
0
)
return
data
def
read_input_data
(
args
,
dst_shape
,
dtype
,
path
,
repeat
):
def
check_shape_equal
(
dst_shape
,
data_shape
):
if
len
(
dst_shape
):
assert
len
(
data_shape
)
==
len
(
dst_shape
),
"input/data shapes mismatch: {} vs {}"
.
format
(
dst_shape
,
data_shape
)
if
data_shape
[
1
:]
!=
dst_shape
[
1
:]:
logger
.
warning
(
"dst_shape is {}; data_shape is {}"
.
format
(
dst_shape
,
data_shape
)
)
if
path
.
startswith
(
"#"
):
assert
not
args
.
resize_input
assert
not
args
.
input_transform
spec
=
path
m
=
re
.
match
(
r
"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$"
,
spec
)
assert
m
,
"bad spec {}"
.
format
(
spec
)
rng_min
=
float
(
m
.
group
(
1
))
rng_max
=
float
(
m
.
group
(
2
))
if
m
.
group
(
3
):
shape_str
=
m
.
group
(
3
)
try
:
shape
=
shape_str
[
1
:].
split
(
","
)
if
shape
[
-
1
].
strip
()
==
"..."
:
shape
=
shape
[:
-
1
]
shape
.
extend
(
list
(
dst_shape
[
len
(
shape
)
:]))
data_shape
=
tuple
(
map
(
int
,
shape
))
except
ValueError
as
e
:
raise
ValueError
(
"bad spec {}: {}"
.
format
(
spec
,
e
.
args
))
else
:
data_shape
=
dst_shape
check_shape_equal
(
dst_shape
,
data_shape
)
return
np
.
random
.
uniform
(
rng_min
,
rng_max
,
data_shape
).
astype
(
dtype
)
# try to load image
data
=
cv2
.
imread
(
path
,
cv2
.
IMREAD_COLOR
)
if
data
is
None
:
assert
not
args
.
resize_input
data
=
np
.
load
(
path
)
assert
isinstance
(
data
,
np
.
ndarray
)
else
:
# load image succeeds, so we expect input format is image format
data
=
auto_reformat_image
(
args
,
path
,
data
,
dst_shape
)
data
=
np
.
repeat
(
data
,
repeat
,
axis
=
0
)
if
repeat
>
1
:
logger
.
info
(
"repeat input for {} times, data shape is {}"
.
format
(
repeat
,
data
.
shape
)
)
check_shape_equal
(
dst_shape
,
data
.
shape
)
if
args
.
input_transform
:
data
=
eval
(
args
.
input_transform
,
{
"data"
:
data
,
"np"
:
np
})
return
data
def
gen_one_testcase
(
args
,
inputs
,
spec
):
paths
=
spec
.
split
(
";"
)
if
len
(
paths
)
!=
len
(
inputs
):
if
len
(
paths
)
==
1
and
paths
[
0
].
startswith
(
"#"
):
paths
=
[
"{}:{}"
.
format
(
name
,
paths
[
0
])
for
name
in
inputs
.
keys
()]
assert
len
(
paths
)
==
len
(
inputs
),
"required inputs: {}; data paths: {}"
.
format
(
inputs
.
keys
(),
paths
)
if
len
(
paths
)
==
1
and
":"
not
in
paths
[
0
]:
paths
[
0
]
=
next
(
iter
(
inputs
.
keys
()))
+
":"
+
paths
[
0
]
ret
=
{}
for
path
in
paths
:
var
,
path
=
path
.
split
(
":"
)
if
args
.
repeat
:
repeat
=
args
.
repeat
else
:
repeat
=
1
ret
[
var
]
=
read_input_data
(
args
,
inputs
[
var
].
shape
,
inputs
[
var
].
dtype
,
path
,
repeat
)
return
ret
def
make_feeds
(
args
):
ret
=
G
.
load_graph
(
args
.
input
)
cg_rt
,
outputs
=
ret
.
graph
,
ret
.
output_vars_list
inputs
=
cgtools
.
get_dep_vars
(
outputs
,
"Host2DeviceCopy"
)
inputs
=
{
i
.
name
:
i
for
i
in
inputs
}
if
not
args
.
no_assert
:
replace_varmap
=
{}
inp_map
=
{}
# replace var use InputNode
for
name
,
var
in
inputs
.
items
():
inp
=
G
.
InputNode
(
device
=
"xpux"
,
dtype
=
var
.
dtype
,
shape
=
var
.
shape
,
graph
=
cg_rt
)
replace_varmap
[
var
]
=
inp
.
outputs
[
0
]
inp_map
[
name
]
=
inp
new
=
cgtools
.
replace_vars
(
outputs
,
replace_varmap
)
if
isinstance
(
new
,
rt
.
VarNode
):
new
=
list
(
new
)
output_nodes
=
[
G
.
OutputNode
(
var
)
for
var
in
new
]
func
=
cg_rt
.
compile
([
node
.
outputs
[
0
]
for
node
in
output_nodes
])
def
make_dev_tensor
(
value
,
dtype
=
None
,
device
=
None
):
return
tensor
(
value
,
dtype
=
dtype
,
device
=
device
).
_dev_tensor
()
def
calculate
(
*
args
,
**
kwargs
):
output_val
=
[]
# set inputs value
for
name
,
var
in
inputs
.
items
():
val
=
kwargs
.
pop
(
name
,
None
)
assert
val
is
not
None
,
"miss input name{}"
.
format
(
name
)
dev_tensor
=
make_dev_tensor
(
val
,
dtype
=
var
.
dtype
,
device
=
"xpux"
)
inp_map
[
name
].
set_value
(
dev_tensor
)
func
.
execute
()
for
res
in
output_nodes
:
output_val
.
append
(
res
.
get_value
().
numpy
())
return
output_val
def
expect_name
(
var
):
return
"{}:expect"
.
format
(
var
.
name
)
testcases
=
[]
np
.
set_printoptions
(
precision
=
2
,
threshold
=
4
,
suppress
=
True
)
data_list
=
[]
for
item
in
args
.
data
:
if
item
.
startswith
(
"@"
):
with
open
(
item
[
1
:],
"r"
)
as
f
:
data_list
.
extend
([
line
.
rstrip
()
for
line
in
f
if
line
.
rstrip
()
!=
""
])
else
:
data_list
.
append
(
item
)
for
inp_spec
in
data_list
:
cur_testcase
=
gen_one_testcase
(
args
,
inputs
,
inp_spec
)
assert
len
(
cur_testcase
)
==
len
(
inputs
),
"required inputs: {}; given data: {}"
.
format
(
inputs
.
keys
(),
cur_testcase
.
keys
()
)
if
not
args
.
no_assert
:
outputs_get
=
calculate
(
**
cur_testcase
)
for
var
,
val
in
zip
(
outputs
,
outputs_get
):
cur_testcase
[
expect_name
(
var
)]
=
val
logger
.
info
(
"generate test groundtruth: var={} shape={} range=({}, {})"
" mean={} var={}"
.
format
(
var
,
val
.
shape
,
val
.
min
(),
val
.
max
(),
np
.
mean
(
val
),
np
.
var
(
val
)
)
)
testcases
.
append
(
cur_testcase
)
logger
.
info
(
"add testcase:
\n
{}"
.
format
(
"
\n
"
.
join
(
"{}: shape={} dtype={} range=({:.2f},{:.2f}) "
"mean={:.2f} sd={:.2f}"
.
format
(
k
,
v
.
shape
,
v
.
dtype
,
v
.
min
(),
v
.
max
(),
np
.
mean
(
v
),
np
.
std
(
v
)
)
for
k
,
v
in
sorted
(
cur_testcase
.
items
())
)
)
)
if
not
args
.
no_assert
:
def
expect_shp
(
var
):
ret
=
var
.
shape
if
ret
:
return
ret
return
testcases
[
0
][
expect_name
(
var
)].
shape
def
assert_equal
(
expect
,
real
,
**
kwargs
):
op
=
builtin
.
AssertEqual
(
**
kwargs
)
(
res
,)
=
apply
(
op
,
expect
,
real
)
return
res
verbose
=
not
args
.
silent
outputs_new
=
[]
for
i
in
outputs
:
device
=
rt
.
CompNode
(
"xpux"
)
dtype
=
i
.
dtype
name
=
expect_name
(
i
)
shape
=
expect_shp
(
i
)
# make expect output as one input of model.
expect_get
=
rt
.
make_h2d
(
cg_rt
,
device
,
dtype
,
shape
,
name
)
# insert assert opr to check expect and real.
outputs_new
.
append
(
assert_equal
(
expect_get
,
i
,
verbose
=
verbose
,
maxerr
=
args
.
maxerr
,
)
)
inputs
[
expect_name
(
i
)]
=
expect_get
outputs
=
outputs_new
return
{
"outputs"
:
outputs
,
"testcases"
:
testcases
}
def
optimize_for_inference
(
args
,
outputs
):
args_list
=
[
"enable_io16xc32"
,
"enable_ioc16"
,
"enable_hwcd4"
,
"enable_nchw4"
,
"enable_nchw88"
,
"enable_nchw44"
,
"enable_nchw44_dot"
,
"enable_nchw32"
,
"enable_chwn4"
,
"enable_fuse_conv_bias_nonlinearity"
,
"enable_fuse_conv_bias_with_z"
,
"enable_fuse_preprocess"
,
]
kwargs
=
{}
for
k
in
args_list
:
if
getattr
(
args
,
k
):
kwargs
[
k
]
=
True
if
args
.
optimize_for_inference
:
outputs
=
G
.
optimize_for_inference
(
outputs
,
**
kwargs
)
return
outputs
def
main
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Pack computing graph, input values and expected output "
"values into one file for checking correctness. README.md gives more "
"details on the usage"
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
"input"
,
help
=
"MegEngine dumped model file"
)
parser
.
add_argument
(
"-o"
,
"--output"
,
help
=
"output file"
,
required
=
True
)
parser
.
add_argument
(
"-d"
,
"--data"
,
default
=
[],
action
=
"append"
,
required
=
True
,
help
=
"Given input test data when input file is a network, "
"and current network output would be used as groundtruth. "
"The format is var0:file0;var1:file1... to specify data files for "
"input vars. It can also be #rand(min,max,shape...) for generating "
"random input data, for example, #rand(0,255), "
"#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means "
"the remaining part of the original shape. "
"If the shape is not specified, the shape of "
"corresponding input tensors in the network will be used. "
"If there is only one input var, its name can be omitted. "
"Each data file can either be an image which can be loaded by opencv, "
"or a pickled numpy.ndarray. "
"This option can be given multiple times to add multiple testcases. "
" *NOTE* "
"If you start the data with the letter @, the rest should be a "
"filename, and each line in the file should be a single datum in "
"the format described above. "
,
)
parser
.
add_argument
(
"--repeat"
,
type
=
int
,
default
=
1
,
help
=
"Specify how many times the input image is repeated. "
"Useful when running benchmark for batch size other than one. "
"Have no effect on randomly generated input data."
,
)
parser
.
add_argument
(
"--silent"
,
action
=
"store_true"
,
help
=
"set verbose to False in asserti_equal opr"
,
)
parser
.
add_argument
(
"--optimize-for-inference"
,
action
=
"store_true"
,
help
=
"enable optimization for inference"
,
)
parser
.
add_argument
(
"--no-assert"
,
action
=
"store_true"
,
help
=
"do not insert assert_equal opr to check result; "
"this option is useful for benchmarking"
,
)
parser
.
add_argument
(
"--maxerr"
,
type
=
float
,
default
=
1e-4
,
help
=
"max error for assert_equal check during runtime"
,
)
parser
.
add_argument
(
"--resize-input"
,
action
=
"store_true"
,
help
=
"resize input image to fit input var shape"
,
)
parser
.
add_argument
(
"--input-transform"
,
help
=
"a python expression to transform the input data. "
"Example: data / np.std(data)"
,
)
parser
.
add_argument
(
"--discard-var-name"
,
action
=
"store_true"
,
help
=
"discard variable and param names in the "
"generated output"
,
)
parser
.
add_argument
(
"--output-strip-info"
,
action
=
"store_true"
,
help
=
"output code strip information"
)
parser
.
add_argument
(
"--enable-io16xc32"
,
action
=
"store_true"
,
help
=
"transform the mode to float16 io float32 compute"
,
)
parser
.
add_argument
(
"--enable-ioc16"
,
action
=
"store_true"
,
help
=
"transform the dtype of the model to float16 io "
"and compute"
,
)
parser
.
add_argument
(
"--enable-fuse-conv-bias-nonlinearity"
,
action
=
"store_true"
,
help
=
"fuse convolution bias and nonlinearity opr to a "
"conv_bias opr and compute"
,
)
parser
.
add_argument
(
"--enable-hwcd4"
,
action
=
"store_true"
,
help
=
"transform the model format from NCHW to NHWCD4 "
"for inference; you may need to disable CUDA and set "
"MGB_USE_MEGDNN_DBG=2"
,
)
parser
.
add_argument
(
"--enable-nchw4"
,
action
=
"store_true"
,
help
=
"transform the model format from NCHW to NCHW4 "
"for inference"
,
)
parser
.
add_argument
(
"--enable-nchw88"
,
action
=
"store_true"
,
help
=
"transform the model format from NCHW to NCHW88 "
"for inference"
,
)
parser
.
add_argument
(
"--enable-nchw44"
,
action
=
"store_true"
,
help
=
"transform the model format from NCHW to NCHW44 "
"for inference"
,
)
parser
.
add_argument
(
"--enable-nchw44-dot"
,
action
=
"store_true"
,
help
=
"transform the model format from NCHW to NCHW44_DOT "
"for optimizing armv8.2 dot in inference"
,
)
parser
.
add_argument
(
"--enable-nchw32"
,
action
=
"store_true"
,
help
=
"transform the model format from NCHW4 to NCHW32 "
"for inference on nvidia TensoCore"
,
)
parser
.
add_argument
(
"--enable-chwn4"
,
action
=
"store_true"
,
help
=
"transform the model format to CHWN4 "
"for inference, mainly used for nvidia tensorcore"
,
)
parser
.
add_argument
(
"--enable-fuse-conv-bias-with-z"
,
action
=
"store_true"
,
help
=
"fuse conv_bias with z input for inference on "
"nvidia GPU (this optimization pass will result in mismatch "
"of the precision of output of training and inference)"
,
)
parser
.
add_argument
(
"--enable-fuse-preprocess"
,
action
=
"store_true"
,
help
=
"fuse astype\pad_channel\dimshuffle and etc opr "
"from h2d opr"
,
)
args
=
parser
.
parse_args
()
feeds
=
make_feeds
(
args
)
assert
isinstance
(
feeds
,
dict
)
and
feeds
[
"testcases"
],
"testcases can not be empty"
output_mgbvars
=
feeds
[
"outputs"
]
output_mgbvars
=
optimize_for_inference
(
args
,
output_mgbvars
)
inputs
=
cgtools
.
get_dep_vars
(
output_mgbvars
,
"Host2DeviceCopy"
)
inputs
=
sorted
((
i
.
name
,
i
.
dtype
)
for
i
in
inputs
)
if
args
.
discard_var_name
:
sereg_kwargs
=
dict
(
keep_var_name
=
0
,
keep_param_name
=
False
)
else
:
sereg_kwargs
=
dict
(
keep_var_name
=
2
,
keep_param_name
=
True
)
strip_info_file
=
args
.
output
+
".json"
if
args
.
output_strip_info
else
None
with
open
(
args
.
output
,
"wb"
)
as
fout
:
fout
.
write
(
b
"mgbtest0"
)
fout
.
write
(
struct
.
pack
(
"I"
,
len
(
feeds
[
"testcases"
])))
dump_content
,
stat
=
G
.
dump_graph
(
output_mgbvars
,
append_json
=
True
,
strip_info_file
=
strip_info_file
,
**
sereg_kwargs
,
)
fout
.
write
(
dump_content
)
logger
.
info
(
"graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB"
.
format
(
stat
.
tot_bytes
/
1024
,
(
stat
.
tot_bytes
-
stat
.
tensor_value_bytes
)
/
1024
)
)
def
make_dev_tensor
(
value
,
dtype
=
None
,
device
=
None
):
return
tensor
(
value
,
dtype
=
dtype
,
device
=
device
).
_dev_tensor
()
for
testcase
in
feeds
[
"testcases"
]:
assert
isinstance
(
testcase
,
dict
)
cg
=
G
.
Graph
()
output_mgbvars
=
[]
for
name
,
dtype
in
inputs
:
output_mgbvars
.
append
(
cg
.
make_const
(
make_dev_tensor
(
testcase
.
pop
(
name
),
dtype
=
dtype
,
device
=
"cpux"
)
)
)
assert
not
testcase
,
"extra inputs provided in testcase: {}"
.
format
(
testcase
.
keys
()
)
with
open
(
args
.
output
,
"ab"
)
as
fout
:
dump_content
,
_
=
G
.
dump_graph
(
output_mgbvars
,
strip_info_file
=
strip_info_file
,
append_json
=
True
)
fout
.
write
(
dump_content
)
if
__name__
==
"__main__"
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录