Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
5e02dc0e
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5e02dc0e
编写于
5月 04, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'check_model_configs' into 'master'
Check model configs See merge request !458
上级
933f75b7
63d76727
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
76 addition
and
14 deletion
+76
-14
mace/kernels/opencl/cl/common.h
mace/kernels/opencl/cl/common.h
+2
-2
mace/python/tools/converter.py
mace/python/tools/converter.py
+4
-1
tools/mace_tools.py
tools/mace_tools.py
+70
-11
未找到文件。
mace/kernels/opencl/cl/common.h
浏览文件 @
5e02dc0e
...
...
@@ -32,8 +32,8 @@
#define CHECK_OUT_OF_RANGE_FOR_IMAGE2D(image, coord)
#endif
#define READ_IMAGET(image,
coord, value
) \
CMD_TYPE(read_image, CMD_DATA_TYPE)(image,
coord, value
)
#define READ_IMAGET(image,
sampler, coord
) \
CMD_TYPE(read_image, CMD_DATA_TYPE)(image,
sampler, coord
)
#define WRITE_IMAGET(image, coord, value) \
CHECK_OUT_OF_RANGE_FOR_IMAGE2D(image, coord) \
CMD_TYPE(write_image, CMD_DATA_TYPE)(image, coord, value);
...
...
mace/python/tools/converter.py
浏览文件 @
5e02dc0e
...
...
@@ -175,7 +175,10 @@ def parse_args():
parser
.
add_argument
(
"--platform"
,
type
=
str
,
default
=
"tensorflow"
,
help
=
"tensorflow/caffe"
)
parser
.
add_argument
(
"--embed_model_data"
,
type
=
str2bool
,
default
=
True
,
help
=
"input shape."
)
"--embed_model_data"
,
type
=
str2bool
,
default
=
True
,
help
=
"embed model data."
)
return
parser
.
parse_known_args
()
...
...
tools/mace_tools.py
浏览文件 @
5e02dc0e
...
...
@@ -336,8 +336,74 @@ def str_to_caffe_env_type(v):
def
parse_model_configs
():
print
(
"============== Load and Parse configs =============="
)
with
open
(
FLAGS
.
config
)
as
f
:
configs
=
yaml
.
load
(
f
)
target_abis
=
configs
.
get
(
"target_abis"
,
[])
if
not
isinstance
(
target_abis
,
list
)
or
not
target_abis
:
print
(
"CONFIG ERROR:"
)
print
(
"target_abis list is needed!"
)
print
(
"For example: 'target_abis: [armeabi-v7a, arm64-v8a]'"
)
exit
(
1
)
embed_model_data
=
configs
.
get
(
"embed_model_data"
,
""
)
if
embed_model_data
==
""
or
not
isinstance
(
embed_model_data
,
int
)
or
\
embed_model_data
<
0
or
embed_model_data
>
1
:
print
(
"CONFIG ERROR:"
)
print
(
"embed_model_data must be integer in range [0, 1]"
)
exit
(
1
)
model_names
=
configs
.
get
(
"models"
,
""
)
if
not
model_names
:
print
(
"CONFIG ERROR:"
)
print
(
"models attribute not found in config file"
)
exit
(
1
)
for
model_name
in
model_names
:
model_config
=
configs
[
"models"
][
model_name
]
platform
=
model_config
.
get
(
"platform"
,
""
)
if
platform
==
""
or
platform
not
in
[
"tensorflow"
,
"caffe"
]:
print
(
"CONFIG ERROR:"
)
print
(
"'platform' must be 'tensorflow' or 'caffe'"
)
exit
(
1
)
for
key
in
[
"model_file_path"
,
"model_sha256_checksum"
,
"runtime"
]:
value
=
model_config
.
get
(
key
,
""
)
if
value
==
""
:
print
(
"CONFIG ERROR:"
)
print
(
"'%s' is necessary"
%
key
)
exit
(
1
)
for
key
in
[
"input_nodes"
,
"input_shapes"
,
"output_nodes"
,
"output_shapes"
]:
value
=
model_config
.
get
(
key
,
""
)
if
value
==
""
:
print
(
"CONFIG ERROR:"
)
print
(
"'%s' is necessary"
%
key
)
exit
(
1
)
if
not
isinstance
(
value
,
list
):
model_config
[
key
]
=
[
value
]
for
key
in
[
"limit_opencl_kernel_time"
,
"dsp_mode"
,
"obfuscate"
,
"fast_conv"
]:
value
=
model_config
.
get
(
key
,
""
)
if
value
==
""
:
model_config
[
key
]
=
0
print
(
"'%s' for %s is set to default value: 0"
%
(
key
,
model_name
))
validation_inputs_data
=
model_config
.
get
(
"validation_inputs_data"
,
[])
model_config
[
"validation_inputs_data"
]
=
validation_inputs_data
if
not
isinstance
(
validation_inputs_data
,
list
):
model_config
[
"validation_inputs_data"
]
=
[
validation_inputs_data
]
weight_file_path
=
model_config
.
get
(
"weight_file_path"
,
""
)
model_config
[
"weight_file_path"
]
=
weight_file_path
print
(
"Parse model configs successfully!
\n
"
)
return
configs
...
...
@@ -434,16 +500,10 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
for
model_name
in
configs
[
"models"
]:
print
'==================='
,
model_name
,
'==================='
model_config
=
configs
[
"models"
][
model_name
]
input_file_list
=
model_config
.
get
(
"validation_inputs_data"
,
[])
input_file_list
=
model_config
[
"validation_inputs_data"
]
data_type
,
device_type
=
get_data_and_device_type
(
model_config
[
"runtime"
])
for
key
in
[
"input_nodes"
,
"output_nodes"
,
"input_shapes"
,
"output_shapes"
]:
if
not
isinstance
(
model_config
[
key
],
list
):
model_config
[
key
]
=
[
model_config
[
key
]]
# Create model build directory
model_path_digest
=
md5sum
(
model_config
[
"model_file_path"
])
...
...
@@ -472,7 +532,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_file_path
,
weight_file_path
=
get_model_files
(
model_config
[
"model_file_path"
],
model_output_dir
,
model_config
.
get
(
"weight_file_path"
,
""
)
)
model_config
[
"weight_file_path"
]
)
if
FLAGS
.
mode
==
"build"
or
FLAGS
.
mode
==
"run"
or
\
FLAGS
.
mode
==
"validate"
or
\
...
...
@@ -604,8 +664,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
if
os
.
path
.
exists
(
throughput_test_output_dir
):
sh
.
rm
(
"-rf"
,
throughput_test_output_dir
)
os
.
makedirs
(
throughput_test_output_dir
)
input_file_list
=
model_config
.
get
(
"validation_inputs_data"
,
[])
input_file_list
=
model_config
[
"validation_inputs_data"
]
sh_commands
.
gen_random_input
(
throughput_test_output_dir
,
first_model
[
"input_nodes"
],
first_model
[
"input_shapes"
],
...
...
@@ -654,7 +713,7 @@ def main(unused_args):
target_socs
=
get_target_socs
(
configs
)
embed_model_data
=
configs
.
get
(
"embed_model_data"
,
1
)
embed_model_data
=
configs
[
"embed_model_data"
]
vlog_level
=
FLAGS
.
vlog_level
phone_data_dir
=
"/data/local/tmp/mace_run/"
for
target_abi
in
configs
[
"target_abis"
]:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录