提交 91197211 编写于 作者: 刘琦

Merge branch 'fix_caffe_validate' into 'master'

fix output key error and support batch input in caffe validate

See merge request !838
......@@ -914,7 +914,7 @@ def validate_model(abi,
device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes),
validation_threshold)
validation_threshold, ",".join(input_data_types))
elif caffe_env == common.CaffeEnvType.DOCKER:
docker_image_id = sh.docker("images", "-q", image_name)
if not docker_image_id:
......@@ -979,6 +979,7 @@ def validate_model(abi,
"--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes),
"--validation_threshold=%f" % validation_threshold,
"--input_data_type=%s" % ",".join(input_data_types),
_fg=True)
six.print_("Validation done!\n")
......
......@@ -163,12 +163,15 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
input_blob_name = net.top_names[input_names[i]][0]
except ValueError:
pass
net.blobs[input_blob_name].data[0] = input_value
new_shape = input_value.shape
net.blobs[input_blob_name].reshape(*new_shape)
for index in range(input_value.shape[0]):
net.blobs[input_blob_name].data[index] = input_value[index]
net.forward()
for i in range(len(output_names)):
value = net.blobs[net.top_names[output_names[i]][0]].data
value = net.blobs[output_names[i]].data
out_shape = output_shapes[i]
if len(out_shape) == 4:
out_shape[1], out_shape[2], out_shape[3] = \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册