Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
54a44cec
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
54a44cec
编写于
1月 25, 2019
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix mace_run & benchmark in code model format
上级
7be6c667
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
116 addition
and
49 deletion
+116
-49
mace/benchmark/benchmark_model.cc
mace/benchmark/benchmark_model.cc
+10
-5
mace/examples/cli/example.cc
mace/examples/cli/example.cc
+35
-33
mace/python/tools/mace_engine_factory.h.jinja2
mace/python/tools/mace_engine_factory.h.jinja2
+44
-1
mace/tools/validation/mace_run.cc
mace/tools/validation/mace_run.cc
+6
-3
tools/device.py
tools/device.py
+21
-7
未找到文件。
mace/benchmark/benchmark_model.cc
浏览文件 @
54a44cec
...
...
@@ -278,19 +278,24 @@ int Main(int argc, char **argv) {
MaceStatus
create_engine_status
;
// Create Engine
std
::
vector
<
unsigned
char
>
model_graph_data
;
if
(
!
mace
::
ReadBinaryFile
(
&
model_graph_data
,
FLAGS_model_file
))
{
LOG
(
FATAL
)
<<
"Failed to read file: "
<<
FLAGS_model_file
;
if
(
FLAGS_model_file
!=
""
)
{
if
(
!
mace
::
ReadBinaryFile
(
&
model_graph_data
,
FLAGS_model_file
))
{
LOG
(
FATAL
)
<<
"Failed to read file: "
<<
FLAGS_model_file
;
}
}
std
::
vector
<
unsigned
char
>
model_weights_data
;
if
(
!
mace
::
ReadBinaryFile
(
&
model_weights_data
,
FLAGS_model_data_file
))
{
LOG
(
FATAL
)
<<
"Failed to read file: "
<<
FLAGS_model_data_file
;
if
(
FLAGS_model_data_file
!=
""
)
{
if
(
!
mace
::
ReadBinaryFile
(
&
model_weights_data
,
FLAGS_model_data_file
))
{
LOG
(
FATAL
)
<<
"Failed to read file: "
<<
FLAGS_model_data_file
;
}
}
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status
=
CreateMaceEngineFromCode
(
FLAGS_model_name
,
model_data_file_ptr
,
model_weights_data
.
data
(),
model_weights_data
.
size
(),
input_names
,
output_names
,
config
,
...
...
mace/examples/cli/example.cc
浏览文件 @
54a44cec
...
...
@@ -59,6 +59,29 @@ std::vector<std::string> Split(const std::string &str, char delims) {
}
// namespace str_util
namespace
{
bool
ReadBinaryFile
(
std
::
vector
<
unsigned
char
>
*
data
,
const
std
::
string
&
filename
)
{
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
!
ifs
.
is_open
())
{
return
false
;
}
ifs
.
seekg
(
0
,
ifs
.
end
);
size_t
length
=
ifs
.
tellg
();
ifs
.
seekg
(
0
,
ifs
.
beg
);
data
->
reserve
(
length
);
data
->
insert
(
data
->
begin
(),
std
::
istreambuf_iterator
<
char
>
(
ifs
),
std
::
istreambuf_iterator
<
char
>
());
if
(
ifs
.
fail
())
{
return
false
;
}
ifs
.
close
();
return
true
;
}
}
// namespace
void
ParseShape
(
const
std
::
string
&
str
,
std
::
vector
<
int64_t
>
*
shape
)
{
std
::
string
tmp
=
str
;
while
(
!
tmp
.
empty
())
{
...
...
@@ -142,30 +165,6 @@ DEFINE_int32(gpu_priority_hint, 1, "0:DEFAULT/1:LOW/2:NORMAL/3:HIGH");
DEFINE_int32
(
omp_num_threads
,
-
1
,
"num of openmp threads"
);
DEFINE_int32
(
cpu_affinity_policy
,
1
,
"0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY"
);
#ifndef MODEL_GRAPH_FORMAT_CODE
namespace
{
bool
ReadBinaryFile
(
std
::
vector
<
unsigned
char
>
*
data
,
const
std
::
string
&
filename
)
{
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
!
ifs
.
is_open
())
{
return
false
;
}
ifs
.
seekg
(
0
,
ifs
.
end
);
size_t
length
=
ifs
.
tellg
();
ifs
.
seekg
(
0
,
ifs
.
beg
);
data
->
reserve
(
length
);
data
->
insert
(
data
->
begin
(),
std
::
istreambuf_iterator
<
char
>
(
ifs
),
std
::
istreambuf_iterator
<
char
>
());
if
(
ifs
.
fail
())
{
return
false
;
}
ifs
.
close
();
return
true
;
}
}
// namespace
#endif
bool
RunModel
(
const
std
::
vector
<
std
::
string
>
&
input_names
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
input_shapes
,
...
...
@@ -212,6 +211,16 @@ bool RunModel(const std::vector<std::string> &input_names,
// Create Engine
std
::
shared_ptr
<
mace
::
MaceEngine
>
engine
;
MaceStatus
create_engine_status
;
std
::
vector
<
unsigned
char
>
model_graph_data
;
if
(
!
ReadBinaryFile
(
&
model_graph_data
,
FLAGS_model_file
))
{
std
::
cerr
<<
"Failed to read file: "
<<
FLAGS_model_file
<<
std
::
endl
;
}
std
::
vector
<
unsigned
char
>
model_weights_data
;
if
(
!
ReadBinaryFile
(
&
model_weights_data
,
FLAGS_model_data_file
))
{
std
::
cerr
<<
"Failed to read file: "
<<
FLAGS_model_data_file
<<
std
::
endl
;
}
// Only choose one of the two type based on the `model_graph_format`
// in model deployment file(.yml).
#ifdef MODEL_GRAPH_FORMAT_CODE
...
...
@@ -219,20 +228,13 @@ bool RunModel(const std::vector<std::string> &input_names,
// to model_data_file parameter.
create_engine_status
=
CreateMaceEngineFromCode
(
FLAGS_model_name
,
FLAGS_model_data_file
,
model_weights_data
.
data
(),
model_weights_data
.
size
(),
input_names
,
output_names
,
config
,
&
engine
);
#else
std
::
vector
<
unsigned
char
>
model_graph_data
;
if
(
!
ReadBinaryFile
(
&
model_graph_data
,
FLAGS_model_file
))
{
std
::
cerr
<<
"Failed to read file: "
<<
FLAGS_model_file
<<
std
::
endl
;
}
std
::
vector
<
unsigned
char
>
model_weights_data
;
if
(
!
ReadBinaryFile
(
&
model_weights_data
,
FLAGS_model_data_file
))
{
std
::
cerr
<<
"Failed to read file: "
<<
FLAGS_model_data_file
<<
std
::
endl
;
}
create_engine_status
=
CreateMaceEngineFromProto
(
model_graph_data
.
data
(),
model_graph_data
.
size
(),
...
...
mace/python/tools/mace_engine_factory.h.jinja2
浏览文件 @
54a44cec
...
...
@@ -62,7 +62,7 @@ std::map<std::string, int> model_name_map {
/// \param engine[out]: output MaceEngine object
/// \return MaceStatus::MACE_SUCCESS for success, MACE_INVALID_ARGS for wrong arguments,
/// MACE_OUT_OF_RESOURCES for resources is out of range.
MaceStatus CreateMaceEngineFromCode(
__attribute__((deprecated))
MaceStatus CreateMaceEngineFromCode(
const std::string &model_name,
const std::string &model_data_file,
const std::vector<std::string> &input_nodes,
...
...
@@ -101,5 +101,48 @@ MaceStatus CreateMaceEngineFromCode(
return status;
}
MaceStatus CreateMaceEngineFromCode(
const std::string &model_name,
const unsigned char *model_weights_data,
const size_t model_weights_data_size,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const MaceEngineConfig &config,
std::shared_ptr<MaceEngine> *engine) {
// load model
if (engine == nullptr) {
return MaceStatus::MACE_INVALID_ARGS;
}
std::shared_ptr<NetDef> net_def;
{% if embed_model_data %}
const unsigned char * model_data;
(void)model_weights_data;
{% endif %}
// TODO(yejianwu) Add buffer range checking
(void)model_weights_data_size;
MaceStatus status = MaceStatus::MACE_SUCCESS;
switch (model_name_map[model_name]) {
{% for i in range(model_tags |length) %}
case {{ i }}:
net_def = mace::{{model_tags[i]}}::CreateNet();
engine->reset(new mace::MaceEngine(config));
{% if embed_model_data %}
model_data = mace::{{model_tags[i]}}::LoadModelData();
status = (*engine)->Init(net_def.get(), input_nodes, output_nodes,
model_data);
{% else %}
status = (*engine)->Init(net_def.get(), input_nodes, output_nodes,
model_weights_data);
{% endif %}
break;
{% endfor %}
default:
status = MaceStatus::MACE_INVALID_ARGS;
}
return status;
}
} // namespace mace
#endif // MACE_CODEGEN_ENGINE_MACE_ENGINE_FACTORY_H_
mace/tools/validation/mace_run.cc
浏览文件 @
54a44cec
...
...
@@ -264,7 +264,8 @@ bool RunModel(const std::string &model_name,
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status
=
CreateMaceEngineFromCode
(
model_name
,
FLAGS_model_data_file
,
model_weights_data
.
data
(),
model_weights_data
.
size
(),
input_names
,
output_names
,
config
,
...
...
@@ -340,7 +341,8 @@ bool RunModel(const std::string &model_name,
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status
=
CreateMaceEngineFromCode
(
model_name
,
FLAGS_model_data_file
,
model_weights_data
.
data
(),
model_weights_data
.
size
(),
input_names
,
output_names
,
config
,
...
...
@@ -382,7 +384,8 @@ bool RunModel(const std::string &model_name,
#ifdef MODEL_GRAPH_FORMAT_CODE
create_engine_status
=
CreateMaceEngineFromCode
(
model_name
,
FLAGS_model_data_file
,
model_weights_data
.
data
(),
model_weights_data
.
size
(),
input_names
,
output_names
,
config
,
...
...
tools/device.py
浏览文件 @
54a44cec
...
...
@@ -192,6 +192,14 @@ class DeviceWrapper:
if
model_graph_format
==
ModelFormat
.
file
:
mace_model_path
=
layers_validate_file
if
layers_validate_file
\
else
"%s/%s.pb"
%
(
mace_model_dir
,
model_tag
)
model_data_file
=
""
if
not
embed_model_data
:
if
self
.
system
==
SystemType
.
host
:
model_data_file
=
"%s/%s.data"
%
(
mace_model_dir
,
model_tag
)
else
:
model_data_file
=
"%s/%s.data"
%
(
self
.
data_dir
,
model_tag
)
if
self
.
system
==
SystemType
.
host
:
libmace_dynamic_lib_path
=
\
os
.
path
.
dirname
(
libmace_dynamic_library_path
)
...
...
@@ -214,8 +222,7 @@ class DeviceWrapper:
output_file_name
),
"--input_dir=%s"
%
input_dir
,
"--output_dir=%s"
%
output_dir
,
"--model_data_file=%s/%s.data"
%
(
mace_model_dir
,
model_tag
),
"--model_data_file=%s"
%
model_data_file
,
"--device=%s"
%
device_type
,
"--round=%s"
%
running_round
,
"--restart_round=%s"
%
restart_round
,
...
...
@@ -229,7 +236,7 @@ class DeviceWrapper:
stdout
=
subprocess
.
PIPE
)
out
,
err
=
p
.
communicate
()
self
.
stdout
=
err
+
out
six
.
print_
(
self
.
stdout
)
six
.
print_
(
self
.
stdout
.
decode
(
'UTF-8'
)
)
six
.
print_
(
"Running finished!
\n
"
)
elif
self
.
system
in
[
SystemType
.
android
,
SystemType
.
arm_linux
]:
self
.
rm
(
self
.
data_dir
)
...
...
@@ -304,7 +311,7 @@ class DeviceWrapper:
"--output_file=%s/%s"
%
(
self
.
data_dir
,
output_file_name
),
"--input_dir=%s"
%
input_dir
,
"--output_dir=%s"
%
output_dir
,
"--model_data_file=%s
/%s.data"
%
(
self
.
data_dir
,
model_tag
)
,
"--model_data_file=%s
"
%
model_data_file
,
"--device=%s"
%
device_type
,
"--round=%s"
%
running_round
,
"--restart_round=%s"
%
restart_round
,
...
...
@@ -753,6 +760,14 @@ class DeviceWrapper:
mace_model_path
=
''
if
model_graph_format
==
ModelFormat
.
file
:
mace_model_path
=
'%s/%s.pb'
%
(
mace_model_dir
,
model_tag
)
model_data_file
=
""
if
not
embed_model_data
:
if
self
.
system
==
SystemType
.
host
:
model_data_file
=
"%s/%s.data"
%
(
mace_model_dir
,
model_tag
)
else
:
model_data_file
=
"%s/%s.data"
%
(
self
.
data_dir
,
model_tag
)
if
abi
==
ABIType
.
host
:
libmace_dynamic_lib_dir_path
=
\
os
.
path
.
dirname
(
libmace_dynamic_library_path
)
...
...
@@ -768,8 +783,7 @@ class DeviceWrapper:
'--input_shape=%s'
%
':'
.
join
(
input_shapes
),
'--output_shape=%s'
%
':'
.
join
(
output_shapes
),
'--input_file=%s/%s'
%
(
model_output_dir
,
input_file_name
),
'--model_data_file=%s/%s.data'
%
(
mace_model_dir
,
model_tag
),
"--model_data_file=%s"
%
model_data_file
,
'--device=%s'
%
device_type
,
'--omp_num_threads=%s'
%
omp_num_threads
,
'--cpu_affinity_policy=%s'
%
cpu_affinity_policy
,
...
...
@@ -822,7 +836,7 @@ class DeviceWrapper:
'--input_shape=%s'
%
':'
.
join
(
input_shapes
),
'--output_shape=%s'
%
':'
.
join
(
output_shapes
),
'--input_file=%s/%s'
%
(
self
.
data_dir
,
input_file_name
),
'--model_data_file=%s/%s.data'
%
(
self
.
data_dir
,
model_tag
)
,
"--model_data_file=%s"
%
model_data_file
,
'--device=%s'
%
device_type
,
'--omp_num_threads=%s'
%
omp_num_threads
,
'--cpu_affinity_policy=%s'
%
cpu_affinity_policy
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录