提交 90f41cce 编写于 作者: L liuqi

Fix validation wrong input file bug.

上级 83623605
......@@ -28,7 +28,7 @@ def load_data(file):
return np.empty([0])
def format_output_name(name):
def format_name(name):
return re.sub('[^0-9a-zA-Z]+', '_', name)
......@@ -71,7 +71,7 @@ def validate_tf_model(input_names, input_shapes, output_names):
input_dict = {}
for i in range(len(input_names)):
input_value = load_data(
FLAGS.input_file + "_" + input_names[i])
FLAGS.input_file + "_" + format_name(input_names[i]))
input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name(
input_names[i] + ':0')
......@@ -84,7 +84,7 @@ def validate_tf_model(input_names, input_shapes, output_names):
output_values = session.run(output_nodes, feed_dict=input_dict)
for i in range(len(output_names)):
output_file_name = FLAGS.mace_out_file + "_" + \
format_output_name(output_names[i])
format_name(output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(output_names[i], mace_out_value,
output_values[i])
......@@ -92,7 +92,7 @@ def validate_tf_model(input_names, input_shapes, output_names):
def validate_caffe_model(input_names, input_shapes, output_names,
output_shapes):
os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints
os.environ['GLOG_minloglevel'] = '1' # suppress Caffe verbose prints
import caffe
if not os.path.isfile(FLAGS.model_file):
print("Input graph file '" + FLAGS.model_file + "' does not exist!")
......@@ -106,7 +106,8 @@ def validate_caffe_model(input_names, input_shapes, output_names,
net = caffe.Net(FLAGS.model_file, caffe.TEST, weights=FLAGS.weight_file)
for i in range(len(input_names)):
input_value = load_data(FLAGS.input_file + "_" + input_names[i])
input_value = load_data(FLAGS.input_file + "_" +
format_name(input_names[i]))
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1,
2))
input_blob_name = input_names[i]
......@@ -125,7 +126,7 @@ def validate_caffe_model(input_names, input_shapes, output_names,
out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[
1], out_shape[2]
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = FLAGS.mace_out_file + "_" + format_output_name(
output_file_name = FLAGS.mace_out_file + "_" + format_name(
output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(output_names[i], mace_out_value, value)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册