提交 0ecfd7a0 编写于 作者: L Liangliang He

Add model weights checksum for caffe model

上级 1a86f7b8
...@@ -45,6 +45,7 @@ def main(unused_args): ...@@ -45,6 +45,7 @@ def main(unused_args):
FLAGS.model_checksum)) FLAGS.model_checksum))
sys.exit(-1) sys.exit(-1)
weight_checksum = None
if FLAGS.platform == 'caffe': if FLAGS.platform == 'caffe':
if not os.path.isfile(FLAGS.weight_file): if not os.path.isfile(FLAGS.weight_file):
print("Input weight file '" + FLAGS.weight_file + print("Input weight file '" + FLAGS.weight_file +
...@@ -82,8 +83,8 @@ def main(unused_args): ...@@ -82,8 +83,8 @@ def main(unused_args):
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source( source_converter_lib.convert_to_source(
output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate, output_graph_def, model_checksum, weight_checksum, FLAGS.template,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime,
FLAGS.embed_model_data, FLAGS.winograd) FLAGS.embed_model_data, FLAGS.winograd)
else: else:
with open(FLAGS.output, "wb") as f: with open(FLAGS.output, "wb") as f:
......
...@@ -159,7 +159,7 @@ const std::string ModelName() { ...@@ -159,7 +159,7 @@ const std::string ModelName() {
} }
const std::string ModelChecksum() { const std::string ModelChecksum() {
return {{ model_pb_checksum|tojson }}; return {{ checksum|tojson }};
} }
const std::string ModelBuildTime() { const std::string ModelBuildTime() {
......
...@@ -124,8 +124,8 @@ def stringfy(value): ...@@ -124,8 +124,8 @@ def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value) return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
model_tag, output, runtime, embed_model_data, obfuscate, model_tag, output, runtime, embed_model_data,
winograd_conv): winograd_conv):
if obfuscate: if obfuscate:
obfuscate_name(net_def) obfuscate_name(net_def)
...@@ -201,6 +201,9 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, ...@@ -201,6 +201,9 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate,
TensorInfo(i, net_def.tensors[i], runtime) TensorInfo(i, net_def.tensors[i], runtime)
for i in range(len(net_def.tensors)) for i in range(len(net_def.tensors))
] ]
checksum = model_checksum
if weight_checksum is not None:
checksum = "{},{}".format(model_checksum, weight_checksum)
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tensors=tensors, tensors=tensors,
net=net_def, net=net_def,
...@@ -209,7 +212,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, ...@@ -209,7 +212,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate,
obfuscate=obfuscate, obfuscate=obfuscate,
embed_model_data=embed_model_data, embed_model_data=embed_model_data,
winograd_conv=winograd_conv, winograd_conv=winograd_conv,
model_pb_checksum=mode_pb_checksum, checksum=checksum,
build_time=build_time) build_time=build_time)
with open(output, "wb") as f: with open(output, "wb") as f:
f.write(source) f.write(source)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册