提交 a5bdbc62 编写于 作者: W wuchenghui

add EMBED_MODE_DATA option

上级 24ca9183
...@@ -10,13 +10,9 @@ ...@@ -10,13 +10,9 @@
namespace mace { namespace mace {
namespace {{tag}} { namespace {{tag}} {
{% if tensor_info.data_type != 'DT_UINT8' %} alignas(4) {% endif %} unsigned char {{ tensor_info.name }}[] = { void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors, const unsigned char *model_data) {
{% for d in tensor_info.data %}{{"0x%02X, " % d }}{%endfor%}
};
void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors) {
tensors.emplace_back(mace::ConstTensor( tensors.emplace_back(mace::ConstTensor(
{{ tensor.name|tojson }}, {{ tensor.name }}, {{ tensor.name|tojson }}, const_cast<unsigned char *>(model_data) + {{ offset }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }})); { {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }}));
} }
...@@ -24,6 +20,42 @@ void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors) { ...@@ -24,6 +20,42 @@ void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors) {
} // namespace mace } // namespace mace
{% elif mode == 1 %} {% elif mode == 1 %}
{% if not embed_model_data %}
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
{% endif %}
namespace mace {
namespace {{tag}} {
{% if embed_model_data %}
alignas(4) unsigned char model_data[{{ model_data_size }}] = {
{% for d in model_data %}{{"0x%02X, " % d }}{%endfor%}
};
{% endif %}
unsigned char *LoadModelData(const char *model_data_file) {
{% if embed_model_data %}
return model_data;
{% else %}
int fd=open(model_data_file, O_RDONLY);
unsigned char *model_data = (unsigned char *)mmap(nullptr, {{ model_data_size }}, PROT_READ, MAP_PRIVATE, fd, 0);
close(fd);
return model_data;
{% endif %}
}
void UnloadModelData(unsigned char *model_data) {
{% if not embed_model_data %}
munmap(model_data, {{ model_data_size }});
{% endif %}
}
} // namespace {{tag}}
} // namespace mace
{% elif mode == 2 %}
#include <vector> #include <vector>
#include <string> #include <string>
#include "mace/core/public/mace.h" #include "mace/core/public/mace.h"
...@@ -134,7 +166,7 @@ namespace mace { ...@@ -134,7 +166,7 @@ namespace mace {
namespace {{tag}} { namespace {{tag}} {
{% for tensor in tensors %} {% for tensor in tensors %}
extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors); extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors, const unsigned char *model_data);
{% endfor %} {% endfor %}
...@@ -209,12 +241,12 @@ void CreateOperators(std::vector<mace::OperatorDef> &ops) { ...@@ -209,12 +241,12 @@ void CreateOperators(std::vector<mace::OperatorDef> &ops) {
} }
void CreateTensors(std::vector<mace::ConstTensor> &tensors) { void CreateTensors(std::vector<mace::ConstTensor> &tensors, const unsigned char *model_data) {
tensors.reserve({{ net.tensors|length }}); tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %} {% for tensor in net.tensors %}
mace::{{tag}}::Create{{tensor.name}}(tensors); mace::{{tag}}::Create{{tensor.name}}(tensors, model_data);
{% endfor %} {% endfor %}
} }
...@@ -239,7 +271,7 @@ void CreateMemoryArena(mace::MemoryArena &mem_arena) { ...@@ -239,7 +271,7 @@ void CreateMemoryArena(mace::MemoryArena &mem_arena) {
namespace mace { namespace mace {
namespace {{tag}} { namespace {{tag}} {
NetDef CreateNet() { NetDef CreateNet(const unsigned char *model_data) {
NetDef net_def; NetDef net_def;
net_def.set_name("{{ net.name}}"); net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}"); net_def.set_version("{{ net.version }}");
...@@ -250,7 +282,7 @@ NetDef CreateNet() { ...@@ -250,7 +282,7 @@ NetDef CreateNet() {
CreateOperators(net_def.mutable_op()); CreateOperators(net_def.mutable_op());
CreateTensors(net_def.mutable_tensors()); CreateTensors(net_def.mutable_tensors(), model_data);
{% if net.mem_arena.mem_block|length != 0 %} {% if net.mem_arena.mem_block|length != 0 %}
CreateMemoryArena(net_def.mutable_mem_arena()); CreateMemoryArena(net_def.mutable_mem_arena());
......
...@@ -91,7 +91,7 @@ class TensorInfo: ...@@ -91,7 +91,7 @@ class TensorInfo:
def stringfy(value): def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value) return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime): def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime, embed_model_data):
if obfuscate: if obfuscate:
obfuscate_name(net_def) obfuscate_name(net_def)
else: else:
...@@ -109,18 +109,44 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -109,18 +109,44 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
counter = 0 counter = 0
output_dir = os.path.dirname(output) + '/' output_dir = os.path.dirname(output) + '/'
# generate tensor source files # generate tensor source files
model_data = []
offset = 0
for t in net_def.tensors: for t in net_def.tensors:
tensor_info = TensorInfo(t, runtime)
# align
if tensor_info.data_type != 'DT_UINT8' and offset % 4 != 0:
padding = 4 - offset % 4
model_data.extend(bytearray([0] * padding))
offset += padding
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tensor_info = TensorInfo(t, runtime), tensor_info = TensorInfo(t, runtime),
tensor = t, tensor = t,
tag = model_tag, tag = model_tag,
mode = 0, mode = 0,
runtime = runtime, runtime = runtime,
offset = offset,
) )
model_data.extend(tensor_info.data)
offset += len(tensor_info.data)
with gfile.GFile(output_dir + 'tensor' + str(counter) + '.cc', "wb") as f: with gfile.GFile(output_dir + 'tensor' + str(counter) + '.cc', "wb") as f:
f.write(source) f.write(source)
counter += 1 counter += 1
# generate tensor data
source = j2_env.get_template(template_name).render(
tag = model_tag,
mode = 1,
embed_model_data = embed_model_data,
model_data_size = offset,
model_data = model_data
)
with gfile.GFile(output_dir + 'tensor_data' + '.cc', "wb") as f:
f.write(source)
if not embed_model_data:
f = open(output_dir + model_tag + '.data', "wb")
f.write(bytearray(model_data))
f.close()
# generate op source files # generate op source files
counter = 0 counter = 0
op_size = len(net_def.op) op_size = len(net_def.op)
...@@ -130,7 +156,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -130,7 +156,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
end = min(start+10, op_size), end = min(start+10, op_size),
net = net_def, net = net_def,
tag = model_tag, tag = model_tag,
mode = 1, mode = 2,
runtime = runtime, runtime = runtime,
) )
with gfile.GFile(output_dir + 'op' + str(counter) + '.cc', "wb") as f: with gfile.GFile(output_dir + 'op' + str(counter) + '.cc', "wb") as f:
...@@ -143,9 +169,9 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -143,9 +169,9 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
tensors = tensors, tensors = tensors,
net = net_def, net = net_def,
tag = model_tag, tag = model_tag,
mode = 2, mode = 3,
runtime = runtime, runtime = runtime,
model_pb_checksum = mode_pb_checksum, model_pb_checksum = mode_pb_checksum
) )
with gfile.GFile(output, "wb") as f: with gfile.GFile(output, "wb") as f:
f.write(source) f.write(source)
...@@ -43,7 +43,7 @@ def main(unused_args): ...@@ -43,7 +43,7 @@ def main(unused_args):
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime) FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data)
else: else:
with gfile.GFile(FLAGS.output, "wb") as f: with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString()) f.write(output_graph_def.SerializeToString())
...@@ -133,6 +133,11 @@ def parse_args(): ...@@ -133,6 +133,11 @@ def parse_args():
type=str, type=str,
default="", default="",
help="input shape.") help="input shape.")
parser.add_argument(
"--embed_model_data",
type=str2bool,
default=True,
help="input shape.")
return parser.parse_known_args() return parser.parse_known_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册