Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
a5bdbc62
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a5bdbc62
编写于
2月 27, 2018
作者:
W
wuchenghui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add EMBED_MODE_DATA option
上级
24ca9183
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
79 addition
and
16 deletion
+79
-16
python/tools/model.template
python/tools/model.template
+43
-11
python/tools/source_converter_lib.py
python/tools/source_converter_lib.py
+30
-4
python/tools/tf_converter.py
python/tools/tf_converter.py
+6
-1
未找到文件。
python/tools/model.template
浏览文件 @
a5bdbc62
...
...
@@ -10,13 +10,9 @@
namespace mace {
namespace {{tag}} {
{% if tensor_info.data_type != 'DT_UINT8' %} alignas(4) {% endif %} unsigned char {{ tensor_info.name }}[] = {
{% for d in tensor_info.data %}{{"0x%02X, " % d }}{%endfor%}
};
void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors) {
void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors, const unsigned char *model_data) {
tensors.emplace_back(mace::ConstTensor(
{{ tensor.name|tojson }},
{{ tensor.name
}},
{{ tensor.name|tojson }},
const_cast<unsigned char *>(model_data) + {{ offset
}},
{ {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }}));
}
...
...
@@ -24,6 +20,42 @@ void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors) {
} // namespace mace
{% elif mode == 1 %}
{% if not embed_model_data %}
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
{% endif %}
namespace mace {
namespace {{tag}} {
{% if embed_model_data %}
alignas(4) unsigned char model_data[{{ model_data_size }}] = {
{% for d in model_data %}{{"0x%02X, " % d }}{%endfor%}
};
{% endif %}
unsigned char *LoadModelData(const char *model_data_file) {
{% if embed_model_data %}
return model_data;
{% else %}
int fd=open(model_data_file, O_RDONLY);
unsigned char *model_data = (unsigned char *)mmap(nullptr, {{ model_data_size }}, PROT_READ, MAP_PRIVATE, fd, 0);
close(fd);
return model_data;
{% endif %}
}
void UnloadModelData(unsigned char *model_data) {
{% if not embed_model_data %}
munmap(model_data, {{ model_data_size }});
{% endif %}
}
} // namespace {{tag}}
} // namespace mace
{% elif mode == 2 %}
#include <vector>
#include <string>
#include "mace/core/public/mace.h"
...
...
@@ -134,7 +166,7 @@ namespace mace {
namespace {{tag}} {
{% for tensor in tensors %}
extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors);
extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors
, const unsigned char *model_data
);
{% endfor %}
...
...
@@ -209,12 +241,12 @@ void CreateOperators(std::vector<mace::OperatorDef> &ops) {
}
void CreateTensors(std::vector<mace::ConstTensor> &tensors) {
void CreateTensors(std::vector<mace::ConstTensor> &tensors
, const unsigned char *model_data
) {
tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %}
mace::{{tag}}::Create{{tensor.name}}(tensors);
mace::{{tag}}::Create{{tensor.name}}(tensors
, model_data
);
{% endfor %}
}
...
...
@@ -239,7 +271,7 @@ void CreateMemoryArena(mace::MemoryArena &mem_arena) {
namespace mace {
namespace {{tag}} {
NetDef CreateNet() {
NetDef CreateNet(
const unsigned char *model_data
) {
NetDef net_def;
net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}");
...
...
@@ -250,7 +282,7 @@ NetDef CreateNet() {
CreateOperators(net_def.mutable_op());
CreateTensors(net_def.mutable_tensors());
CreateTensors(net_def.mutable_tensors()
, model_data
);
{% if net.mem_arena.mem_block|length != 0 %}
CreateMemoryArena(net_def.mutable_mem_arena());
...
...
python/tools/source_converter_lib.py
浏览文件 @
a5bdbc62
...
...
@@ -91,7 +91,7 @@ class TensorInfo:
def
stringfy
(
value
):
return
', '
.
join
(
'"{0}"'
.
format
(
w
)
for
w
in
value
)
def
convert_to_source
(
net_def
,
mode_pb_checksum
,
template
,
obfuscate
,
model_tag
,
output
,
runtime
):
def
convert_to_source
(
net_def
,
mode_pb_checksum
,
template
,
obfuscate
,
model_tag
,
output
,
runtime
,
embed_model_data
):
if
obfuscate
:
obfuscate_name
(
net_def
)
else
:
...
...
@@ -109,18 +109,44 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
counter
=
0
output_dir
=
os
.
path
.
dirname
(
output
)
+
'/'
# generate tensor source files
model_data
=
[]
offset
=
0
for
t
in
net_def
.
tensors
:
tensor_info
=
TensorInfo
(
t
,
runtime
)
# align
if
tensor_info
.
data_type
!=
'DT_UINT8'
and
offset
%
4
!=
0
:
padding
=
4
-
offset
%
4
model_data
.
extend
(
bytearray
([
0
]
*
padding
))
offset
+=
padding
source
=
j2_env
.
get_template
(
template_name
).
render
(
tensor_info
=
TensorInfo
(
t
,
runtime
),
tensor
=
t
,
tag
=
model_tag
,
mode
=
0
,
runtime
=
runtime
,
offset
=
offset
,
)
model_data
.
extend
(
tensor_info
.
data
)
offset
+=
len
(
tensor_info
.
data
)
with
gfile
.
GFile
(
output_dir
+
'tensor'
+
str
(
counter
)
+
'.cc'
,
"wb"
)
as
f
:
f
.
write
(
source
)
counter
+=
1
# generate tensor data
source
=
j2_env
.
get_template
(
template_name
).
render
(
tag
=
model_tag
,
mode
=
1
,
embed_model_data
=
embed_model_data
,
model_data_size
=
offset
,
model_data
=
model_data
)
with
gfile
.
GFile
(
output_dir
+
'tensor_data'
+
'.cc'
,
"wb"
)
as
f
:
f
.
write
(
source
)
if
not
embed_model_data
:
f
=
open
(
output_dir
+
model_tag
+
'.data'
,
"wb"
)
f
.
write
(
bytearray
(
model_data
))
f
.
close
()
# generate op source files
counter
=
0
op_size
=
len
(
net_def
.
op
)
...
...
@@ -130,7 +156,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
end
=
min
(
start
+
10
,
op_size
),
net
=
net_def
,
tag
=
model_tag
,
mode
=
1
,
mode
=
2
,
runtime
=
runtime
,
)
with
gfile
.
GFile
(
output_dir
+
'op'
+
str
(
counter
)
+
'.cc'
,
"wb"
)
as
f
:
...
...
@@ -143,9 +169,9 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
tensors
=
tensors
,
net
=
net_def
,
tag
=
model_tag
,
mode
=
2
,
mode
=
3
,
runtime
=
runtime
,
model_pb_checksum
=
mode_pb_checksum
,
model_pb_checksum
=
mode_pb_checksum
)
with
gfile
.
GFile
(
output
,
"wb"
)
as
f
:
f
.
write
(
source
)
python/tools/tf_converter.py
浏览文件 @
a5bdbc62
...
...
@@ -43,7 +43,7 @@ def main(unused_args):
if
FLAGS
.
output_type
==
'source'
:
source_converter_lib
.
convert_to_source
(
output_graph_def
,
mode_pb_checksum
,
FLAGS
.
template
,
FLAGS
.
obfuscate
,
FLAGS
.
model_tag
,
FLAGS
.
output
,
FLAGS
.
runtime
)
FLAGS
.
model_tag
,
FLAGS
.
output
,
FLAGS
.
runtime
,
FLAGS
.
embed_model_data
)
else
:
with
gfile
.
GFile
(
FLAGS
.
output
,
"wb"
)
as
f
:
f
.
write
(
output_graph_def
.
SerializeToString
())
...
...
@@ -133,6 +133,11 @@ def parse_args():
type
=
str
,
default
=
""
,
help
=
"input shape."
)
parser
.
add_argument
(
"--embed_model_data"
,
type
=
str2bool
,
default
=
True
,
help
=
"input shape."
)
return
parser
.
parse_known_args
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录