diff --git a/python/tools/model_header.template b/python/tools/model_header.template new file mode 100644 index 0000000000000000000000000000000000000000..9f5c776d52bd6456bf3c410216f5b4de1ce1fa58 --- /dev/null +++ b/python/tools/model_header.template @@ -0,0 +1,22 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// Generated by the mace converter. DO NOT EDIT! +// + +#include + +#include "mace/public/mace.h" + +namespace mace { +namespace {{tag}} { + +extern const unsigned char *LoadModelData(const char *model_data_file); + +extern void UnloadModelData(const unsigned char *model_data); + +extern NetDef CreateNet(const unsigned char *model_data); + +extern const std::string ModelChecksum(); + +} // namespace {{ tag }} +} // namespace mace diff --git a/python/tools/source_converter_lib.py b/python/tools/source_converter_lib.py index 48620344c084c1cd30069c362f07ffacf5c591f3..8dd1cd7c2befea4ce9e64dce91aa1531a3d00794 100644 --- a/python/tools/source_converter_lib.py +++ b/python/tools/source_converter_lib.py @@ -177,3 +177,11 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_ ) with open(output, "wb") as f: f.write(source) + + # generate model header file + template_name = 'model_header.template' + source = j2_env.get_template(template_name).render( + tag = model_tag, + ) + with open(output_dir + model_tag + '.h', "wb") as f: + f.write(source)