提交 a5bdbc62 编写于 作者: W wuchenghui

add EMBED_MODE_DATA option

上级 24ca9183
......@@ -10,13 +10,9 @@
namespace mace {
namespace {{tag}} {
{% if tensor_info.data_type != 'DT_UINT8' %} alignas(4) {% endif %} unsigned char {{ tensor_info.name }}[] = {
{% for d in tensor_info.data %}{{"0x%02X, " % d }}{%endfor%}
};
void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors) {
void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors, const unsigned char *model_data) {
tensors.emplace_back(mace::ConstTensor(
{{ tensor.name|tojson }}, {{ tensor.name }},
{{ tensor.name|tojson }}, const_cast<unsigned char *>(model_data) + {{ offset }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }}));
}
......@@ -24,6 +20,42 @@ void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors) {
} // namespace mace
{% elif mode == 1 %}
{% if not embed_model_data %}
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
{% endif %}
namespace mace {
namespace {{tag}} {
{% if embed_model_data %}
alignas(4) unsigned char model_data[{{ model_data_size }}] = {
{% for d in model_data %}{{"0x%02X, " % d }}{%endfor%}
};
{% endif %}
unsigned char *LoadModelData(const char *model_data_file) {
{% if embed_model_data %}
return model_data;
{% else %}
int fd=open(model_data_file, O_RDONLY);
unsigned char *model_data = (unsigned char *)mmap(nullptr, {{ model_data_size }}, PROT_READ, MAP_PRIVATE, fd, 0);
close(fd);
return model_data;
{% endif %}
}
void UnloadModelData(unsigned char *model_data) {
{% if not embed_model_data %}
munmap(model_data, {{ model_data_size }});
{% endif %}
}
} // namespace {{tag}}
} // namespace mace
{% elif mode == 2 %}
#include <vector>
#include <string>
#include "mace/core/public/mace.h"
......@@ -134,7 +166,7 @@ namespace mace {
namespace {{tag}} {
{% for tensor in tensors %}
extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors);
extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors, const unsigned char *model_data);
{% endfor %}
......@@ -209,12 +241,12 @@ void CreateOperators(std::vector<mace::OperatorDef> &ops) {
}
void CreateTensors(std::vector<mace::ConstTensor> &tensors) {
void CreateTensors(std::vector<mace::ConstTensor> &tensors, const unsigned char *model_data) {
tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %}
mace::{{tag}}::Create{{tensor.name}}(tensors);
mace::{{tag}}::Create{{tensor.name}}(tensors, model_data);
{% endfor %}
}
......@@ -239,7 +271,7 @@ void CreateMemoryArena(mace::MemoryArena &mem_arena) {
namespace mace {
namespace {{tag}} {
NetDef CreateNet() {
NetDef CreateNet(const unsigned char *model_data) {
NetDef net_def;
net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}");
......@@ -250,7 +282,7 @@ NetDef CreateNet() {
CreateOperators(net_def.mutable_op());
CreateTensors(net_def.mutable_tensors());
CreateTensors(net_def.mutable_tensors(), model_data);
{% if net.mem_arena.mem_block|length != 0 %}
CreateMemoryArena(net_def.mutable_mem_arena());
......
......@@ -91,7 +91,7 @@ class TensorInfo:
def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime):
def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime, embed_model_data):
if obfuscate:
obfuscate_name(net_def)
else:
......@@ -109,18 +109,44 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
counter = 0
output_dir = os.path.dirname(output) + '/'
# generate tensor source files
model_data = []
offset = 0
for t in net_def.tensors:
tensor_info = TensorInfo(t, runtime)
# align
if tensor_info.data_type != 'DT_UINT8' and offset % 4 != 0:
padding = 4 - offset % 4
model_data.extend(bytearray([0] * padding))
offset += padding
source = j2_env.get_template(template_name).render(
tensor_info = TensorInfo(t, runtime),
tensor = t,
tag = model_tag,
mode = 0,
runtime = runtime,
offset = offset,
)
model_data.extend(tensor_info.data)
offset += len(tensor_info.data)
with gfile.GFile(output_dir + 'tensor' + str(counter) + '.cc', "wb") as f:
f.write(source)
counter += 1
# generate tensor data
source = j2_env.get_template(template_name).render(
tag = model_tag,
mode = 1,
embed_model_data = embed_model_data,
model_data_size = offset,
model_data = model_data
)
with gfile.GFile(output_dir + 'tensor_data' + '.cc', "wb") as f:
f.write(source)
if not embed_model_data:
f = open(output_dir + model_tag + '.data', "wb")
f.write(bytearray(model_data))
f.close()
# generate op source files
counter = 0
op_size = len(net_def.op)
......@@ -130,7 +156,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
end = min(start+10, op_size),
net = net_def,
tag = model_tag,
mode = 1,
mode = 2,
runtime = runtime,
)
with gfile.GFile(output_dir + 'op' + str(counter) + '.cc', "wb") as f:
......@@ -143,9 +169,9 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
tensors = tensors,
net = net_def,
tag = model_tag,
mode = 2,
mode = 3,
runtime = runtime,
model_pb_checksum = mode_pb_checksum,
model_pb_checksum = mode_pb_checksum
)
with gfile.GFile(output, "wb") as f:
f.write(source)
......@@ -43,7 +43,7 @@ def main(unused_args):
if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime)
FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data)
else:
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
......@@ -133,6 +133,11 @@ def parse_args():
type=str,
default="",
help="input shape.")
parser.add_argument(
"--embed_model_data",
type=str2bool,
default=True,
help="input shape.")
return parser.parse_known_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册