提交 f64a27e5 编写于 作者: L liuqi

Add weight checksum for caffe model.

上级 81d83a95
...@@ -20,29 +20,45 @@ def main(unused_args): ...@@ -20,29 +20,45 @@ def main(unused_args):
print("Input graph file '" + FLAGS.model_file + "' does not exist!") print("Input graph file '" + FLAGS.model_file + "' does not exist!")
return -1 return -1
model_checksum = file_checksum(FLAGS.input) model_checksum = file_checksum(FLAGS.model_file)
if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum: if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum)) print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum))
return -1 return -1
if FLAGS.runtime == 'dsp': if FLAGS.platform == 'caffe':
from lib.python.tools import tf_dsp_converter_lib if not os.path.isfile(FLAGS.weight_file):
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( print("Input weight file '" + FLAGS.weight_file + "' does not exist!")
FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) return -1
else:
weight_checksum = file_checksum(FLAGS.weight_file)
if FLAGS.weight_checksum != "" and FLAGS.weight_checksum != weight_checksum:
print("Weight checksum mismatch: %s != %s" % (weight_checksum, FLAGS.weight_checksum))
return -1
if FLAGS.runtime == 'dsp':
print("DSP not support caffe model yet.")
return -1
input_shape = [] input_shape = []
if FLAGS.input_shape != "": if FLAGS.input_shape != "":
input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')]) input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')])
if FLAGS.platform == 'tensorflow': from lib.python.tools import caffe_converter_lib
output_graph_def = caffe_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.weight_file, FLAGS.input_node, input_shape, FLAGS.output_node,
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
elif FLAGS.platform == 'tensorflow':
if FLAGS.runtime == 'dsp':
from lib.python.tools import tf_dsp_converter_lib
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode)
else:
input_shape = []
if FLAGS.input_shape != "":
input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')])
from lib.python.tools import tf_converter_lib from lib.python.tools import tf_converter_lib
output_graph_def = tf_converter_lib.convert_to_mace_pb( output_graph_def = tf_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, input_shape, FLAGS.output_node, FLAGS.model_file, FLAGS.input_node, input_shape, FLAGS.output_node,
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
elif FLAGS.platform == 'caffe':
from lib.python.tools import caffe_converter_lib
output_graph_def = caffe_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.weight_file, FLAGS.input_node, input_shape, FLAGS.output_node,
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate, source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate,
...@@ -82,6 +98,11 @@ def parse_args(): ...@@ -82,6 +98,11 @@ def parse_args():
type=str, type=str,
default="", default="",
help="Model file sha256 checksum") help="Model file sha256 checksum")
parser.add_argument(
"--weight_checksum",
type=str,
default="",
help="Weight file sha256 checksum")
parser.add_argument( parser.add_argument(
"--output", "--output",
type=str, type=str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册