From 22a2766af0a23119a8a2bd5549ad58d24701778f Mon Sep 17 00:00:00 2001 From: yejianwu Date: Thu, 18 Oct 2018 20:23:52 +0800 Subject: [PATCH] fix output key error and support batch input in caffe validate --- tools/sh_commands.py | 3 ++- tools/validate.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tools/sh_commands.py b/tools/sh_commands.py index b732366f..b7e3446c 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -914,7 +914,7 @@ def validate_model(abi, device_type, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), - validation_threshold) + validation_threshold, ",".join(input_data_types)) elif caffe_env == common.CaffeEnvType.DOCKER: docker_image_id = sh.docker("images", "-q", image_name) if not docker_image_id: @@ -979,6 +979,7 @@ def validate_model(abi, "--input_shape=%s" % ":".join(input_shapes), "--output_shape=%s" % ":".join(output_shapes), "--validation_threshold=%f" % validation_threshold, + "--input_data_type=%s" % ",".join(input_data_types), _fg=True) six.print_("Validation done!\n") diff --git a/tools/validate.py b/tools/validate.py index bc3a9c5d..be499c1a 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -163,12 +163,15 @@ def validate_caffe_model(platform, device_type, model_file, input_file, input_blob_name = net.top_names[input_names[i]][0] except ValueError: pass - net.blobs[input_blob_name].data[0] = input_value + new_shape = input_value.shape + net.blobs[input_blob_name].reshape(*new_shape) + for index in range(input_value.shape[0]): + net.blobs[input_blob_name].data[index] = input_value[index] net.forward() for i in range(len(output_names)): - value = net.blobs[net.top_names[output_names[i]][0]].data + value = net.blobs[output_names[i]].data out_shape = output_shapes[i] if len(out_shape) == 4: out_shape[1], out_shape[2], out_shape[3] = \ -- GitLab