From 90f41ccee8c069efde46f860d608c8bfc9147e3d Mon Sep 17 00:00:00 2001 From: liuqi Date: Fri, 20 Apr 2018 19:00:11 +0800 Subject: [PATCH] Fix validation wrong input file bug. --- tools/validate.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tools/validate.py b/tools/validate.py index bc93d709..79b9fc0a 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -28,7 +28,7 @@ def load_data(file): return np.empty([0]) -def format_output_name(name): +def format_name(name): return re.sub('[^0-9a-zA-Z]+', '_', name) @@ -71,7 +71,7 @@ def validate_tf_model(input_names, input_shapes, output_names): input_dict = {} for i in range(len(input_names)): input_value = load_data( - FLAGS.input_file + "_" + input_names[i]) + FLAGS.input_file + "_" + format_name(input_names[i])) input_value = input_value.reshape(input_shapes[i]) input_node = graph.get_tensor_by_name( input_names[i] + ':0') @@ -84,7 +84,7 @@ def validate_tf_model(input_names, input_shapes, output_names): output_values = session.run(output_nodes, feed_dict=input_dict) for i in range(len(output_names)): output_file_name = FLAGS.mace_out_file + "_" + \ - format_output_name(output_names[i]) + format_name(output_names[i]) mace_out_value = load_data(output_file_name) compare_output(output_names[i], mace_out_value, output_values[i]) @@ -92,7 +92,7 @@ def validate_tf_model(input_names, input_shapes, output_names): def validate_caffe_model(input_names, input_shapes, output_names, output_shapes): - os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints + os.environ['GLOG_minloglevel'] = '1' # suppress Caffe verbose prints import caffe if not os.path.isfile(FLAGS.model_file): print("Input graph file '" + FLAGS.model_file + "' does not exist!") @@ -106,7 +106,8 @@ def validate_caffe_model(input_names, input_shapes, output_names, net = caffe.Net(FLAGS.model_file, caffe.TEST, weights=FLAGS.weight_file) for i in range(len(input_names)): - input_value = load_data(FLAGS.input_file + "_" + input_names[i]) + input_value = load_data(FLAGS.input_file + "_" + + format_name(input_names[i])) input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, 2)) input_blob_name = input_names[i] @@ -125,7 +126,7 @@ def validate_caffe_model(input_names, input_shapes, output_names, out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[ 1], out_shape[2] value = value.reshape(out_shape).transpose((0, 2, 3, 1)) - output_file_name = FLAGS.mace_out_file + "_" + format_output_name( + output_file_name = FLAGS.mace_out_file + "_" + format_name( output_names[i]) mace_out_value = load_data(output_file_name) compare_output(output_names[i], mace_out_value, value) -- GitLab