Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
项目经理老王
Mace
提交
8173b231
Mace
项目概览
项目经理老王
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8173b231
编写于
1月 05, 2019
作者:
李
李寅
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'hexagon' into 'master'
Support layers validate See merge request !946
上级
22e40d66
300aacbb
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
426 addition
and
110 deletion
+426
-110
.gitlab-ci.yml
.gitlab-ci.yml
+2
-0
docs/development/how_to_debug.rst
docs/development/how_to_debug.rst
+18
-1
mace/python/tools/BUILD
mace/python/tools/BUILD
+12
-0
mace/python/tools/converter_tool/hexagon_converter.py
mace/python/tools/converter_tool/hexagon_converter.py
+41
-17
mace/python/tools/layers_validate.py
mace/python/tools/layers_validate.py
+171
-0
mace/python/tools/model_saver.py
mace/python/tools/model_saver.py
+2
-0
tools/common.py
tools/common.py
+15
-5
tools/converter.py
tools/converter.py
+5
-0
tools/device.py
tools/device.py
+125
-71
tools/sh_commands.py
tools/sh_commands.py
+9
-4
tools/validate.py
tools/validate.py
+26
-12
未找到文件。
.gitlab-ci.yml
浏览文件 @
8173b231
...
@@ -147,6 +147,7 @@ python_tools_tests:
...
@@ -147,6 +147,7 @@ python_tools_tests:
python tools/converter.py convert --config=${CONF_FILE} --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py convert --config=${CONF_FILE} --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --round=1 --target_abis=armeabi-v7a,armhf --validate --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --round=1 --target_abis=armeabi-v7a,armhf --validate --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --example --target_abis=armeabi-v7a,armhf --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --example --target_abis=armeabi-v7a,armhf --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --example --target_abis=armeabi-v7a,armhf --round=1 --validate_all_layers --model_graph_format=file --model_data_format=file || exit 1;
model_tests
:
model_tests
:
stage
:
model_tests
stage
:
model_tests
...
@@ -189,6 +190,7 @@ quantization_tests:
...
@@ -189,6 +190,7 @@ quantization_tests:
python tools/converter.py convert --config=${CONF_FILE} --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py convert --config=${CONF_FILE} --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --example --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --example --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
python tools/converter.py run --config=${CONF_FILE} --device_yml=${DEVICE_CONF_FILE} --example --round=1 --validate_all_layers --model_graph_format=file --model_data_format=file || exit 1;
done
done
-
rm -rf mace-models
-
rm -rf mace-models
...
...
docs/development/how_to_debug.rst
浏览文件 @
8173b231
...
@@ -34,6 +34,14 @@ It is usually used to measure classification accuracy. The higher the better.
...
@@ -34,6 +34,14 @@ It is usually used to measure classification accuracy. The higher the better.
where :math:`X` is expected output (from training platform) whereas :math:`X'` is actual output (from MACE) .
where :math:`X` is expected output (from training platform) whereas :math:`X'` is actual output (from MACE) .
You can validate it by specifying `--validate` while running the model.
.. code:: sh
# Validate the correctness by comparing the results against the
# original model and framework
python tools/converter.py run --config=/path/to/your/model_deployment_file.yml --validate
MACE automatically validate these metrics by running models with synthetic inputs.
MACE automatically validate these metrics by running models with synthetic inputs.
If you want to specify input data to use, you can add an option in yaml config under 'subgraphs', e.g.,
If you want to specify input data to use, you can add an option in yaml config under 'subgraphs', e.g.,
...
@@ -53,13 +61,22 @@ If you want to specify input data to use, you can add an option in yaml config u
...
@@ -53,13 +61,22 @@ If you want to specify input data to use, you can add an option in yaml config u
- MobilenetV1/Predictions/Reshape_1
- MobilenetV1/Predictions/Reshape_1
output_shapes:
output_shapes:
- 1,1001
- 1,1001
check_tensors:
- MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd:0
check_shapes:
- 1,1,1,1001
validation_inputs_data:
validation_inputs_data:
- https://cnbj1.fds.api.xiaomi.com/mace/inputs/dog.npy
- https://cnbj1.fds.api.xiaomi.com/mace/inputs/dog.npy
If model's output is suspected to be incorrect, it might be useful to debug your model layer by layer by specifying an intermediate layer as output,
If model's output is suspected to be incorrect, it might be useful to debug your model layer by layer by specifying an intermediate layer as output,
or use binary search method until suspicious layer is found.
or use binary search method until suspicious layer is found.
You can also specify `--validate_all_layers` to validate all the layers of the model(excluding some layers changed by MACE, e.g., BatchToSpaceND),
it only supports TensorFlow now. You can find validation results in `builds/your_model/model/runtime_in_yaml/log.csv`.
For quantized model, if you want to check one layer, you can add `check_tensors` and `check_shapes` like in the yaml above. You can only specify
MACE op's output.
Debug memory usage
Debug memory usage
--------------------------
--------------------------
...
...
mace/python/tools/BUILD
浏览文件 @
8173b231
...
@@ -55,3 +55,15 @@ py_binary(
...
@@ -55,3 +55,15 @@ py_binary(
srcs_version
=
"PY2AND3"
,
srcs_version
=
"PY2AND3"
,
visibility
=
[
"//visibility:public"
],
visibility
=
[
"//visibility:public"
],
)
)
py_binary
(
name
=
"layers_validate"
,
srcs
=
[
"layers_validate.py"
,
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":converter_lib"
,
":model_saver_lib"
,
],
)
mace/python/tools/converter_tool/hexagon_converter.py
浏览文件 @
8173b231
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
from
enum
import
Enum
from
operator
import
mul
from
operator
import
mul
from
mace.proto
import
mace_pb2
from
mace.proto
import
mace_pb2
...
@@ -28,22 +29,44 @@ from mace.python.tools.convert_util import mace_check
...
@@ -28,22 +29,44 @@ from mace.python.tools.convert_util import mace_check
from
mace.python.tools
import
graph_util
from
mace.python.tools
import
graph_util
HexagonSupportedOps
=
[
'BatchToSpaceND_8'
,
'DepthwiseSupernode_8x8p32to8'
,
'DequantizeOUTPUT_8tof'
,
'QuantizedAdd_8p8to8'
,
'QuantizedAvgPool_8'
,
'QuantizedConcat_8'
,
'QuantizedMaxPool_8'
,
'QuantizedResizeBilinear_8'
,
'QuantizedSoftmax_8'
,
'QuantizeINPUT_f_to_8'
,
'SpaceToBatchND_8'
,
'Supernode_8x8p32to8'
,
'Nop'
,
]
HexagonOp
=
Enum
(
'HexagonOp'
,
[(
op
,
op
)
for
op
in
HexagonSupportedOps
],
type
=
str
)
class
HexagonOps
(
object
):
class
HexagonOps
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
hexagon_ops
=
{
self
.
hexagon_ops
=
{
'Quantize'
:
'QuantizeINPUT_f_to_8'
,
MaceOp
.
BatchToSpaceND
.
name
:
HexagonOp
.
BatchToSpaceND_8
.
name
,
'Dequantize'
:
'DequantizeOUTPUT_8tof'
,
MaceOp
.
Concat
.
name
:
HexagonOp
.
QuantizedConcat_8
.
name
,
'Concat'
:
'QuantizedConcat_8'
,
MaceOp
.
Conv2D
.
name
:
HexagonOp
.
Supernode_8x8p32to8
.
name
,
'Conv2D'
:
'Supernode_8x8p32to8'
,
MaceOp
.
DepthwiseConv2d
.
name
:
'DepthwiseConv2d'
:
'DepthwiseSupernode_8x8p32to8'
,
HexagonOp
.
DepthwiseSupernode_8x8p32to8
.
name
,
'ResizeBilinear'
:
'QuantizedResizeBilinear_8'
,
MaceOp
.
Dequantize
.
name
:
HexagonOp
.
DequantizeOUTPUT_8tof
.
name
,
'SpaceToBatchND'
:
'SpaceToBatchND_8'
,
MaceOp
.
Eltwise
.
name
:
[
HexagonOp
.
QuantizedAdd_8p8to8
],
'BatchToSpaceND'
:
'BatchToSpaceND_8'
,
MaceOp
.
Identity
.
name
:
HexagonOp
.
Nop
.
name
,
'Softmax'
:
'QuantizedSoftmax_8'
,
MaceOp
.
Quantize
.
name
:
HexagonOp
.
QuantizeINPUT_f_to_8
.
name
,
'Eltwise'
:
'Eltwise'
,
MaceOp
.
Pooling
.
name
:
[
HexagonOp
.
QuantizedAvgPool_8
.
name
,
'Pooling'
:
'Pooling'
,
HexagonOp
.
QuantizedMaxPool_8
.
name
],
'Identity'
:
'Nop'
,
MaceOp
.
ResizeBilinear
.
name
:
'Squeeze'
:
'Nop'
,
HexagonOp
.
QuantizedResizeBilinear_8
.
name
,
MaceOp
.
SpaceToBatchND
.
name
:
HexagonOp
.
SpaceToBatchND_8
.
name
,
MaceOp
.
Softmax
.
name
:
HexagonOp
.
QuantizedSoftmax_8
.
name
,
}
}
def
has_op
(
self
,
tf_op
):
def
has_op
(
self
,
tf_op
):
...
@@ -116,7 +139,6 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -116,7 +139,6 @@ class HexagonConverter(base_converter.ConverterInterface):
for
op
in
self
.
_model
.
op
:
for
op
in
self
.
_model
.
op
:
if
not
self
.
_hexagon_ops
.
has_op
(
op
.
type
):
if
not
self
.
_hexagon_ops
.
has_op
(
op
.
type
):
raise
Exception
(
'Unsupported op: '
,
op
)
raise
Exception
(
'Unsupported op: '
,
op
)
print
(
'Op: %s (%s)'
%
(
op
.
name
,
op
.
type
))
for
i
in
range
(
len
(
op
.
input
)):
for
i
in
range
(
len
(
op
.
input
)):
if
':'
not
in
op
.
input
[
i
]:
if
':'
not
in
op
.
input
[
i
]:
node_name
=
op
.
input
[
i
]
node_name
=
op
.
input
[
i
]
...
@@ -250,14 +272,14 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -250,14 +272,14 @@ class HexagonConverter(base_converter.ConverterInterface):
and
ConverterUtil
.
get_arg
(
and
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_element_type_str
).
i
op
,
MaceKeyword
.
mace_element_type_str
).
i
==
EltwiseType
.
SUM
.
value
):
==
EltwiseType
.
SUM
.
value
):
op
.
type
=
'QuantizedAdd_8p8to8'
op
.
type
=
HexagonOp
.
QuantizedAdd_8p8to8
.
name
elif
op
.
type
==
MaceOp
.
Pooling
.
name
:
elif
op
.
type
==
MaceOp
.
Pooling
.
name
:
pooling_type_arg
=
ConverterUtil
.
get_arg
(
pooling_type_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_pooling_type_str
)
op
,
MaceKeyword
.
mace_pooling_type_str
)
if
PoolingType
(
pooling_type_arg
.
i
)
==
PoolingType
.
AVG
:
if
PoolingType
(
pooling_type_arg
.
i
)
==
PoolingType
.
AVG
:
op
.
type
=
'QuantizedAvgPool_8'
op
.
type
=
HexagonOp
.
QuantizedAvgPool_8
.
name
else
:
else
:
op
.
type
=
'QuantizedMaxPool_8'
op
.
type
=
HexagonOp
.
QuantizedMaxPool_8
.
name
else
:
else
:
op
.
type
=
self
.
_hexagon_ops
.
map_nn_op
(
op
.
type
)
op
.
type
=
self
.
_hexagon_ops
.
map_nn_op
(
op
.
type
)
...
@@ -342,8 +364,10 @@ class HexagonConverter(base_converter.ConverterInterface):
...
@@ -342,8 +364,10 @@ class HexagonConverter(base_converter.ConverterInterface):
tensor_op
,
port
=
get_op_and_port_from_tensor
(
tensor
.
name
)
tensor_op
,
port
=
get_op_and_port_from_tensor
(
tensor
.
name
)
node_id_map
[
tensor_op
]
=
tensor
.
node_id
node_id_map
[
tensor_op
]
=
tensor
.
node_id
print
(
"Hexagon op:"
)
for
op
in
self
.
_model
.
op
:
for
op
in
self
.
_model
.
op
:
op
.
node_id
=
node_id_counter
op
.
node_id
=
node_id_counter
print
(
'Op: %s (%s, %d)'
%
(
op
.
name
,
op
.
type
,
op
.
node_id
))
node_id_counter
+=
1
node_id_counter
+=
1
node_id_map
[
op
.
name
]
=
op
.
node_id
node_id_map
[
op
.
name
]
=
op
.
node_id
for
ipt
in
op
.
input
:
for
ipt
in
op
.
input
:
...
...
mace/python/tools/layers_validate.py
0 → 100644
浏览文件 @
8173b231
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
copy
import
os
import
sys
import
yaml
from
mace.proto
import
mace_pb2
from
mace.python.tools.converter_tool.base_converter
import
ConverterUtil
from
mace.python.tools.converter_tool.base_converter
import
MaceKeyword
from
mace.python.tools.converter_tool.base_converter
import
MaceOp
from
mace.python.tools.converter_tool.hexagon_converter
import
HexagonOp
from
mace.python.tools.convert_util
import
mace_check
from
mace.python.tools.model_saver
import
save_model_to_proto
def
normalize_op_name
(
name
):
return
name
.
replace
(
'/'
,
'_'
).
replace
(
':'
,
'_'
)
def
main
(
unused_args
):
mace_check
(
os
.
path
.
isfile
(
FLAGS
.
model_file
),
"Input graph file '"
+
FLAGS
.
model_file
+
"' does not exist!"
)
mace_check
(
os
.
path
.
isdir
(
FLAGS
.
output_dir
),
"Output directory '"
+
FLAGS
.
output_dir
+
"' does not exist!"
)
net_def
=
mace_pb2
.
NetDef
()
with
open
(
FLAGS
.
model_file
,
"rb"
)
as
f
:
net_def
.
ParseFromString
(
f
.
read
())
quantize_flag
=
ConverterUtil
.
get_arg
(
net_def
,
MaceKeyword
.
mace_quantize_flag_arg_str
)
quantize_flag
=
False
if
quantize_flag
is
None
else
quantize_flag
.
i
==
1
hexagon_flag
=
False
index
=
0
end_index
=
len
(
net_def
.
op
)
if
quantize_flag
:
while
index
<
end_index
:
# omit op quantize
if
net_def
.
op
[
index
].
type
==
MaceOp
.
Quantize
.
name
or
\
net_def
.
op
[
index
].
type
==
\
HexagonOp
.
QuantizeINPUT_f_to_8
.
name
:
index
+=
1
# omit op dequantize
elif
net_def
.
op
[
end_index
-
1
].
type
==
MaceOp
.
Dequantize
.
name
or
\
net_def
.
op
[
end_index
-
1
].
type
==
\
HexagonOp
.
DequantizeOUTPUT_8tof
.
name
:
end_index
-=
1
else
:
break
mace_check
(
0
<
index
<
end_index
<
len
(
net_def
.
op
),
"Wrong number of op quantize(%d) or dequantize(%d)."
%
(
index
,
len
(
net_def
.
op
)
-
end_index
))
if
net_def
.
op
[
-
1
].
type
==
HexagonOp
.
DequantizeOUTPUT_8tof
.
name
:
hexagon_flag
=
True
# omit original output
end_index
-=
1
data_format
=
net_def
.
output_info
[
0
].
data_format
output_configs
=
{
"subgraphs"
:
[]}
while
index
<
end_index
:
# omit BatchToSpaceND and op before that due to changed graph
if
net_def
.
op
[
index
].
type
==
MaceOp
.
BatchToSpaceND
.
name
or
\
net_def
.
op
[
index
].
type
==
HexagonOp
.
BatchToSpaceND_8
.
name
or
\
(
index
+
1
<
end_index
and
(
net_def
.
op
[
index
+
1
].
type
==
MaceOp
.
BatchToSpaceND
.
name
or
net_def
.
op
[
index
+
1
].
type
==
HexagonOp
.
BatchToSpaceND_8
.
name
)):
# noqa
index
+=
1
continue
net
=
copy
.
deepcopy
(
net_def
)
if
hexagon_flag
:
# reuse dequantize op and it's min/max tensor's node_id
del
net
.
op
[
index
+
1
:
end_index
+
1
]
else
:
del
net
.
op
[
index
+
1
:]
del
net
.
output_info
[:]
op
=
net
.
op
[
index
]
index
+=
1
output_tensors
=
[]
output_shapes
=
[]
op_name
=
op
.
name
if
quantize_flag
:
op
.
name
=
MaceKeyword
.
mace_output_node_name
+
'_'
+
op
.
name
if
hexagon_flag
:
mace_check
(
len
(
op
.
output
)
==
1
,
"Only supports number of outputs of Hexagon op be 1."
)
for
i
in
range
(
len
(
op
.
output
)):
output_tensors
.
append
(
str
(
op
.
output
[
i
]))
output_shapes
.
append
(
","
.
join
([
str
(
dim
)
for
dim
in
op
.
output_shape
[
i
].
dims
]))
# modify output info
output_info
=
net
.
output_info
.
add
()
output_info
.
name
=
op
.
output
[
i
]
output_info
.
data_format
=
data_format
output_info
.
dims
.
extend
(
op
.
output_shape
[
i
].
dims
)
output_info
.
data_type
=
mace_pb2
.
DT_FLOAT
# modify output op
if
quantize_flag
:
output_name
=
op
.
output
[
i
]
new_output_name
=
\
MaceKeyword
.
mace_output_node_name
+
'_'
+
op
.
output
[
i
]
op
.
output
[
i
]
=
new_output_name
if
not
hexagon_flag
:
dequantize_op
=
net
.
op
.
add
()
dequantize_op
.
name
=
normalize_op_name
(
output_name
)
dequantize_op
.
type
=
MaceOp
.
Dequantize
.
name
dequantize_op
.
input
.
append
(
new_output_name
)
dequantize_op
.
output
.
append
(
output_name
)
output_shape
=
dequantize_op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
op
.
output_shape
[
i
].
dims
)
dequantize_op
.
output_type
.
append
(
mace_pb2
.
DT_FLOAT
)
ConverterUtil
.
add_data_type_arg
(
dequantize_op
,
mace_pb2
.
DT_UINT8
)
else
:
dequantize_op
=
net
.
op
[
-
1
]
dequantize_op
.
name
=
normalize_op_name
(
output_name
)
del
dequantize_op
.
input
[:]
del
dequantize_op
.
output
[:]
dequantize_op
.
input
.
append
(
new_output_name
)
dequantize_op
.
output
.
append
(
output_name
)
input_min
=
new_output_name
[:
-
1
]
+
'1'
input_max
=
new_output_name
[:
-
1
]
+
'2'
dequantize_op
.
input
.
extend
([
input_min
,
input_max
])
dequantize_op
.
node_input
[
0
].
node_id
=
op
.
node_id
dequantize_op
.
node_input
[
1
].
node_id
=
op
.
node_id
dequantize_op
.
node_input
[
2
].
node_id
=
op
.
node_id
model_path
=
save_model_to_proto
(
net
,
normalize_op_name
(
op_name
),
FLAGS
.
output_dir
)
output_config
=
{
"model_file_path"
:
str
(
model_path
),
"output_tensors"
:
output_tensors
,
"output_shapes"
:
output_shapes
}
output_configs
[
"subgraphs"
].
append
(
output_config
)
output_configs_path
=
FLAGS
.
output_dir
+
"outputs.yml"
with
open
(
output_configs_path
,
"w"
)
as
f
:
yaml
.
dump
(
output_configs
,
f
,
default_flow_style
=
False
)
def
parse_args
():
"""Parses command line arguments."""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
""
,
help
=
"pb file to load."
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
""
,
help
=
"Directory to save the output graph to."
)
return
parser
.
parse_known_args
()
if
__name__
==
'__main__'
:
FLAGS
,
unparsed
=
parse_args
()
main
(
unused_args
=
[
sys
.
argv
[
0
]]
+
unparsed
)
mace/python/tools/model_saver.py
浏览文件 @
8173b231
...
@@ -190,6 +190,8 @@ def save_model_to_proto(net_def, model_tag, output_dir):
...
@@ -190,6 +190,8 @@ def save_model_to_proto(net_def, model_tag, output_dir):
with
open
(
proto_file_path
+
'_txt'
,
"w"
)
as
f
:
with
open
(
proto_file_path
+
'_txt'
,
"w"
)
as
f
:
f
.
write
(
str
(
net_def
))
f
.
write
(
str
(
net_def
))
return
proto_file_path
def
save_model_to_code
(
net_def
,
model_tag
,
device
,
def
save_model_to_code
(
net_def
,
model_tag
,
device
,
template_dir
,
output_dir
,
embed_model_data
,
template_dir
,
output_dir
,
embed_model_data
,
...
...
tools/common.py
浏览文件 @
8173b231
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
enum
import
enum
import
hashlib
import
hashlib
import
inspect
import
re
import
re
import
os
import
os
...
@@ -34,6 +35,12 @@ class CMDColors:
...
@@ -34,6 +35,12 @@ class CMDColors:
UNDERLINE
=
'
\033
[4m'
UNDERLINE
=
'
\033
[4m'
def
get_frame_info
(
level
=
2
):
caller_frame
=
inspect
.
stack
()[
level
]
info
=
inspect
.
getframeinfo
(
caller_frame
[
0
])
return
info
.
filename
+
':'
+
str
(
info
.
lineno
)
+
': '
class
MaceLogger
:
class
MaceLogger
:
@
staticmethod
@
staticmethod
def
header
(
message
):
def
header
(
message
):
...
@@ -45,22 +52,25 @@ class MaceLogger:
...
@@ -45,22 +52,25 @@ class MaceLogger:
@
staticmethod
@
staticmethod
def
info
(
message
):
def
info
(
message
):
six
.
print_
(
message
)
six
.
print_
(
get_frame_info
()
+
message
)
@
staticmethod
@
staticmethod
def
warning
(
message
):
def
warning
(
message
):
six
.
print_
(
CMDColors
.
YELLOW
+
'WARNING:'
+
message
+
CMDColors
.
ENDC
)
six
.
print_
(
CMDColors
.
YELLOW
+
'WARNING:'
+
get_frame_info
()
+
message
+
CMDColors
.
ENDC
)
@
staticmethod
@
staticmethod
def
error
(
module
,
message
):
def
error
(
module
,
message
,
location_info
=
""
):
six
.
print_
(
CMDColors
.
RED
+
'ERROR: ['
+
module
+
'] '
if
not
location_info
:
location_info
=
get_frame_info
()
six
.
print_
(
CMDColors
.
RED
+
'ERROR: ['
+
module
+
'] '
+
location_info
+
message
+
CMDColors
.
ENDC
)
+
message
+
CMDColors
.
ENDC
)
exit
(
1
)
exit
(
1
)
def
mace_check
(
condition
,
module
,
message
):
def
mace_check
(
condition
,
module
,
message
):
if
not
condition
:
if
not
condition
:
MaceLogger
.
error
(
module
,
message
)
MaceLogger
.
error
(
module
,
message
,
get_frame_info
()
)
################################
################################
...
...
tools/converter.py
浏览文件 @
8173b231
...
@@ -1177,6 +1177,11 @@ def parse_args():
...
@@ -1177,6 +1177,11 @@ def parse_args():
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"whether to verify the results are consistent with "
help
=
"whether to verify the results are consistent with "
"the frameworks."
)
"the frameworks."
)
run
.
add_argument
(
"--validate_all_layers"
,
action
=
"store_true"
,
help
=
"whether to verify the results of all layers are "
"consistent with the frameworks."
)
run
.
add_argument
(
run
.
add_argument
(
"--caffe_env"
,
"--caffe_env"
,
type
=
str_to_caffe_env_type
,
type
=
str_to_caffe_env_type
,
...
...
tools/device.py
浏览文件 @
8173b231
...
@@ -179,6 +179,7 @@ class DeviceWrapper:
...
@@ -179,6 +179,7 @@ class DeviceWrapper:
address_sanitizer
=
False
,
address_sanitizer
=
False
,
link_dynamic
=
False
,
link_dynamic
=
False
,
quantize_stat
=
False
,
quantize_stat
=
False
,
layers_validate_file
=
""
,
):
):
six
.
print_
(
"* Run '%s' with round=%s, restart_round=%s, tuning=%s, "
six
.
print_
(
"* Run '%s' with round=%s, restart_round=%s, tuning=%s, "
"out_of_range_check=%s, omp_num_threads=%s, "
"out_of_range_check=%s, omp_num_threads=%s, "
...
@@ -189,7 +190,8 @@ class DeviceWrapper:
...
@@ -189,7 +190,8 @@ class DeviceWrapper:
cpu_affinity_policy
,
gpu_perf_hint
,
gpu_priority_hint
))
cpu_affinity_policy
,
gpu_perf_hint
,
gpu_priority_hint
))
mace_model_path
=
""
mace_model_path
=
""
if
model_graph_format
==
ModelFormat
.
file
:
if
model_graph_format
==
ModelFormat
.
file
:
mace_model_path
=
"%s/%s.pb"
%
(
mace_model_dir
,
model_tag
)
mace_model_path
=
layers_validate_file
if
layers_validate_file
\
else
"%s/%s.pb"
%
(
mace_model_dir
,
model_tag
)
if
self
.
system
==
SystemType
.
host
:
if
self
.
system
==
SystemType
.
host
:
libmace_dynamic_lib_path
=
\
libmace_dynamic_lib_path
=
\
os
.
path
.
dirname
(
libmace_dynamic_library_path
)
os
.
path
.
dirname
(
libmace_dynamic_library_path
)
...
@@ -413,6 +415,28 @@ class DeviceWrapper:
...
@@ -413,6 +415,28 @@ class DeviceWrapper:
six
.
print_
(
'Tuning done!
\n
'
)
six
.
print_
(
'Tuning done!
\n
'
)
@
staticmethod
def
get_layers
(
model_dir
,
model_name
):
sh_commands
.
bazel_build_common
(
"//mace/python/tools:layers_validate"
)
model_file
=
"%s/%s.pb"
%
(
model_dir
,
model_name
)
output_dir
=
"%s/output_models/"
%
model_dir
if
os
.
path
.
exists
(
output_dir
):
sh
.
rm
(
'-rf'
,
output_dir
)
os
.
makedirs
(
output_dir
)
sh
.
python
(
"bazel-bin/mace/python/tools/layers_validate"
,
"-u"
,
"--model_file=%s"
%
model_file
,
"--output_dir=%s"
%
output_dir
,
_fg
=
True
)
output_configs_path
=
output_dir
+
"outputs.yml"
with
open
(
output_configs_path
)
as
f
:
output_configs
=
yaml
.
load
(
f
)
output_configs
=
output_configs
[
YAMLKeyword
.
subgraphs
]
return
output_configs
def
run_specify_abi
(
self
,
flags
,
configs
,
target_abi
):
def
run_specify_abi
(
self
,
flags
,
configs
,
target_abi
):
if
target_abi
not
in
self
.
target_abis
:
if
target_abi
not
in
self
.
target_abis
:
six
.
print_
(
'There is no device with soc: %s abi: %s'
%
six
.
print_
(
'There is no device with soc: %s abi: %s'
%
...
@@ -527,6 +551,27 @@ class DeviceWrapper:
...
@@ -527,6 +551,27 @@ class DeviceWrapper:
else
:
else
:
output_nodes
=
subgraphs
[
0
][
YAMLKeyword
.
check_tensors
]
output_nodes
=
subgraphs
[
0
][
YAMLKeyword
.
check_tensors
]
output_shapes
=
subgraphs
[
0
][
YAMLKeyword
.
check_shapes
]
output_shapes
=
subgraphs
[
0
][
YAMLKeyword
.
check_shapes
]
output_configs
=
[]
log_file
=
""
if
flags
.
validate_all_layers
:
mace_check
(
configs
[
YAMLKeyword
.
model_graph_format
]
==
ModelFormat
.
file
and
configs
[
YAMLKeyword
.
model_data_format
]
==
ModelFormat
.
file
,
"Device"
,
"'--validate_all_layers' only supports model format 'file'."
)
# noqa
output_configs
=
\
self
.
get_layers
(
mace_model_dir
,
model_name
)
log_dir
=
mace_model_dir
+
"/"
+
runtime
if
os
.
path
.
exists
(
log_dir
):
sh
.
rm
(
'-rf'
,
log_dir
)
os
.
makedirs
(
log_dir
)
log_file
=
log_dir
+
"/log.csv"
model_path
=
"%s/%s.pb"
%
(
mace_model_dir
,
model_name
)
output_config
=
{
YAMLKeyword
.
model_file_path
:
model_path
,
YAMLKeyword
.
output_tensors
:
output_nodes
,
YAMLKeyword
.
output_shapes
:
output_shapes
}
output_configs
.
append
(
output_config
)
for
output_config
in
output_configs
:
run_output
=
self
.
tuning_run
(
run_output
=
self
.
tuning_run
(
abi
=
target_abi
,
abi
=
target_abi
,
target_dir
=
build_tmp_binary_dir
,
target_dir
=
build_tmp_binary_dir
,
...
@@ -535,9 +580,10 @@ class DeviceWrapper:
...
@@ -535,9 +580,10 @@ class DeviceWrapper:
embed_model_data
=
embed_model_data
,
embed_model_data
=
embed_model_data
,
model_output_dir
=
model_output_dir
,
model_output_dir
=
model_output_dir
,
input_nodes
=
subgraphs
[
0
][
YAMLKeyword
.
input_tensors
],
input_nodes
=
subgraphs
[
0
][
YAMLKeyword
.
input_tensors
],
output_nodes
=
output_nodes
,
output_nodes
=
output_config
[
YAMLKeyword
.
output_tensors
],
input_shapes
=
subgraphs
[
0
][
YAMLKeyword
.
input_shapes
],
input_shapes
=
subgraphs
[
0
][
YAMLKeyword
.
input_shapes
],
output_shapes
=
output_shapes
,
output_shapes
=
output_config
[
YAMLKeyword
.
output_shapes
]
,
mace_model_dir
=
mace_model_dir
,
mace_model_dir
=
mace_model_dir
,
model_tag
=
model_name
,
model_tag
=
model_name
,
device_type
=
device_type
,
device_type
=
device_type
,
...
@@ -547,7 +593,8 @@ class DeviceWrapper:
...
@@ -547,7 +593,8 @@ class DeviceWrapper:
YAMLKeyword
.
limit_opencl_kernel_time
],
YAMLKeyword
.
limit_opencl_kernel_time
],
tuning
=
False
,
tuning
=
False
,
out_of_range_check
=
flags
.
gpu_out_of_range_check
,
out_of_range_check
=
flags
.
gpu_out_of_range_check
,
model_graph_format
=
configs
[
YAMLKeyword
.
model_graph_format
],
model_graph_format
=
configs
[
YAMLKeyword
.
model_graph_format
],
omp_num_threads
=
flags
.
omp_num_threads
,
omp_num_threads
=
flags
.
omp_num_threads
,
cpu_affinity_policy
=
flags
.
cpu_affinity_policy
,
cpu_affinity_policy
=
flags
.
cpu_affinity_policy
,
gpu_perf_hint
=
flags
.
gpu_perf_hint
,
gpu_perf_hint
=
flags
.
gpu_perf_hint
,
...
@@ -561,8 +608,10 @@ class DeviceWrapper:
...
@@ -561,8 +608,10 @@ class DeviceWrapper:
quantize_stat
=
flags
.
quantize_stat
,
quantize_stat
=
flags
.
quantize_stat
,
input_dir
=
flags
.
input_dir
,
input_dir
=
flags
.
input_dir
,
output_dir
=
flags
.
output_dir
,
output_dir
=
flags
.
output_dir
,
layers_validate_file
=
output_config
[
YAMLKeyword
.
model_file_path
]
)
)
if
flags
.
validate
:
if
flags
.
validate
or
flags
.
validate_all_layers
:
model_file_path
,
weight_file_path
=
get_model_files
(
model_file_path
,
weight_file_path
=
get_model_files
(
model_config
[
YAMLKeyword
.
model_file_path
],
model_config
[
YAMLKeyword
.
model_file_path
],
model_config
[
YAMLKeyword
.
model_sha256_checksum
],
model_config
[
YAMLKeyword
.
model_sha256_checksum
],
...
@@ -570,7 +619,6 @@ class DeviceWrapper:
...
@@ -570,7 +619,6 @@ class DeviceWrapper:
model_config
[
YAMLKeyword
.
weight_file_path
],
model_config
[
YAMLKeyword
.
weight_file_path
],
model_config
[
YAMLKeyword
.
weight_sha256_checksum
]
model_config
[
YAMLKeyword
.
weight_sha256_checksum
]
)
)
validate_type
=
device_type
validate_type
=
device_type
if
model_config
[
YAMLKeyword
.
quantize
]
==
1
:
if
model_config
[
YAMLKeyword
.
quantize
]
==
1
:
validate_type
=
device_type
+
'_QUANTIZE'
validate_type
=
device_type
+
'_QUANTIZE'
...
@@ -581,17 +629,23 @@ class DeviceWrapper:
...
@@ -581,17 +629,23 @@ class DeviceWrapper:
weight_file_path
=
weight_file_path
,
weight_file_path
=
weight_file_path
,
platform
=
model_config
[
YAMLKeyword
.
platform
],
platform
=
model_config
[
YAMLKeyword
.
platform
],
device_type
=
device_type
,
device_type
=
device_type
,
input_nodes
=
subgraphs
[
0
][
YAMLKeyword
.
input_tensors
],
input_nodes
=
subgraphs
[
0
][
output_nodes
=
output_nodes
,
YAMLKeyword
.
input_tensors
],
input_shapes
=
subgraphs
[
0
][
YAMLKeyword
.
input_shapes
],
output_nodes
=
output_config
[
output_shapes
=
output_shapes
,
YAMLKeyword
.
output_tensors
],
input_shapes
=
subgraphs
[
0
][
YAMLKeyword
.
input_shapes
],
output_shapes
=
output_config
[
YAMLKeyword
.
output_shapes
],
model_output_dir
=
model_output_dir
,
model_output_dir
=
model_output_dir
,
input_data_types
=
subgraphs
[
0
][
input_data_types
=
subgraphs
[
0
][
YAMLKeyword
.
input_data_types
],
YAMLKeyword
.
input_data_types
],
caffe_env
=
flags
.
caffe_env
,
caffe_env
=
flags
.
caffe_env
,
validation_threshold
=
subgraphs
[
0
][
validation_threshold
=
subgraphs
[
0
][
YAMLKeyword
.
validation_threshold
][
validate_type
],
YAMLKeyword
.
validation_threshold
][
backend
=
subgraphs
[
0
][
YAMLKeyword
.
backend
]
validate_type
],
backend
=
subgraphs
[
0
][
YAMLKeyword
.
backend
],
log_file
=
log_file
,
)
)
if
flags
.
report
and
flags
.
round
>
0
:
if
flags
.
report
and
flags
.
round
>
0
:
tuned
=
is_tuned
and
device_type
==
DeviceType
.
GPU
tuned
=
is_tuned
and
device_type
==
DeviceType
.
GPU
...
...
tools/sh_commands.py
浏览文件 @
8173b231
...
@@ -622,7 +622,9 @@ def validate_model(abi,
...
@@ -622,7 +622,9 @@ def validate_model(abi,
input_file_name
=
"model_input"
,
input_file_name
=
"model_input"
,
output_file_name
=
"model_out"
,
output_file_name
=
"model_out"
,
validation_threshold
=
0.9
,
validation_threshold
=
0.9
,
backend
=
"tensorflow"
):
backend
=
"tensorflow"
,
log_file
=
""
,
):
six
.
print_
(
"* Validate with %s"
%
platform
)
six
.
print_
(
"* Validate with %s"
%
platform
)
if
abi
!=
"host"
:
if
abi
!=
"host"
:
for
output_name
in
output_nodes
:
for
output_name
in
output_nodes
:
...
@@ -639,14 +641,16 @@ def validate_model(abi,
...
@@ -639,14 +641,16 @@ def validate_model(abi,
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
validation_threshold
,
","
.
join
(
input_data_types
),
backend
)
validation_threshold
,
","
.
join
(
input_data_types
),
backend
,
log_file
)
elif
platform
==
"onnx"
:
elif
platform
==
"onnx"
:
validate
(
platform
,
model_file_path
,
""
,
validate
(
platform
,
model_file_path
,
""
,
"%s/%s"
%
(
model_output_dir
,
input_file_name
),
"%s/%s"
%
(
model_output_dir
,
input_file_name
),
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
validation_threshold
,
","
.
join
(
input_data_types
),
backend
)
validation_threshold
,
","
.
join
(
input_data_types
),
backend
,
log_file
)
elif
platform
==
"caffe"
:
elif
platform
==
"caffe"
:
image_name
=
"mace-caffe:latest"
image_name
=
"mace-caffe:latest"
container_name
=
"mace_caffe_validator"
container_name
=
"mace_caffe_validator"
...
@@ -662,7 +666,8 @@ def validate_model(abi,
...
@@ -662,7 +666,8 @@ def validate_model(abi,
device_type
,
device_type
,
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
":"
.
join
(
input_shapes
),
":"
.
join
(
output_shapes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
","
.
join
(
input_nodes
),
","
.
join
(
output_nodes
),
validation_threshold
,
","
.
join
(
input_data_types
),
backend
)
validation_threshold
,
","
.
join
(
input_data_types
),
backend
,
log_file
)
elif
caffe_env
==
common
.
CaffeEnvType
.
DOCKER
:
elif
caffe_env
==
common
.
CaffeEnvType
.
DOCKER
:
docker_image_id
=
sh
.
docker
(
"images"
,
"-q"
,
image_name
)
docker_image_id
=
sh
.
docker
(
"images"
,
"-q"
,
image_name
)
if
not
docker_image_id
:
if
not
docker_image_id
:
...
...
tools/validate.py
浏览文件 @
8173b231
...
@@ -79,7 +79,7 @@ def calculate_pixel_accuracy(out_value, mace_out_value):
...
@@ -79,7 +79,7 @@ def calculate_pixel_accuracy(out_value, mace_out_value):
def
compare_output
(
platform
,
device_type
,
output_name
,
mace_out_value
,
def
compare_output
(
platform
,
device_type
,
output_name
,
mace_out_value
,
out_value
,
validation_threshold
):
out_value
,
validation_threshold
,
log_file
):
if
mace_out_value
.
size
!=
0
:
if
mace_out_value
.
size
!=
0
:
pixel_accuracy
=
calculate_pixel_accuracy
(
out_value
,
mace_out_value
)
pixel_accuracy
=
calculate_pixel_accuracy
(
out_value
,
mace_out_value
)
out_value
=
out_value
.
reshape
(
-
1
)
out_value
=
out_value
.
reshape
(
-
1
)
...
@@ -91,7 +91,18 @@ def compare_output(platform, device_type, output_name, mace_out_value,
...
@@ -91,7 +91,18 @@ def compare_output(platform, device_type, output_name, mace_out_value,
output_name
+
' MACE VS '
+
platform
.
upper
()
output_name
+
' MACE VS '
+
platform
.
upper
()
+
' similarity: '
+
str
(
similarity
)
+
' , sqnr: '
+
str
(
sqnr
)
+
' similarity: '
+
str
(
similarity
)
+
' , sqnr: '
+
str
(
sqnr
)
+
' , pixel_accuracy: '
+
str
(
pixel_accuracy
))
+
' , pixel_accuracy: '
+
str
(
pixel_accuracy
))
if
similarity
>
validation_threshold
:
if
log_file
:
if
not
os
.
path
.
exists
(
log_file
):
with
open
(
log_file
,
'w'
)
as
f
:
f
.
write
(
'output_name,similarity,sqnr,pixel_accuracy
\n
'
)
summary
=
'{output_name},{similarity},{sqnr},{pixel_accuracy}
\n
'
\
.
format
(
output_name
=
output_name
,
similarity
=
similarity
,
sqnr
=
sqnr
,
pixel_accuracy
=
pixel_accuracy
)
with
open
(
log_file
,
"a"
)
as
f
:
f
.
write
(
summary
)
elif
similarity
>
validation_threshold
:
common
.
MaceLogger
.
summary
(
common
.
MaceLogger
.
summary
(
common
.
StringFormatter
.
block
(
"Similarity Test Passed"
))
common
.
StringFormatter
.
block
(
"Similarity Test Passed"
))
else
:
else
:
...
@@ -112,7 +123,8 @@ def normalize_tf_tensor_name(name):
...
@@ -112,7 +123,8 @@ def normalize_tf_tensor_name(name):
def
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
def
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
,
validation_threshold
,
input_data_types
):
output_names
,
validation_threshold
,
input_data_types
,
log_file
):
import
tensorflow
as
tf
import
tensorflow
as
tf
if
not
os
.
path
.
isfile
(
model_file
):
if
not
os
.
path
.
isfile
(
model_file
):
common
.
MaceLogger
.
error
(
common
.
MaceLogger
.
error
(
...
@@ -151,12 +163,13 @@ def validate_tf_model(platform, device_type, model_file, input_file,
...
@@ -151,12 +163,13 @@ def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_value
=
load_data
(
output_file_name
)
mace_out_value
=
load_data
(
output_file_name
)
compare_output
(
platform
,
device_type
,
output_names
[
i
],
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
output_values
[
i
],
mace_out_value
,
output_values
[
i
],
validation_threshold
)
validation_threshold
,
log_file
)
def
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
def
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
weight_file
,
input_names
,
input_shapes
,
mace_out_file
,
weight_file
,
input_names
,
input_shapes
,
output_names
,
output_shapes
,
validation_threshold
):
output_names
,
output_shapes
,
validation_threshold
,
log_file
):
os
.
environ
[
'GLOG_minloglevel'
]
=
'1'
# suprress Caffe verbose prints
os
.
environ
[
'GLOG_minloglevel'
]
=
'1'
# suprress Caffe verbose prints
import
caffe
import
caffe
if
not
os
.
path
.
isfile
(
model_file
):
if
not
os
.
path
.
isfile
(
model_file
):
...
@@ -201,13 +214,13 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
...
@@ -201,13 +214,13 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file
,
output_names
[
i
])
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
mace_out_value
=
load_data
(
output_file_name
)
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
value
,
validation_threshold
)
value
,
validation_threshold
,
log_file
)
def
validate_onnx_model
(
platform
,
device_type
,
model_file
,
input_file
,
def
validate_onnx_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
,
output_shapes
,
validation_threshold
,
output_names
,
output_shapes
,
validation_threshold
,
input_data_types
,
backend
):
input_data_types
,
backend
,
log_file
):
import
onnx
import
onnx
if
backend
==
"tensorflow"
:
if
backend
==
"tensorflow"
:
from
onnx_tf.backend
import
prepare
from
onnx_tf.backend
import
prepare
...
@@ -257,12 +270,12 @@ def validate_onnx_model(platform, device_type, model_file, input_file,
...
@@ -257,12 +270,12 @@ def validate_onnx_model(platform, device_type, model_file, input_file,
mace_out_value
=
load_data
(
output_file_name
)
mace_out_value
=
load_data
(
output_file_name
)
compare_output
(
platform
,
device_type
,
output_names
[
i
],
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
value
,
mace_out_value
,
value
,
validation_threshold
)
validation_threshold
,
log_file
)
def
validate
(
platform
,
model_file
,
weight_file
,
input_file
,
mace_out_file
,
def
validate
(
platform
,
model_file
,
weight_file
,
input_file
,
mace_out_file
,
device_type
,
input_shape
,
output_shape
,
input_node
,
output_node
,
device_type
,
input_shape
,
output_shape
,
input_node
,
output_node
,
validation_threshold
,
input_data_type
,
backend
):
validation_threshold
,
input_data_type
,
backend
,
log_file
):
input_names
=
[
name
for
name
in
input_node
.
split
(
','
)]
input_names
=
[
name
for
name
in
input_node
.
split
(
','
)]
input_shape_strs
=
[
shape
for
shape
in
input_shape
.
split
(
':'
)]
input_shape_strs
=
[
shape
for
shape
in
input_shape
.
split
(
':'
)]
input_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
input_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
...
@@ -278,7 +291,8 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
...
@@ -278,7 +291,8 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
if
platform
==
'tensorflow'
:
if
platform
==
'tensorflow'
:
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
validate_tf_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
,
validation_threshold
,
input_data_types
)
output_names
,
validation_threshold
,
input_data_types
,
log_file
)
elif
platform
==
'caffe'
:
elif
platform
==
'caffe'
:
output_shape_strs
=
[
shape
for
shape
in
output_shape
.
split
(
':'
)]
output_shape_strs
=
[
shape
for
shape
in
output_shape
.
split
(
':'
)]
output_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
output_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
...
@@ -286,7 +300,7 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
...
@@ -286,7 +300,7 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
weight_file
,
input_names
,
mace_out_file
,
weight_file
,
input_names
,
input_shapes
,
output_names
,
output_shapes
,
input_shapes
,
output_names
,
output_shapes
,
validation_threshold
)
validation_threshold
,
log_file
)
elif
platform
==
'onnx'
:
elif
platform
==
'onnx'
:
output_shape_strs
=
[
shape
for
shape
in
output_shape
.
split
(
':'
)]
output_shape_strs
=
[
shape
for
shape
in
output_shape
.
split
(
':'
)]
output_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
output_shapes
=
[[
int
(
x
)
for
x
in
shape
.
split
(
','
)]
...
@@ -295,7 +309,7 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
...
@@ -295,7 +309,7 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
mace_out_file
,
input_names
,
input_shapes
,
mace_out_file
,
input_names
,
input_shapes
,
output_names
,
output_shapes
,
output_names
,
output_shapes
,
validation_threshold
,
validation_threshold
,
input_data_types
,
backend
)
input_data_types
,
backend
,
log_file
)
def
parse_args
():
def
parse_args
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录