提交 710e319f 编写于 作者: L Liangliang He

Support customized similarity threshold for model validation

上级 4a95917c
...@@ -33,6 +33,9 @@ Required dependencies ...@@ -33,6 +33,9 @@ Required dependencies
* - Numpy * - Numpy
- pip install -I numpy==1.14.0 - pip install -I numpy==1.14.0
- Required by model validation - Required by model validation
* - six
- pip install -I six==1.11.0
- Required for Python 2 and 3 compatibility (TODO)
Optional dependencies Optional dependencies
--------------------- ---------------------
......
...@@ -76,6 +76,8 @@ in one deployment file. ...@@ -76,6 +76,8 @@ in one deployment file.
- The numerical range of the input tensors' data, default [-1, 1]. It is only for test. - The numerical range of the input tensors' data, default [-1, 1]. It is only for test.
* - validation_inputs_data * - validation_inputs_data
- [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used. - [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used.
* - validation_threshold
- [optional] Specify the similarity threshold for validation. A dict with key in 'CPU', 'GPU' and/or 'HEXAGON' and value <= 1.0.
* - runtime * - runtime
- The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU. - The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type * - data_type
......
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
import re import re
import sh import sh
import subprocess import subprocess
import six
import sys import sys
import urllib import urllib
import yaml import yaml
...@@ -189,6 +190,7 @@ class YAMLKeyword(object): ...@@ -189,6 +190,7 @@ class YAMLKeyword(object):
quantize = 'quantize' quantize = 'quantize'
quantize_range_file = 'quantize_range_file' quantize_range_file = 'quantize_range_file'
validation_inputs_data = 'validation_inputs_data' validation_inputs_data = 'validation_inputs_data'
validation_threshold = 'validation_threshold'
graph_optimize_options = 'graph_optimize_options' # internal use for now graph_optimize_options = 'graph_optimize_options' # internal use for now
...@@ -444,6 +446,30 @@ def format_model_config(flags): ...@@ -444,6 +446,30 @@ def format_model_config(flags):
"'%s' is necessary in subgraph" % key) "'%s' is necessary in subgraph" % key)
if not isinstance(value, list): if not isinstance(value, list):
subgraph[key] = [value] subgraph[key] = [value]
validation_threshold = subgraph.get(
YAMLKeyword.validation_threshold, {})
if not isinstance(validation_threshold, dict):
raise argparse.ArgumentTypeError(
'similarity threshold must be a dict.')
threshold_dict = {
DeviceType.CPU: 0.999,
DeviceType.GPU: 0.995,
DeviceType.HEXAGON: 0.930,
}
for k, v in six.iteritems(validation_threshold):
if k.upper() == 'DSP':
k = DeviceType.HEXAGON
if k.upper() not in (DeviceType.CPU,
DeviceType.GPU,
DeviceType.HEXAGON):
raise argparse.ArgumentTypeError(
'Unsupported validation threshold runtime: %s' % k)
threshold_dict[k.upper()] = v
subgraph[YAMLKeyword.validation_threshold] = threshold_dict
validation_inputs_data = subgraph.get( validation_inputs_data = subgraph.get(
YAMLKeyword.validation_inputs_data, []) YAMLKeyword.validation_inputs_data, [])
if not isinstance(validation_inputs_data, list): if not isinstance(validation_inputs_data, list):
...@@ -1202,7 +1228,8 @@ def run_specific_target(flags, configs, target_abi, ...@@ -1202,7 +1228,8 @@ def run_specific_target(flags, configs, target_abi,
output_shapes=subgraphs[0][YAMLKeyword.output_shapes], output_shapes=subgraphs[0][YAMLKeyword.output_shapes],
model_output_dir=model_output_dir, model_output_dir=model_output_dir,
phone_data_dir=PHONE_DATA_DIR, phone_data_dir=PHONE_DATA_DIR,
caffe_env=flags.caffe_env) caffe_env=flags.caffe_env,
validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa
if flags.report and flags.round > 0: if flags.report and flags.round > 0:
tuned = is_tuned and device_type == DeviceType.GPU tuned = is_tuned and device_type == DeviceType.GPU
report_run_statistics( report_run_statistics(
......
...@@ -799,7 +799,8 @@ def validate_model(abi, ...@@ -799,7 +799,8 @@ def validate_model(abi,
phone_data_dir, phone_data_dir,
caffe_env, caffe_env,
input_file_name="model_input", input_file_name="model_input",
output_file_name="model_out"): output_file_name="model_out",
validation_threshold=0.9):
print("* Validate with %s" % platform) print("* Validate with %s" % platform)
if abi != "host": if abi != "host":
for output_name in output_nodes: for output_name in output_nodes:
...@@ -816,7 +817,8 @@ def validate_model(abi, ...@@ -816,7 +817,8 @@ def validate_model(abi,
"%s/%s" % (model_output_dir, input_file_name), "%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), device_type, "%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes)) ",".join(input_nodes), ",".join(output_nodes),
validation_threshold)
elif platform == "caffe": elif platform == "caffe":
image_name = "mace-caffe:latest" image_name = "mace-caffe:latest"
container_name = "mace_caffe_validator" container_name = "mace_caffe_validator"
...@@ -832,7 +834,8 @@ def validate_model(abi, ...@@ -832,7 +834,8 @@ def validate_model(abi,
"%s/%s" % (model_output_dir, output_file_name), "%s/%s" % (model_output_dir, output_file_name),
device_type, device_type,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes)) ",".join(input_nodes), ",".join(output_nodes),
validation_threshold)
elif caffe_env == common.CaffeEnvType.DOCKER: elif caffe_env == common.CaffeEnvType.DOCKER:
docker_image_id = sh.docker("images", "-q", image_name) docker_image_id = sh.docker("images", "-q", image_name)
if not docker_image_id: if not docker_image_id:
...@@ -896,6 +899,7 @@ def validate_model(abi, ...@@ -896,6 +899,7 @@ def validate_model(abi,
"--output_node=%s" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes), "--output_shape=%s" % ":".join(output_shapes),
"--validation_threshold=%f" % validation_threshold,
_fg=True) _fg=True)
print("Validation done!\n") print("Validation done!\n")
......
...@@ -35,6 +35,7 @@ import common ...@@ -35,6 +35,7 @@ import common
# --output_node output_node \ # --output_node output_node \
# --input_shape 1,64,64,3 \ # --input_shape 1,64,64,3 \
# --output_shape 1,64,64,2 # --output_shape 1,64,64,2
# --validation_threshold 0.995
VALIDATION_MODULE = 'VALIDATION' VALIDATION_MODULE = 'VALIDATION'
...@@ -47,7 +48,7 @@ def load_data(file): ...@@ -47,7 +48,7 @@ def load_data(file):
def compare_output(platform, device_type, output_name, mace_out_value, def compare_output(platform, device_type, output_name, mace_out_value,
out_value): out_value, validation_threshold):
if mace_out_value.size != 0: if mace_out_value.size != 0:
out_value = out_value.reshape(-1) out_value = out_value.reshape(-1)
mace_out_value = mace_out_value.reshape(-1) mace_out_value = mace_out_value.reshape(-1)
...@@ -56,9 +57,7 @@ def compare_output(platform, device_type, output_name, mace_out_value, ...@@ -56,9 +57,7 @@ def compare_output(platform, device_type, output_name, mace_out_value,
common.MaceLogger.summary( common.MaceLogger.summary(
output_name + ' MACE VS ' + platform.upper() output_name + ' MACE VS ' + platform.upper()
+ ' similarity: ' + str(similarity)) + ' similarity: ' + str(similarity))
if (device_type == "CPU" and similarity > 0.999) or \ if similarity > validation_threshold:
(device_type == "GPU" and similarity > 0.995) or \
(device_type == "HEXAGON" and similarity > 0.930):
common.MaceLogger.summary( common.MaceLogger.summary(
common.StringFormatter.block("Similarity Test Passed")) common.StringFormatter.block("Similarity Test Passed"))
else: else:
...@@ -78,7 +77,8 @@ def normalize_tf_tensor_name(name): ...@@ -78,7 +77,8 @@ def normalize_tf_tensor_name(name):
def validate_tf_model(platform, device_type, model_file, input_file, def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, output_names): mace_out_file, input_names, input_shapes,
output_names, validation_threshold):
import tensorflow as tf import tensorflow as tf
if not os.path.isfile(model_file): if not os.path.isfile(model_file):
common.MaceLogger.error( common.MaceLogger.error(
...@@ -115,12 +115,13 @@ def validate_tf_model(platform, device_type, model_file, input_file, ...@@ -115,12 +115,13 @@ def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
compare_output(platform, device_type, output_names[i], compare_output(platform, device_type, output_names[i],
mace_out_value, output_values[i]) mace_out_value, output_values[i],
validation_threshold)
def validate_caffe_model(platform, device_type, model_file, input_file, def validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, weight_file, input_names, input_shapes, mace_out_file, weight_file, input_names, input_shapes,
output_names, output_shapes): output_names, output_shapes, validation_threshold):
os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints
import caffe import caffe
if not os.path.isfile(model_file): if not os.path.isfile(model_file):
...@@ -162,11 +163,12 @@ def validate_caffe_model(platform, device_type, model_file, input_file, ...@@ -162,11 +163,12 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, output_names[i]) mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
compare_output(platform, device_type, output_names[i], mace_out_value, compare_output(platform, device_type, output_names[i], mace_out_value,
value) value, validation_threshold)
def validate(platform, model_file, weight_file, input_file, mace_out_file, def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_node, output_node): device_type, input_shape, output_shape, input_node, output_node,
validation_threshold):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')] input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in shape.split(',')]
...@@ -177,14 +179,15 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file, ...@@ -177,14 +179,15 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
if platform == 'tensorflow': if platform == 'tensorflow':
validate_tf_model(platform, device_type, model_file, input_file, validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names) output_names, validation_threshold)
elif platform == 'caffe': elif platform == 'caffe':
output_shape_strs = [shape for shape in output_shape.split(':')] output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')] output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs] for shape in output_shape_strs]
validate_caffe_model(platform, device_type, model_file, input_file, validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, weight_file, input_names, mace_out_file, weight_file, input_names,
input_shapes, output_names, output_shapes) input_shapes, output_names, output_shapes,
validation_threshold)
def parse_args(): def parse_args():
...@@ -219,6 +222,9 @@ def parse_args(): ...@@ -219,6 +222,9 @@ def parse_args():
"--input_node", type=str, default="input_node", help="input node") "--input_node", type=str, default="input_node", help="input node")
parser.add_argument( parser.add_argument(
"--output_node", type=str, default="output_node", help="output node") "--output_node", type=str, default="output_node", help="output node")
parser.add_argument(
"--validation_threshold", type=float, default=0.995,
help="validation similarity threshold")
return parser.parse_known_args() return parser.parse_known_args()
...@@ -234,4 +240,5 @@ if __name__ == '__main__': ...@@ -234,4 +240,5 @@ if __name__ == '__main__':
FLAGS.input_shape, FLAGS.input_shape,
FLAGS.output_shape, FLAGS.output_shape,
FLAGS.input_node, FLAGS.input_node,
FLAGS.output_node) FLAGS.output_node,
FLAGS.validation_threshold)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册