Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
毕竟曾有刹那
Mace
提交
710e319f
Mace
项目概览
毕竟曾有刹那
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
710e319f
编写于
8月 03, 2018
作者:
L
Liangliang He
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support customized similarity threshold for model validation
上级
4a95917c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
59 addition
and
16 deletion
+59
-16
docs/installation/env_requirement.rst
docs/installation/env_requirement.rst
+3
-0
docs/user_guide/advanced_usage.rst
docs/user_guide/advanced_usage.rst
+2
-0
tools/converter.py
tools/converter.py
+28
-1
tools/sh_commands.py
tools/sh_commands.py
+7
-3
tools/validate.py
tools/validate.py
+19
-12
未找到文件。
docs/installation/env_requirement.rst
浏览文件 @
710e319f
...
...
@@ -33,6 +33,9 @@ Required dependencies
* - Numpy
- pip install -I numpy==1.14.0
- Required by model validation
* - six
- pip install -I six==1.11.0
- Required for Python 2 and 3 compatibility (TODO)
Optional dependencies
---------------------
...
...
docs/user_guide/advanced_usage.rst
浏览文件 @
710e319f
...
...
@@ -76,6 +76,8 @@ in one deployment file.
- The numerical range of the input tensors' data, default [-1, 1]. It is only for test.
* - validation_inputs_data
- [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used.
* - validation_threshold
- [optional] Specify the similarity threshold for validation. A dict with key in 'CPU', 'GPU' and/or 'HEXAGON' and value <= 1.0.
* - runtime
- The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type
...
...
tools/converter.py
浏览文件 @
710e319f
...
...
@@ -19,6 +19,7 @@ import os
import
re
import
sh
import
subprocess
import
six
import
sys
import
urllib
import
yaml
...
...
@@ -189,6 +190,7 @@ class YAMLKeyword(object):
quantize
=
'quantize'
quantize_range_file
=
'quantize_range_file'
validation_inputs_data
=
'validation_inputs_data'
validation_threshold
=
'validation_threshold'
graph_optimize_options
=
'graph_optimize_options'
# internal use for now
...
...
@@ -444,6 +446,30 @@ def format_model_config(flags):
"'%s' is necessary in subgraph"
%
key
)
if
not
isinstance
(
value
,
list
):
subgraph
[
key
]
=
[
value
]
validation_threshold
=
subgraph
.
get
(
YAMLKeyword
.
validation_threshold
,
{})
if
not
isinstance
(
validation_threshold
,
dict
):
raise
argparse
.
ArgumentTypeError
(
'similarity threshold must be a dict.'
)
threshold_dict
=
{
DeviceType
.
CPU
:
0.999
,
DeviceType
.
GPU
:
0.995
,
DeviceType
.
HEXAGON
:
0.930
,
}
for
k
,
v
in
six
.
iteritems
(
validation_threshold
):
if
k
.
upper
()
==
'DSP'
:
k
=
DeviceType
.
HEXAGON
if
k
.
upper
()
not
in
(
DeviceType
.
CPU
,
DeviceType
.
GPU
,
DeviceType
.
HEXAGON
):
raise
argparse
.
ArgumentTypeError
(
'Unsupported validation threshold runtime: %s'
%
k
)
threshold_dict
[
k
.
upper
()]
=
v
subgraph
[
YAMLKeyword
.
validation_threshold
]
=
threshold_dict
validation_inputs_data
=
subgraph
.
get
(
YAMLKeyword
.
validation_inputs_data
,
[])
if
not
isinstance
(
validation_inputs_data
,
list
):
...
...
@@ -1202,7 +1228,8 @@ def run_specific_target(flags, configs, target_abi,
output_shapes
=
subgraphs
[
0
][
YAMLKeyword
.
output_shapes
],
model_output_dir
=
model_output_dir
,
phone_data_dir
=
PHONE_DATA_DIR
,
caffe_env
=
flags
.
caffe_env
)
caffe_env
=
flags
.
caffe_env
,
validation_threshold
=
subgraphs
[
0
][
YAMLKeyword
.
validation_threshold
][
device_type
])
# noqa
if
flags
.
report
and
flags
.
round
>
0
:
tuned
=
is_tuned
and
device_type
==
DeviceType
.
GPU
report_run_statistics
(
...
...
tools/sh_commands.py
浏览文件 @
710e319f
...
...
@@ -799,7 +799,8 @@ def validate_model(abi,
phone_data_dir
,
caffe_env
,
input_file_name
=
"model_input"
,
output_file_name
=
"model_out"
):
output_file_name
=
"model_out"
,
validation_threshold
=
0.9
):
print
(
"* Validate with %s"
%
platform
)
if
abi
!=
"host"
:
for
output_name
in
output_nodes
:
...
...
@@ -816,7 +817,8 @@ def validate_model(abi,
"%s/%s"
%
(
model_output_dir
,
input_file_name
),
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
))
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
validation_threshold
)
elif
platform
==
"caffe"
:
image_name
=
"mace-caffe:latest"
container_name
=
"mace_caffe_validator"
...
...
@@ -832,7 +834,8 @@ def validate_model(abi,
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
))
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
validation_threshold
)
elif
caffe_env
==
common
.
CaffeEnvType
.
DOCKER
:
docker_image_id
=
sh
.
docker
(
"images"
,
"-q"
,
image_name
)
if
not
docker_image_id
:
...
...
@@ -896,6 +899,7 @@ def validate_model(abi,
"--output_node=%s"
%
","
.
join
(
output_nodes
),
"--input_shape=%s"
%
":"
.
join
(
input_shapes
),
"--output_shape=%s"
%
":"
.
join
(
output_shapes
),
"--validation_threshold=%f"
%
validation_threshold
,
_fg
=
True
)
print
(
"Validation done!
\n
"
)
...
...
tools/validate.py
浏览文件 @
710e319f
...
...
@@ -35,6 +35,7 @@ import common
# --output_node output_node \
# --input_shape 1,64,64,3 \
# --output_shape 1,64,64,2
# --validation_threshold 0.995
VALIDATION_MODULE
=
'VALIDATION'
...
...
@@ -47,7 +48,7 @@ def load_data(file):
def
compare_output
(
platform
,
device_type
,
output_name
,
mace_out_value
,
out_value
):
out_value
,
validation_threshold
):
if
mace_out_value
.
size
!=
0
:
out_value
=
out_value
.
reshape
(
-
1
)
mace_out_value
=
mace_out_value
.
reshape
(
-
1
)
...
...
@@ -56,9 +57,7 @@ def compare_output(platform, device_type, output_name, mace_out_value,
common
.
MaceLogger
.
summary
(
output_name
+
' MACE VS '
+
platform
.
upper
()
+
' similarity: '
+
str
(
similarity
))
if
(
device_type
==
"CPU"
and
similarity
>
0.999
)
or
\
(
device_type
==
"GPU"
and
similarity
>
0.995
)
or
\
(
device_type
==
"HEXAGON"
and
similarity
>
0.930
):
if
similarity
>
validation_threshold
:
common
.
MaceLogger
.
summary
(
common
.
StringFormatter
.
block
(
"Similarity Test Passed"
))
else
:
...
...
@@ -78,7 +77,8 @@ def normalize_tf_tensor_name(name):
def
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
):
mace_out_file
,
input_names
,
input_shapes
,
output_names
,
validation_threshold
):
import
tensorflow
as
tf
if
not
os
.
path
.
isfile
(
model_file
):
common
.
MaceLogger
.
error
(
...
...
@@ -115,12 +115,13 @@ def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
output_values
[
i
])
mace_out_value
,
output_values
[
i
],
validation_threshold
)
def
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
weight_file
,
input_names
,
input_shapes
,
output_names
,
output_shapes
):
output_names
,
output_shapes
,
validation_threshold
):
os
.
environ
[
'GLOG_minloglevel'
]
=
'1'
# suprress Caffe verbose prints
import
caffe
if
not
os
.
path
.
isfile
(
model_file
):
...
...
@@ -162,11 +163,12 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
value
)
value
,
validation_threshold
)
def
validate
(
platform
,
model_file
,
weight_file
,
input_file
,
mace_out_file
,
device_type
,
input_shape
,
output_shape
,
input_node
,
output_node
):
device_type
,
input_shape
,
output_shape
,
input_node
,
output_node
,
validation_threshold
):
input_names
=
[
name
for
name
in
input_node
.
split
(
','
)]
input_shape_strs
=
[
shape
for
shape
in
input_shape
.
split
(
':'
)]
input_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
...
...
@@ -177,14 +179,15 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
if
platform
==
'tensorflow'
:
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
)
output_names
,
validation_threshold
)
elif
platform
==
'caffe'
:
output_shape_strs
=
[
shape
for
shape
in
output_shape
.
split
(
':'
)]
output_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
for
shape
in
output_shape_strs
]
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
weight_file
,
input_names
,
input_shapes
,
output_names
,
output_shapes
)
input_shapes
,
output_names
,
output_shapes
,
validation_threshold
)
def
parse_args
():
...
...
@@ -219,6 +222,9 @@ def parse_args():
"--input_node"
,
type
=
str
,
default
=
"input_node"
,
help
=
"input node"
)
parser
.
add_argument
(
"--output_node"
,
type
=
str
,
default
=
"output_node"
,
help
=
"output node"
)
parser
.
add_argument
(
"--validation_threshold"
,
type
=
float
,
default
=
0.995
,
help
=
"validation similarity threshold"
)
return
parser
.
parse_known_args
()
...
...
@@ -234,4 +240,5 @@ if __name__ == '__main__':
FLAGS
.
input_shape
,
FLAGS
.
output_shape
,
FLAGS
.
input_node
,
FLAGS
.
output_node
)
FLAGS
.
output_node
,
FLAGS
.
validation_threshold
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录