提交 0ecfd7a0 编写于 作者: L Liangliang He

Add model weights checksum for caffe model

上级 1a86f7b8
......@@ -45,6 +45,7 @@ def main(unused_args):
FLAGS.model_checksum))
sys.exit(-1)
weight_checksum = None
if FLAGS.platform == 'caffe':
if not os.path.isfile(FLAGS.weight_file):
print("Input weight file '" + FLAGS.weight_file +
......@@ -82,8 +83,8 @@ def main(unused_args):
if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(
output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime,
output_graph_def, model_checksum, weight_checksum, FLAGS.template,
FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime,
FLAGS.embed_model_data, FLAGS.winograd)
else:
with open(FLAGS.output, "wb") as f:
......
......@@ -159,7 +159,7 @@ const std::string ModelName() {
}
const std::string ModelChecksum() {
return {{ model_pb_checksum|tojson }};
return {{ checksum|tojson }};
}
const std::string ModelBuildTime() {
......
......@@ -124,8 +124,8 @@ def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate,
model_tag, output, runtime, embed_model_data,
def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
obfuscate, model_tag, output, runtime, embed_model_data,
winograd_conv):
if obfuscate:
obfuscate_name(net_def)
......@@ -201,6 +201,9 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate,
TensorInfo(i, net_def.tensors[i], runtime)
for i in range(len(net_def.tensors))
]
checksum = model_checksum
if weight_checksum is not None:
checksum = "{},{}".format(model_checksum, weight_checksum)
source = j2_env.get_template(template_name).render(
tensors=tensors,
net=net_def,
......@@ -209,7 +212,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate,
obfuscate=obfuscate,
embed_model_data=embed_model_data,
winograd_conv=winograd_conv,
model_pb_checksum=mode_pb_checksum,
checksum=checksum,
build_time=build_time)
with open(output, "wb") as f:
f.write(source)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册