diff --git a/python/tools/model.template b/python/tools/model.template index ff1b531fb164f377036f9b5cf7898b74f280895a..a11ccf850fc0c3817d7dc7767450ad29cbcf84bc 100644 --- a/python/tools/model.template +++ b/python/tools/model.template @@ -10,13 +10,9 @@ namespace mace { namespace {{tag}} { -{% if tensor_info.data_type != 'DT_UINT8' %} alignas(4) {% endif %} unsigned char {{ tensor_info.name }}[] = { -{% for d in tensor_info.data %}{{"0x%02X, " % d }}{%endfor%} -}; - -void Create{{tensor.name}}(std::vector &tensors) { +void Create{{tensor.name}}(std::vector &tensors, const unsigned char *model_data) { tensors.emplace_back(mace::ConstTensor( - {{ tensor.name|tojson }}, {{ tensor.name }}, + {{ tensor.name|tojson }}, const_cast(model_data) + {{ offset }}, { {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }})); } @@ -24,6 +20,42 @@ void Create{{tensor.name}}(std::vector &tensors) { } // namespace mace {% elif mode == 1 %} +{% if not embed_model_data %} +#include +#include +#include +{% endif %} + +namespace mace { +namespace {{tag}} { + +{% if embed_model_data %} +alignas(4) unsigned char model_data[{{ model_data_size }}] = { +{% for d in model_data %}{{"0x%02X, " % d }}{%endfor%} +}; +{% endif %} + +unsigned char *LoadModelData(const char *model_data_file) { +{% if embed_model_data %} + return model_data; +{% else %} + int fd=open(model_data_file, O_RDONLY); + unsigned char *model_data = (unsigned char *)mmap(nullptr, {{ model_data_size }}, PROT_READ, MAP_PRIVATE, fd, 0); + close(fd); + return model_data; +{% endif %} +} + +void UnloadModelData(unsigned char *model_data) { +{% if not embed_model_data %} + munmap(model_data, {{ model_data_size }}); +{% endif %} +} + +} // namespace {{tag}} +} // namespace mace + +{% elif mode == 2 %} #include #include #include "mace/core/public/mace.h" @@ -134,7 +166,7 @@ namespace mace { namespace {{tag}} { {% for tensor in tensors %} -extern void Create{{ tensor.name }}(std::vector &tensors); +extern void Create{{ tensor.name }}(std::vector &tensors, const unsigned char *model_data); {% endfor %} @@ -209,12 +241,12 @@ void CreateOperators(std::vector &ops) { } -void CreateTensors(std::vector &tensors) { +void CreateTensors(std::vector &tensors, const unsigned char *model_data) { tensors.reserve({{ net.tensors|length }}); {% for tensor in net.tensors %} - mace::{{tag}}::Create{{tensor.name}}(tensors); + mace::{{tag}}::Create{{tensor.name}}(tensors, model_data); {% endfor %} } @@ -239,7 +271,7 @@ void CreateMemoryArena(mace::MemoryArena &mem_arena) { namespace mace { namespace {{tag}} { -NetDef CreateNet() { +NetDef CreateNet(const unsigned char *model_data) { NetDef net_def; net_def.set_name("{{ net.name}}"); net_def.set_version("{{ net.version }}"); @@ -250,7 +282,7 @@ NetDef CreateNet() { CreateOperators(net_def.mutable_op()); - CreateTensors(net_def.mutable_tensors()); + CreateTensors(net_def.mutable_tensors(), model_data); {% if net.mem_arena.mem_block|length != 0 %} CreateMemoryArena(net_def.mutable_mem_arena()); diff --git a/python/tools/source_converter_lib.py b/python/tools/source_converter_lib.py index d842ffab49631f597cfb9e8925cd2f876e029106..c7061b98b65b425a5cdfe0901fa054609427ea78 100644 --- a/python/tools/source_converter_lib.py +++ b/python/tools/source_converter_lib.py @@ -91,7 +91,7 @@ class TensorInfo: def stringfy(value): return ', '.join('"{0}"'.format(w) for w in value) -def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime): +def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime, embed_model_data): if obfuscate: obfuscate_name(net_def) else: @@ -109,18 +109,44 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, counter = 0 output_dir = os.path.dirname(output) + '/' # generate tensor source files + model_data = [] + offset = 0 for t in net_def.tensors: + tensor_info = TensorInfo(t, runtime) + # align + if tensor_info.data_type != 'DT_UINT8' and offset % 4 != 0: + padding = 4 - offset % 4 + model_data.extend(bytearray([0] * padding)) + offset += padding source = j2_env.get_template(template_name).render( tensor_info = TensorInfo(t, runtime), tensor = t, tag = model_tag, mode = 0, runtime = runtime, + offset = offset, ) + model_data.extend(tensor_info.data) + offset += len(tensor_info.data) with gfile.GFile(output_dir + 'tensor' + str(counter) + '.cc', "wb") as f: f.write(source) counter += 1 + # generate tensor data + source = j2_env.get_template(template_name).render( + tag = model_tag, + mode = 1, + embed_model_data = embed_model_data, + model_data_size = offset, + model_data = model_data + ) + with gfile.GFile(output_dir + 'tensor_data' + '.cc', "wb") as f: + f.write(source) + if not embed_model_data: + f = open(output_dir + model_tag + '.data', "wb") + f.write(bytearray(model_data)) + f.close() + # generate op source files counter = 0 op_size = len(net_def.op) @@ -130,7 +156,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, end = min(start+10, op_size), net = net_def, tag = model_tag, - mode = 1, + mode = 2, runtime = runtime, ) with gfile.GFile(output_dir + 'op' + str(counter) + '.cc', "wb") as f: @@ -143,9 +169,9 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, tensors = tensors, net = net_def, tag = model_tag, - mode = 2, + mode = 3, runtime = runtime, - model_pb_checksum = mode_pb_checksum, + model_pb_checksum = mode_pb_checksum ) with gfile.GFile(output, "wb") as f: f.write(source) diff --git a/python/tools/tf_converter.py b/python/tools/tf_converter.py index 0da5420fd954646679511335647d63a63867c1c1..4ad8bddd1531c81aacdecc08f5d9081905876a30 100644 --- a/python/tools/tf_converter.py +++ b/python/tools/tf_converter.py @@ -43,7 +43,7 @@ def main(unused_args): if FLAGS.output_type == 'source': source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, - FLAGS.model_tag, FLAGS.output, FLAGS.runtime) + FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data) else: with gfile.GFile(FLAGS.output, "wb") as f: f.write(output_graph_def.SerializeToString()) @@ -133,6 +133,11 @@ def parse_args(): type=str, default="", help="input shape.") + parser.add_argument( + "--embed_model_data", + type=str2bool, + default=True, + help="input shape.") return parser.parse_known_args()