diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index 056a0b2c1077ea4dc45abcfe68e9c47d6cb25ea4..0f367da0fe81a9421cfdb200aea15cd9048310a4 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -45,6 +45,7 @@ def main(unused_args): FLAGS.model_checksum)) sys.exit(-1) + weight_checksum = None if FLAGS.platform == 'caffe': if not os.path.isfile(FLAGS.weight_file): print("Input weight file '" + FLAGS.weight_file + @@ -82,8 +83,8 @@ def main(unused_args): if FLAGS.output_type == 'source': source_converter_lib.convert_to_source( - output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate, - FLAGS.model_tag, FLAGS.output, FLAGS.runtime, + output_graph_def, model_checksum, weight_checksum, FLAGS.template, + FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data, FLAGS.winograd) else: with open(FLAGS.output, "wb") as f: diff --git a/mace/python/tools/model.jinja2 b/mace/python/tools/model.jinja2 index 3555b20513fc3bdeabfc1a51f8a8ce73c32d1dd3..bd228229b5e339b22172ddda650af2340b169a28 100644 --- a/mace/python/tools/model.jinja2 +++ b/mace/python/tools/model.jinja2 @@ -159,7 +159,7 @@ const std::string ModelName() { } const std::string ModelChecksum() { - return {{ model_pb_checksum|tojson }}; + return {{ checksum|tojson }}; } const std::string ModelBuildTime() { diff --git a/mace/python/tools/source_converter_lib.py b/mace/python/tools/source_converter_lib.py index c9fda3de87fdb76e2a9109a984510c566564259c..9e57d02bb1863626344dfdefa6811dcf72ab95cd 100644 --- a/mace/python/tools/source_converter_lib.py +++ b/mace/python/tools/source_converter_lib.py @@ -124,8 +124,8 @@ def stringfy(value): return ', '.join('"{0}"'.format(w) for w in value) -def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, - model_tag, output, runtime, embed_model_data, +def convert_to_source(net_def, model_checksum, weight_checksum, template_dir, + obfuscate, model_tag, output, runtime, embed_model_data, winograd_conv): if obfuscate: obfuscate_name(net_def) @@ -201,6 +201,9 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, TensorInfo(i, net_def.tensors[i], runtime) for i in range(len(net_def.tensors)) ] + checksum = model_checksum + if weight_checksum is not None: + checksum = "{},{}".format(model_checksum, weight_checksum) source = j2_env.get_template(template_name).render( tensors=tensors, net=net_def, @@ -209,7 +212,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, obfuscate=obfuscate, embed_model_data=embed_model_data, winograd_conv=winograd_conv, - model_pb_checksum=mode_pb_checksum, + checksum=checksum, build_time=build_time) with open(output, "wb") as f: f.write(source)