diff --git a/tools/sh_commands.py b/tools/sh_commands.py index b732366fa8082b9a8cf4f63c94e550858ccf71df..b7e3446c55c5859ec425fc69a182e6e340cad13b 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -914,7 +914,7 @@ def validate_model(abi, device_type, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), - validation_threshold) + validation_threshold, ",".join(input_data_types)) elif caffe_env == common.CaffeEnvType.DOCKER: docker_image_id = sh.docker("images", "-q", image_name) if not docker_image_id: @@ -979,6 +979,7 @@ def validate_model(abi, "--input_shape=%s" % ":".join(input_shapes), "--output_shape=%s" % ":".join(output_shapes), "--validation_threshold=%f" % validation_threshold, + "--input_data_type=%s" % ",".join(input_data_types), _fg=True) six.print_("Validation done!\n") diff --git a/tools/validate.py b/tools/validate.py index bc3a9c5db46e2851d6309ad2b0c181b6a9acd26d..be499c1a3bed51ea2e0631d71dd3d3630ad97bff 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -163,12 +163,15 @@ def validate_caffe_model(platform, device_type, model_file, input_file, input_blob_name = net.top_names[input_names[i]][0] except ValueError: pass - net.blobs[input_blob_name].data[0] = input_value + new_shape = input_value.shape + net.blobs[input_blob_name].reshape(*new_shape) + for index in range(input_value.shape[0]): + net.blobs[input_blob_name].data[index] = input_value[index] net.forward() for i in range(len(output_names)): - value = net.blobs[net.top_names[output_names[i]][0]].data + value = net.blobs[output_names[i]].data out_shape = output_shapes[i] if len(out_shape) == 4: out_shape[1], out_shape[2], out_shape[3] = \