Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Mr.Vain
Mace
提交
cefa11d4
Mace
项目概览
Mr.Vain
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cefa11d4
编写于
5月 18, 2018
作者:
Y
yejianwu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add CreateMaceEngineFromPB api, fix merge_libs
上级
c59be5ae
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
231 addition
and
95 deletion
+231
-95
mace/benchmark/benchmark_model.cc
mace/benchmark/benchmark_model.cc
+7
-8
mace/core/mace.cc
mace/core/mace.cc
+61
-0
mace/core/operator.h
mace/core/operator.h
+2
-2
mace/core/runtime/hexagon/hexagon_control_wrapper.cc
mace/core/runtime/hexagon/hexagon_control_wrapper.cc
+2
-2
mace/examples/example.cc
mace/examples/example.cc
+23
-9
mace/public/mace.h
mace/public/mace.h
+7
-2
mace/python/tools/mace_engine_factory.h.jinja2
mace/python/tools/mace_engine_factory.h.jinja2
+22
-4
mace/python/tools/mace_engine_factory_codegen.py
mace/python/tools/mace_engine_factory_codegen.py
+8
-2
mace/python/tools/source_converter_lib.py
mace/python/tools/source_converter_lib.py
+39
-38
mace/tools/validation/mace_run.cc
mace/tools/validation/mace_run.cc
+7
-8
tools/mace_tools.py
tools/mace_tools.py
+15
-7
tools/sh_commands.py
tools/sh_commands.py
+38
-13
未找到文件。
mace/benchmark/benchmark_model.cc
浏览文件 @
cefa11d4
...
...
@@ -280,16 +280,15 @@ int Main(int argc, char **argv) {
LOG
(
FATAL
)
<<
"Failed to read file: "
<<
FLAGS_model_file
;
}
create_engine_status
=
CreateMaceEngine
(
FLAGS_model_name
.
c_str
(),
model_data_file_ptr
,
input_names
,
output_names
,
device_type
,
&
engine
,
model_pb_data
);
CreateMaceEngineFromPB
(
model_data_file_ptr
,
input_names
,
output_names
,
device_type
,
&
engine
,
model_pb_data
);
}
else
{
create_engine_status
=
CreateMaceEngine
(
FLAGS_model_name
.
c_str
()
,
CreateMaceEngine
(
FLAGS_model_name
,
model_data_file_ptr
,
input_names
,
output_names
,
...
...
mace/core/mace.cc
浏览文件 @
cefa11d4
...
...
@@ -12,12 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <errno.h>
#include <fcntl.h>
#include <memory>
#include <string.h>
#include <sys/mman.h>
#include <unistd.h>
#include "mace/core/net.h"
#include "mace/core/types.h"
#include "mace/public/mace.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_runtime.h"
#endif // MACE_ENABLE_OPENCL
...
...
@@ -269,4 +276,58 @@ MaceStatus MaceEngine::Run(const std::map<std::string, MaceTensor> &inputs,
return
impl_
->
Run
(
inputs
,
outputs
,
nullptr
);
}
namespace
{
const
unsigned
char
*
LoadModelData
(
const
char
*
model_data_file
)
{
int
fd
=
open
(
model_data_file
,
O_RDONLY
);
MACE_CHECK
(
fd
>=
0
,
"Failed to open model data file "
,
model_data_file
,
", error code: "
,
errno
);
const
unsigned
char
*
model_data
=
static_cast
<
const
unsigned
char
*>
(
mmap
(
nullptr
,
2453764
,
PROT_READ
,
MAP_PRIVATE
,
fd
,
0
));
MACE_CHECK
(
model_data
!=
MAP_FAILED
,
"Failed to map model data file "
,
model_data_file
,
", error code: "
,
errno
);
int
ret
=
close
(
fd
);
MACE_CHECK
(
ret
==
0
,
"Failed to close model data file "
,
model_data_file
,
", error code: "
,
errno
);
return
model_data
;
}
void
UnloadModelData
(
const
unsigned
char
*
model_data
)
{
int
ret
=
munmap
(
const_cast
<
unsigned
char
*>
(
model_data
),
2453764
);
MACE_CHECK
(
ret
==
0
,
"Failed to unmap model data file, error code: "
,
errno
);
}
}
// namespace
MaceStatus
CreateMaceEngineFromPB
(
const
char
*
model_data_file
,
const
std
::
vector
<
std
::
string
>
&
input_nodes
,
const
std
::
vector
<
std
::
string
>
&
output_nodes
,
const
DeviceType
device_type
,
std
::
shared_ptr
<
MaceEngine
>
*
engine
,
const
std
::
vector
<
unsigned
char
>
model_pb
)
{
LOG
(
INFO
)
<<
"Create MaceEngine from model pb"
;
// load model
if
(
engine
==
nullptr
)
{
return
MaceStatus
::
MACE_INVALID_ARGS
;
}
const
unsigned
char
*
model_data
=
nullptr
;
model_data
=
LoadModelData
(
model_data_file
);
NetDef
net_def
;
net_def
.
ParseFromArray
(
&
model_pb
[
0
],
model_pb
.
size
());
engine
->
reset
(
new
mace
::
MaceEngine
(
&
net_def
,
device_type
,
input_nodes
,
output_nodes
,
model_data
));
if
(
device_type
==
DeviceType
::
GPU
||
device_type
==
DeviceType
::
HEXAGON
)
{
UnloadModelData
(
model_data
);
}
return
MACE_SUCCESS
;
}
}
// namespace mace
mace/core/operator.h
浏览文件 @
cefa11d4
...
...
@@ -108,7 +108,7 @@ class Operator : public OperatorBase {
inputs_
.
push_back
(
tensor
);
}
for
(
size_t
i
=
0
;
i
<
(
size_t
)
operator_def
.
output_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
operator_def
.
output_size
();
++
i
)
{
const
std
::
string
output_str
=
operator_def
.
output
(
i
);
if
(
ws
->
HasTensor
(
output_str
))
{
outputs_
.
push_back
(
ws
->
GetTensor
(
output_str
));
...
...
@@ -120,7 +120,7 @@ class Operator : public OperatorBase {
operator_def
.
output_size
(),
operator_def
.
output_type_size
());
DataType
output_type
;
if
(
i
<
(
size_t
)
operator_def
.
output_type_size
())
{
if
(
i
<
operator_def
.
output_type_size
())
{
output_type
=
operator_def
.
output_type
(
i
);
}
else
{
output_type
=
DataTypeToEnum
<
T
>::
v
();
...
...
mace/core/runtime/hexagon/hexagon_control_wrapper.cc
浏览文件 @
cefa11d4
...
...
@@ -134,12 +134,12 @@ bool HexagonControlWrapper::SetupGraph(const NetDef &net_def,
for
(
const
OperatorDef
&
op
:
net_def
.
op
())
{
int
op_id
=
op_map
.
GetOpId
(
op
.
type
());
inputs
.
resize
(
op
.
node_input
().
size
());
for
(
size_t
i
=
0
;
i
<
(
size_t
)
op
.
node_input
().
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
op
.
node_input
().
size
();
++
i
)
{
inputs
[
i
].
src_id
=
node_id
(
op
.
node_input
()[
i
].
node_id
());
inputs
[
i
].
output_idx
=
op
.
node_input
()[
i
].
output_port
();
}
outputs
.
resize
(
op
.
out_max_byte_size
().
size
());
for
(
size_t
i
=
0
;
i
<
(
size_t
)
op
.
out_max_byte_size
().
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
op
.
out_max_byte_size
().
size
();
++
i
)
{
outputs
[
i
].
max_size
=
op
.
out_max_byte_size
()[
i
];
}
cached_inputs
.
push_back
(
inputs
);
...
...
mace/examples/example.cc
浏览文件 @
cefa11d4
...
...
@@ -120,6 +120,9 @@ DEFINE_string(output_file,
DEFINE_string
(
model_data_file
,
""
,
"model data file name, used when EMBED_MODEL_DATA set to 0"
);
DEFINE_string
(
model_file
,
""
,
"model file name, used when load mace model in pb"
);
DEFINE_string
(
device
,
"GPU"
,
"CPU/GPU/HEXAGON"
);
DEFINE_int32
(
round
,
1
,
"round"
);
DEFINE_int32
(
restart_round
,
1
,
"restart round"
);
...
...
@@ -163,23 +166,33 @@ bool RunModel(const std::vector<std::string> &input_names,
std
::
shared_ptr
<
mace
::
MaceEngine
>
engine
;
MaceStatus
create_engine_status
;
// Create Engine
if
(
FLAGS_model_data_file
.
empty
())
{
MaceStatus
create_engine_status
;
// Create Engine
int64_t
t0
=
NowMicros
();
const
char
*
model_data_file_ptr
=
FLAGS_model_data_file
.
empty
()
?
nullptr
:
FLAGS_model_data_file
.
c_str
();
if
(
FLAGS_model_file
!=
""
)
{
std
::
vector
<
unsigned
char
>
model_pb_data
;
if
(
!
mace
::
ReadBinaryFile
(
&
model_pb_data
,
FLAGS_model_file
))
{
LOG
(
FATAL
)
<<
"Failed to read file: "
<<
FLAGS_model_file
;
}
create_engine_status
=
CreateMaceEngine
(
FLAGS_model_name
.
c_str
()
,
nullptr
,
in
put_names
,
output_names
,
device_typ
e
,
&
engine
);
CreateMaceEngine
FromPB
(
model_data_file_ptr
,
input_names
,
out
put_names
,
device_type
,
&
engin
e
,
model_pb_data
);
}
else
{
create_engine_status
=
CreateMaceEngine
(
FLAGS_model_name
.
c_str
()
,
FLAGS_model_data_file
.
c_str
()
,
CreateMaceEngine
(
model_name
,
model_data_file_ptr
,
input_names
,
output_names
,
device_type
,
&
engine
);
}
if
(
create_engine_status
!=
MaceStatus
::
MACE_SUCCESS
)
{
LOG
(
FATAL
)
<<
"Create engine error, please check the arguments"
;
}
...
...
@@ -258,6 +271,7 @@ int Main(int argc, char **argv) {
LOG
(
INFO
)
<<
"input_file: "
<<
FLAGS_input_file
;
LOG
(
INFO
)
<<
"output_file: "
<<
FLAGS_output_file
;
LOG
(
INFO
)
<<
"model_data_file: "
<<
FLAGS_model_data_file
;
LOG
(
INFO
)
<<
"model_file: "
<<
FLAGS_model_file
;
LOG
(
INFO
)
<<
"device: "
<<
FLAGS_device
;
LOG
(
INFO
)
<<
"round: "
<<
FLAGS_round
;
LOG
(
INFO
)
<<
"restart_round: "
<<
FLAGS_restart_round
;
...
...
mace/public/mace.h
浏览文件 @
cefa11d4
...
...
@@ -56,8 +56,6 @@ class RunMetadata {
const
char
*
MaceVersion
();
// enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3 };
enum
MaceStatus
{
MACE_SUCCESS
=
0
,
MACE_INVALID_ARGS
=
1
};
// MACE input/output tensor
...
...
@@ -108,6 +106,13 @@ class MaceEngine {
MaceEngine
&
operator
=
(
const
MaceEngine
&
)
=
delete
;
};
MaceStatus
CreateMaceEngineFromPB
(
const
char
*
model_data_file
,
const
std
::
vector
<
std
::
string
>
&
input_nodes
,
const
std
::
vector
<
std
::
string
>
&
output_nodes
,
const
DeviceType
device_type
,
std
::
shared_ptr
<
MaceEngine
>
*
engine
,
const
std
::
vector
<
unsigned
char
>
model_pb
);
}
// namespace mace
#endif // MACE_PUBLIC_MACE_H_
mace/python/tools/mace_engine_factory.h.jinja2
浏览文件 @
cefa11d4
...
...
@@ -19,11 +19,13 @@
#include <string>
#include <vector>
#include "mace/core/macros.h"
#include "mace/public/mace.h"
#include "mace/public/mace_runtime.h"
namespace mace {
{% if model_type == 'source' %}
{% for tag in model_tags %}
namespace {{tag}} {
...
...
@@ -50,13 +52,12 @@ std::map<std::string, int> model_name_map {
} // namespace
MaceStatus CreateMaceEngine(
const
char *
model_name,
const
std::string &
model_name,
const char *model_data_file,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const DeviceType device_type,
std::shared_ptr<MaceEngine> *engine,
const std::vector<unsigned char> model_pb = {}) {
std::shared_ptr<MaceEngine> *engine) {
// load model
if (engine == nullptr) {
return MaceStatus::MACE_INVALID_ARGS;
...
...
@@ -68,7 +69,7 @@ MaceStatus CreateMaceEngine(
case {{ i }}:
model_data =
mace::{{model_tags[i]}}::LoadModelData(model_data_file);
net_def = mace::{{model_tags[i]}}::CreateNet(
model_pb
);
net_def = mace::{{model_tags[i]}}::CreateNet();
engine->reset(
new mace::MaceEngine(&net_def, device_type, input_nodes, output_nodes,
...
...
@@ -84,5 +85,22 @@ MaceStatus CreateMaceEngine(
return MaceStatus::MACE_SUCCESS;
}
{% else %}
MaceStatus CreateMaceEngine(
const std::string &model_name,
const char *model_data_file,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
const DeviceType device_type,
std::shared_ptr<MaceEngine> *engine) {
MACE_UNUSED(model_name);
MACE_UNUSED(model_data_file);
MACE_UNUSED(input_nodes);
MACE_UNUSED(output_nodes);
MACE_UNUSED(device_type);
MACE_UNUSED(engine);
return MaceStatus::MACE_INVALID_ARGS;
}
{% endif %}
} // namespace mace
mace/python/tools/mace_engine_factory_codegen.py
浏览文件 @
cefa11d4
...
...
@@ -20,7 +20,7 @@ from jinja2 import Environment, FileSystemLoader
FLAGS
=
None
def
gen_mace_engine_factory
(
model_tags
,
template_dir
,
output_dir
):
def
gen_mace_engine_factory
(
model_tags
,
template_dir
,
model_type
,
output_dir
):
# Create the jinja2 environment.
j2_env
=
Environment
(
loader
=
FileSystemLoader
(
template_dir
),
trim_blocks
=
True
)
...
...
@@ -29,6 +29,7 @@ def gen_mace_engine_factory(model_tags, template_dir, output_dir):
template_name
=
'mace_engine_factory.h.jinja2'
source
=
j2_env
.
get_template
(
template_name
).
render
(
model_tags
=
model_tags
,
model_type
=
model_type
,
)
with
open
(
output_dir
+
'/mace_engine_factory.h'
,
"wb"
)
as
f
:
f
.
write
(
source
)
...
...
@@ -46,10 +47,15 @@ def parse_args():
"--template_dir"
,
type
=
str
,
default
=
""
,
help
=
"template path"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
""
,
help
=
"output path"
)
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
default
=
""
,
help
=
"[source|pb] model load type"
)
return
parser
.
parse_known_args
()
if
__name__
==
'__main__'
:
FLAGS
,
unparsed
=
parse_args
()
gen_mace_engine_creator
(
FLAGS
.
model_tag
,
FLAGS
.
template_dir
,
FLAGS
.
output_dir
)
FLAGS
.
model_type
,
FLAGS
.
output_dir
)
mace/python/tools/source_converter_lib.py
浏览文件 @
cefa11d4
...
...
@@ -184,20 +184,21 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
model_data
.
extend
(
tensor_info
.
data
)
offset
+=
len
(
tensor_info
.
data
)
# generate tensor data
template_name
=
'tensor_data.jinja2'
source
=
j2_env
.
get_template
(
template_name
).
render
(
tag
=
model_tag
,
embed_model_data
=
embed_model_data
,
model_data_size
=
offset
,
model_data
=
model_data
)
with
open
(
output_dir
+
'tensor_data'
+
'.cc'
,
"wb"
)
as
f
:
f
.
write
(
source
)
if
not
embed_model_data
:
with
open
(
output_dir
+
model_tag
+
'.data'
,
"wb"
)
as
f
:
f
.
write
(
bytearray
(
model_data
))
if
model_load_type
==
'source'
:
# generate tensor data
template_name
=
'tensor_data.jinja2'
source
=
j2_env
.
get_template
(
template_name
).
render
(
tag
=
model_tag
,
embed_model_data
=
embed_model_data
,
model_data_size
=
offset
,
model_data
=
model_data
)
with
open
(
output_dir
+
'tensor_data'
+
'.cc'
,
"wb"
)
as
f
:
f
.
write
(
source
)
# generate op source files
template_name
=
'operator.jinja2'
counter
=
0
...
...
@@ -214,35 +215,35 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
f
.
write
(
source
)
counter
+=
1
# generate model source files
build_time
=
datetime
.
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
template_name
=
'model.jinja2'
tensors
=
[
TensorInfo
(
i
,
net_def
.
tensors
[
i
],
runtime
)
for
i
in
range
(
len
(
net_def
.
tensors
))
]
checksum
=
model_checksum
if
weight_checksum
is
not
None
:
checksum
=
"{},{}"
.
format
(
model_checksum
,
weight_checksum
)
source
=
j2_env
.
get_template
(
template_name
).
render
(
tensors
=
tensors
,
net
=
net_def
,
tag
=
model_tag
,
runtime
=
runtime
,
obfuscate
=
obfuscate
,
embed_model_data
=
embed_model_data
,
winograd_conv
=
winograd_conv
,
checksum
=
checksum
,
build_time
=
build_time
,
model_type
=
model_load_type
)
with
open
(
output
,
"wb"
)
as
f
:
f
.
write
(
source
)
# generate model header file
template_name
=
'model_header.jinja2'
source
=
j2_env
.
get_template
(
template_name
).
render
(
tag
=
model_tag
,
)
with
open
(
output_dir
+
model_tag
+
'.h'
,
"wb"
)
as
f
:
f
.
write
(
source
)
# generate model source files
build_time
=
datetime
.
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
template_name
=
'model.jinja2'
tensors
=
[
TensorInfo
(
i
,
net_def
.
tensors
[
i
],
runtime
)
for
i
in
range
(
len
(
net_def
.
tensors
))
]
checksum
=
model_checksum
if
weight_checksum
is
not
None
:
checksum
=
"{},{}"
.
format
(
model_checksum
,
weight_checksum
)
source
=
j2_env
.
get_template
(
template_name
).
render
(
tensors
=
tensors
,
net
=
net_def
,
tag
=
model_tag
,
runtime
=
runtime
,
obfuscate
=
obfuscate
,
embed_model_data
=
embed_model_data
,
winograd_conv
=
winograd_conv
,
checksum
=
checksum
,
build_time
=
build_time
,
model_type
=
model_load_type
)
with
open
(
output
,
"wb"
)
as
f
:
f
.
write
(
source
)
# generate model header file
template_name
=
'model_header.jinja2'
source
=
j2_env
.
get_template
(
template_name
).
render
(
tag
=
model_tag
,
)
with
open
(
output_dir
+
model_tag
+
'.h'
,
"wb"
)
as
f
:
f
.
write
(
source
)
for
t
in
net_def
.
tensors
:
if
t
.
data_type
==
mace_pb2
.
DT_FLOAT
:
...
...
mace/tools/validation/mace_run.cc
浏览文件 @
cefa11d4
...
...
@@ -239,16 +239,15 @@ bool RunModel(const std::string &model_name,
LOG
(
FATAL
)
<<
"Failed to read file: "
<<
FLAGS_model_file
;
}
create_engine_status
=
CreateMaceEngine
(
model_name
.
c_str
(),
model_data_file_ptr
,
input_names
,
output_names
,
device_type
,
&
engine
,
model_pb_data
);
CreateMaceEngineFromPB
(
model_data_file_ptr
,
input_names
,
output_names
,
device_type
,
&
engine
,
model_pb_data
);
}
else
{
create_engine_status
=
CreateMaceEngine
(
model_name
.
c_str
()
,
CreateMaceEngine
(
model_name
,
model_data_file_ptr
,
input_names
,
output_names
,
...
...
tools/mace_tools.py
浏览文件 @
cefa11d4
...
...
@@ -205,8 +205,8 @@ def tuning_run(target_abi,
stdout
,
target_abi
,
serialno
,
model_name
,
device_type
)
def
build_mace_run_prod
(
hexagon_mode
,
runtime
,
target_abi
,
serialno
,
vlog_level
,
embed_model_data
,
def
build_mace_run_prod
(
hexagon_mode
,
runtime
,
target_abi
,
serialno
,
vlog_level
,
embed_model_data
,
model_load_type
,
model_output_dir
,
input_nodes
,
output_nodes
,
input_shapes
,
output_shapes
,
mace_model_dir
,
model_name
,
device_type
,
running_round
,
restart_round
,
...
...
@@ -228,7 +228,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi,
hexagon_mode
=
hexagon_mode
,
enable_openmp
=
enable_openmp
)
sh_commands
.
update_mace_run_lib
(
model_output_dir
,
sh_commands
.
update_mace_run_lib
(
model_output_dir
,
model_load_type
,
model_name
,
embed_model_data
)
device_type
=
parse_device_type
(
"gpu"
)
...
...
@@ -250,7 +250,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi,
debug
=
debug
,
enable_openmp
=
enable_openmp
)
sh_commands
.
update_mace_run_lib
(
model_output_dir
,
sh_commands
.
update_mace_run_lib
(
model_output_dir
,
model_load_type
,
model_name
,
embed_model_data
)
else
:
gen_opencl_and_tuning_code
(
target_abi
,
serialno
,
[],
False
)
...
...
@@ -263,7 +263,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi,
debug
=
debug
,
enable_openmp
=
enable_openmp
)
sh_commands
.
update_mace_run_lib
(
model_output_dir
,
sh_commands
.
update_mace_run_lib
(
model_output_dir
,
model_load_type
,
model_name
,
embed_model_data
)
...
...
@@ -274,11 +274,12 @@ def merge_libs_and_tuning_results(target_soc,
output_dir
,
model_output_dirs
,
mace_model_dirs_kv
,
model_load_type
,
hexagon_mode
,
embed_model_data
):
gen_opencl_and_tuning_code
(
target_abi
,
serialno
,
model_output_dirs
,
False
)
sh_commands
.
build_production_code
(
target_abi
)
sh_commands
.
build_production_code
(
model_load_type
,
target_abi
)
sh_commands
.
merge_libs
(
target_soc
,
target_abi
,
...
...
@@ -286,6 +287,7 @@ def merge_libs_and_tuning_results(target_soc,
output_dir
,
model_output_dirs
,
mace_model_dirs_kv
,
model_load_type
,
hexagon_mode
,
embed_model_data
)
...
...
@@ -370,6 +372,9 @@ def parse_model_configs():
print
(
"CONFIG ERROR:"
)
print
(
"embed_model_data must be integer in range [0, 1]"
)
exit
(
1
)
elif
FLAGS
.
model_load_type
==
"pb"
:
configs
[
"embed_model_data"
]
=
0
print
(
"emebed_model_data is set 0"
)
model_names
=
configs
.
get
(
"models"
,
""
)
if
not
model_names
:
...
...
@@ -599,6 +604,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
serialno
,
vlog_level
,
embed_model_data
,
model_load_type
,
model_output_dir
,
model_config
[
"input_nodes"
],
model_config
[
"output_nodes"
],
...
...
@@ -688,6 +694,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
FLAGS
.
output_dir
,
model_output_dirs
,
mace_model_dirs_kv
,
model_load_type
,
hexagon_mode
,
embed_model_data
)
...
...
@@ -748,7 +755,8 @@ def main(unused_args):
# generate source
sh_commands
.
gen_mace_version
()
sh_commands
.
gen_encrypted_opencl_source
()
sh_commands
.
gen_mace_engine_factory_source
(
configs
[
'models'
].
keys
())
sh_commands
.
gen_mace_engine_factory_source
(
configs
[
'models'
].
keys
(),
FLAGS
.
model_load_type
)
embed_model_data
=
configs
[
"embed_model_data"
]
target_socs
=
get_target_socs
(
configs
)
...
...
tools/sh_commands.py
浏览文件 @
cefa11d4
...
...
@@ -370,6 +370,7 @@ def gen_encrypted_opencl_source(codegen_path="mace/codegen"):
def
gen_mace_engine_factory_source
(
model_tags
,
model_load_type
,
codegen_path
=
"mace/codegen"
):
print
(
"* Genearte mace engine creator source"
)
codegen_tools_dir
=
"%s/engine"
%
codegen_path
...
...
@@ -378,6 +379,7 @@ def gen_mace_engine_factory_source(model_tags,
gen_mace_engine_factory
(
model_tags
,
"mace/python/tools"
,
model_load_type
,
codegen_tools_dir
)
print
(
"Genearte mace engine creator source done!
\n
"
)
...
...
@@ -547,6 +549,7 @@ def gen_random_input(model_output_dir,
def
update_mace_run_lib
(
model_output_dir
,
model_load_type
,
model_tag
,
embed_model_data
):
mace_run_filepath
=
model_output_dir
+
"/mace_run"
...
...
@@ -558,8 +561,9 @@ def update_mace_run_lib(model_output_dir,
sh
.
cp
(
"-f"
,
"mace/codegen/models/%s/%s.data"
%
(
model_tag
,
model_tag
),
model_output_dir
)
sh
.
cp
(
"-f"
,
"mace/codegen/models/%s/%s.h"
%
(
model_tag
,
model_tag
),
model_output_dir
)
if
model_load_type
==
"source"
:
sh
.
cp
(
"-f"
,
"mace/codegen/models/%s/%s.h"
%
(
model_tag
,
model_tag
),
model_output_dir
)
def
create_internal_storage_dir
(
serialno
,
phone_data_dir
):
...
...
@@ -833,13 +837,17 @@ def validate_model(abi,
print
(
"Validation done!
\n
"
)
def
build_production_code
(
abi
):
def
build_production_code
(
model_load_type
,
abi
):
bazel_build
(
"//mace/codegen:generated_opencl"
,
abi
=
abi
)
bazel_build
(
"//mace/codegen:generated_tuning_params"
,
abi
=
abi
)
if
abi
==
'host'
:
bazel_build
(
"//mace/codegen:generated_models"
,
abi
=
abi
)
if
model_load_type
==
"source"
:
bazel_build
(
"//mace/codegen:generated_models"
,
abi
=
abi
)
else
:
bazel_build
(
"//mace/core:core"
,
abi
=
abi
)
bazel_build
(
"//mace/ops:ops"
,
abi
=
abi
)
def
merge_libs
(
target_soc
,
...
...
@@ -848,6 +856,7 @@ def merge_libs(target_soc,
libmace_output_dir
,
model_output_dirs
,
mace_model_dirs_kv
,
model_load_type
,
hexagon_mode
,
embed_model_data
):
print
(
"* Merge mace lib"
)
...
...
@@ -879,12 +888,24 @@ def merge_libs(target_soc,
mri_stream
+=
(
"addlib "
"bazel-bin/mace/codegen/libgenerated_tuning_params.pic.a
\n
"
)
mri_stream
+=
(
"addlib "
"bazel-bin/mace/codegen/libgenerated_models.pic.a
\n
"
)
if
model_load_type
==
"source"
:
mri_stream
+=
(
"addlib "
"bazel-bin/mace/codegen/libgenerated_models.pic.a
\n
"
)
else
:
mri_stream
+=
(
"addlib "
"bazel-bin/mace/core/libcore.pic.a
\n
"
)
mri_stream
+=
(
"addlib "
"bazel-bin/mace/ops/libops.pic.lo
\n
"
)
else
:
mri_stream
+=
"create %s/libmace_%s.%s.a
\n
"
%
\
(
model_bin_dir
,
project_name
,
target_soc
)
if
model_load_type
==
"source"
:
mri_stream
+=
(
"addlib "
"bazel-bin/mace/codegen/libgenerated_models.a
\n
"
)
mri_stream
+=
(
"addlib "
"bazel-bin/mace/codegen/libgenerated_opencl.a
\n
"
)
...
...
@@ -894,9 +915,6 @@ def merge_libs(target_soc,
mri_stream
+=
(
"addlib "
"bazel-bin/mace/codegen/libgenerated_version.a
\n
"
)
mri_stream
+=
(
"addlib "
"bazel-bin/mace/codegen/libgenerated_models.a
\n
"
)
mri_stream
+=
(
"addlib "
"bazel-bin/mace/core/libcore.a
\n
"
)
...
...
@@ -909,6 +927,12 @@ def merge_libs(target_soc,
mri_stream
+=
(
"addlib "
"bazel-bin/mace/utils/libutils_prod.a
\n
"
)
mri_stream
+=
(
"addlib "
"bazel-bin/mace/proto/libmace_cc.a
\n
"
)
mri_stream
+=
(
"addlib "
"bazel-bin/external/com_google_protobuf/libprotobuf_lite.a
\n
"
)
mri_stream
+=
(
"addlib "
"bazel-bin/mace/ops/libops.lo
\n
"
)
...
...
@@ -917,7 +941,8 @@ def merge_libs(target_soc,
if
not
embed_model_data
:
sh
.
cp
(
"-f"
,
glob
.
glob
(
"%s/*.data"
%
model_output_dir
),
model_data_dir
)
sh
.
cp
(
"-f"
,
glob
.
glob
(
"%s/*.h"
%
model_output_dir
),
model_header_dir
)
if
model_load_type
==
"source"
:
sh
.
cp
(
"-f"
,
glob
.
glob
(
"%s/*.h"
%
model_output_dir
),
model_header_dir
)
for
model_name
in
mace_model_dirs_kv
:
sh
.
cp
(
"-f"
,
"%s/%s.pb"
%
(
mace_model_dirs_kv
[
model_name
],
model_name
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录