diff --git a/docs/user_guide/advanced_usage.rst b/docs/user_guide/advanced_usage.rst index dfd69cca91ef8ac90f35d1aa3dc6a4a9d8f832ac..8c452d08b93e14c16b6c3b771dec7095a293cf89 100644 --- a/docs/user_guide/advanced_usage.rst +++ b/docs/user_guide/advanced_usage.rst @@ -76,6 +76,8 @@ in one deployment file. - The numerical range of the input tensors' data, default [-1, 1]. It is only for test. * - validation_inputs_data - [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used. + * - accuracy_validation_script + - [optional] Specify the accuracy validation script as a plugin to test accuracy, see `doc <#validate-accuracy-of-mace-model>`__. * - validation_threshold - [optional] Specify the similarity threshold for validation. A dict with key in 'CPU', 'GPU' and/or 'HEXAGON' and value <= 1.0. * - backend @@ -358,6 +360,19 @@ Tuning for specific SoC's GPU // ... Same with the code in basic usage. +Validate accuracy of MACE model +------------------------------- + +MACE supports **python validation script** as a plugin to test the accuracy, the plugin script could be used for below two purpose. + +1. Test the **accuracy(like Top-1)** of MACE model(specifically quantization model) converted from other framework(like tensorflow) +2. Show some real output if you want to see it. + +The script define some interfaces like `preprocess` and `postprocess` to deal with input/outut and calculate the accuracy, +you could refer to the `sample code `__ for detail. +the sample code show how to calculate the Top-1 accuracy with imagenet validation dataset. + + Useful Commands --------------- * **run the model** diff --git a/docs/user_guide/models/demo_models.yml b/docs/user_guide/models/demo_models.yml index ad0527a9a61698eb4ea401ffa12ad70b11c0079f..77e54a1dbf3793642ee819e9d1e05bc95366f07b 100644 --- a/docs/user_guide/models/demo_models.yml +++ b/docs/user_guide/models/demo_models.yml @@ -2,6 +2,8 @@ library_name: mobile_squeeze # host, armeabi-v7a or arm64-v8a target_abis: [arm64-v8a] +# soc's name or all +target_socs: [all] # The build mode for model(s). # 'code' for transferring model(s) into cpp code, 'file' for keeping model(s) in protobuf file(s) (.pb). model_graph_format: code @@ -43,6 +45,8 @@ models: - prob output_shapes: - 1,1,1,1000 + accuracy_validation_script: + - path/to/your/script runtime: cpu+gpu limit_opencl_kernel_time: 0 obfuscate: 0 diff --git a/tools/accuracy_validator.py b/tools/accuracy_validator.py new file mode 100644 index 0000000000000000000000000000000000000000..837d01b665b879640cccbf6c33821181d1564dc2 --- /dev/null +++ b/tools/accuracy_validator.py @@ -0,0 +1,153 @@ +# Copyright 2019 The MACE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +import numpy as np +from PIL import Image + + +class AccuracyValidator(object): + """Accuracy Validator Plugin: + Usage: This script is used to calculate the accuracy(like Top-1) + of MACE model. + User could reload this validator script to do + other accuracy validation(like MIOU for segmentation), + the new script's interface should be same + with current AccuracyValidator exactly, + + Warning: Do not use relative path in this script. + """ + def __init__(self, **kwargs): + # absolute path + validation_set_image_dir = \ + '/path/to/your/validation/set/directory' + validation_set_label_file_path =\ + '/path/to/imagenet_groundtruth_labels.txt' + black_list_file_path = \ + '/path/to/imagenet_blacklist.txt' + imagenet_classes_file = \ + '/path/to/imagenet_classes.txt' + self._imagenet_classes = [ + line.rstrip('\n') for line in open(imagenet_classes_file)] + imagenet_classes_map = {} + for idx in range(len(self._imagenet_classes)): + imagenet_classes_map[self._imagenet_classes[idx]] = idx + black_list = [ + int(line.rstrip('\n')) for line in open(black_list_file_path)] + + self._samples = [] + self._labels = [0] # image id start from 1 + self._correct_count = 0 + + for img_file in os.listdir(validation_set_image_dir): + if img_file.endswith(".JPEG"): + img_id = int(os.path.splitext(img_file)[0].split('_')[-1]) + if img_id not in black_list: + self._samples.append( + os.path.join(validation_set_image_dir, img_file)) + for label in open(validation_set_label_file_path): + label = label.rstrip('\n') + self._labels.append(imagenet_classes_map[label]) + + def sample_size(self): + """ + :return: the size of samples in validation set + """ + return len(self._samples) + + def batch_size(self): + """ + batch size to do validation to speed up validation. + Keep same with batch size of input_shapes + in model deployment file(.yml). do not set too large + :return: batch size + """ + return 1 + + def preprocess(self, sample_idx_start, sample_idx_end, **kwargs): + """ + pre-process the input sample + :param sample_idx_start: start index of the sample. + :param sample_idx_end: end index of the sample(not include). + :param kwargs: other parameters. + :return: the batched inputs' map(name: data) feed into your model + """ + inputs = {} + batch_sample_data = [] + sample_idx_end = min(sample_idx_end, self.sample_size()) + for sample_idx in range(sample_idx_start, sample_idx_end): + sample_file_path = self._samples[sample_idx] + sample_img = Image.open(sample_file_path).resize((224, 224)) + sample_data = np.asarray(sample_img, dtype=np.float32) + sample_data = (2.0 / 255.0) * sample_data - 1.0 + batch_sample_data.append(sample_data.tolist()) + inputs["input"] = batch_sample_data + return inputs + + def postprocess(self, + sample_idx_start, + sample_idx_end, + output_map, + **kwargs): + """ + post-process the outputs of your model and calculate the accuracy + :param sample_idx_start: start index of input sample + :param sample_idx_end: end index of input sample + :param output_map: output map of the model + :param kwargs: other parameters. + :return: None + """ + output = output_map['MobilenetV2/Predictions/Reshape_1'] + sample_idx_end = min(sample_idx_end, self.sample_size()) + batch_size = sample_idx_end - sample_idx_start + output = np.array(output).reshape((batch_size, -1)) + output = np.argmax(output, axis=-1) + output_idx = 0 + for sample_idx in range(sample_idx_start, sample_idx_end): + sample_file_path = self._samples[sample_idx] + img_id = int(os.path.splitext(sample_file_path)[0].split('_')[-1]) + if output[output_idx] == self._labels[img_id]: + self._correct_count += 1 + else: + print(img_id, 'predict %s vs gt %s' % + (self._imagenet_classes[output[output_idx]], + self._imagenet_classes[self._labels[img_id]])) + output_idx += 1 + + def result(self): + """ + print or show the result + :return: None + """ + print("==========================================") + print("Top 1 accuracy: %f" % + (self._correct_count * 1.0 / self.sample_size())) + print("==========================================") + + +if __name__ == '__main__': + # sample usage code + validator = AccuracyValidator() + sample_size = validator.sample_size() + val_batch_size = validator.batch_size() + for i in range(0, sample_size, val_batch_size): + inputs = validator.preprocess(i, i+val_batch_size) + print(np.array(inputs['input']).shape) + + output_map = { + 'MobilenetV2/Predictions/Reshape_1': np.array([[0, 1], [1, 0]]) + } + validator.postprocess(i, i+val_batch_size, output_map) + + validator.result() diff --git a/tools/common.py b/tools/common.py index 0884319ff9f369c0d05271141e16935cdbf57a56..ad36810106b3b33833d885e275c8dc4a9a660bab 100644 --- a/tools/common.py +++ b/tools/common.py @@ -413,6 +413,7 @@ class YAMLKeyword(object): cl_mem_type = 'cl_mem_type' backend = 'backend' validation_outputs_data = 'validation_outputs_data' + accuracy_validation_script = 'accuracy_validation_script' docker_image_tag = 'docker_image_tag' dockerfile_path = 'dockerfile_path' dockerfile_sha256_checksum = 'dockerfile_sha256_checksum' diff --git a/tools/converter.py b/tools/converter.py index a5df88a9cecd8493b26b6462b33a9aaff729f99b..b82200bdd95e40b0c2a94be8c0ff66e01c718f0a 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -532,6 +532,16 @@ def format_model_config(flags): subgraph[YAMLKeyword.input_ranges] = \ [str(v) for v in subgraph[YAMLKeyword.input_ranges]] + accuracy_validation_script = subgraph.get( + YAMLKeyword.accuracy_validation_script, "") + if isinstance(accuracy_validation_script, list): + mace_check(len(accuracy_validation_script) == 1, + ModuleName.YAML_CONFIG, + "Only support one accuracy validation script") + accuracy_validation_script = accuracy_validation_script[0] + subgraph[YAMLKeyword.accuracy_validation_script] = \ + accuracy_validation_script + for key in [YAMLKeyword.limit_opencl_kernel_time, YAMLKeyword.nnlib_graph_mode, YAMLKeyword.obfuscate, diff --git a/tools/device.py b/tools/device.py index 39d5783e86b4dac0cda4cfcf68c8d6eb090f7cd9..152ee4a28d0a9a9809c8cc49b0ab1cfc22603151 100644 --- a/tools/device.py +++ b/tools/device.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import os import sys import socket @@ -387,7 +388,7 @@ class DeviceWrapper: subgraphs = model_config[YAMLKeyword.subgraphs] # generate input data - sh_commands.gen_random_input( + sh_commands.gen_input( model_output_dir, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], @@ -460,18 +461,14 @@ class DeviceWrapper: return output_configs - def run_specify_abi(self, flags, configs, target_abi): - if target_abi not in self.target_abis: - six.print_('The device %s with soc %s do not support the abi %s' % - (self.device_name, self.target_socs, target_abi)) - return + def run_model(self, flags, configs, target_abi, + model_name, output_config, runtime, tuning): library_name = configs[YAMLKeyword.library_name] - mace_lib_type = flags.mace_lib_type embed_model_data = \ configs[YAMLKeyword.model_data_format] == ModelFormat.code build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi) - # get target name for run + mace_lib_type = flags.mace_lib_type if flags.example: if mace_lib_type == MACELibType.static: target_name = EXAMPLE_STATIC_NAME @@ -483,6 +480,114 @@ class DeviceWrapper: else: target_name = MACE_RUN_DYNAMIC_NAME link_dynamic = mace_lib_type == MACELibType.dynamic + + if target_abi != ABIType.host: + self.clear_data_dir() + + model_config = configs[YAMLKeyword.models][model_name] + subgraphs = model_config[YAMLKeyword.subgraphs] + + model_output_base_dir, model_output_dir, mace_model_dir = \ + get_build_model_dirs( + library_name, model_name, target_abi, self, + model_config[YAMLKeyword.model_file_path]) + + model_opencl_output_bin_path = '' + model_opencl_parameter_path = '' + if tuning: + model_opencl_output_bin_path = \ + '{}/{}/{}'.format(model_output_dir, + BUILD_TMP_OPENCL_BIN_DIR, + CL_COMPILED_BINARY_FILE_NAME) + model_opencl_parameter_path = \ + '{}/{}/{}'.format(model_output_dir, + BUILD_TMP_OPENCL_BIN_DIR, + CL_TUNED_PARAMETER_FILE_NAME) + elif target_abi != ABIType.host and self.target_socs: + model_opencl_output_bin_path = get_opencl_binary_output_path( + library_name, target_abi, self + ) + model_opencl_parameter_path = get_opencl_parameter_output_path( + library_name, target_abi, self + ) + # run for specified soc + device_type = parse_device_type(runtime) + self.tuning_run( + abi=target_abi, + target_dir=build_tmp_binary_dir, + target_name=target_name, + vlog_level=flags.vlog_level, + embed_model_data=embed_model_data, + model_output_dir=model_output_dir, + input_nodes=subgraphs[0][YAMLKeyword.input_tensors], + output_nodes=output_config[ + YAMLKeyword.output_tensors], + input_shapes=subgraphs[0][YAMLKeyword.input_shapes], + output_shapes=output_config[YAMLKeyword.output_shapes], + input_data_formats=subgraphs[0][ + YAMLKeyword.input_data_formats], + output_data_formats=subgraphs[0][ + YAMLKeyword.output_data_formats], + mace_model_dir=mace_model_dir, + model_tag=model_name, + device_type=device_type, + running_round=flags.round, + restart_round=flags.restart_round, + limit_opencl_kernel_time=model_config[ + YAMLKeyword.limit_opencl_kernel_time], + tuning=False, + out_of_range_check=flags.gpu_out_of_range_check, + model_graph_format=configs[ + YAMLKeyword.model_graph_format], + omp_num_threads=flags.omp_num_threads, + cpu_affinity_policy=flags.cpu_affinity_policy, + gpu_perf_hint=flags.gpu_perf_hint, + gpu_priority_hint=flags.gpu_priority_hint, + runtime_failure_ratio=flags.runtime_failure_ratio, + address_sanitizer=flags.address_sanitizer, + opencl_binary_file=model_opencl_output_bin_path, + opencl_parameter_file=model_opencl_parameter_path, + libmace_dynamic_library_path=LIBMACE_DYNAMIC_PATH, + link_dynamic=link_dynamic, + quantize_stat=flags.quantize_stat, + input_dir=flags.input_dir, + output_dir=flags.output_dir, + layers_validate_file=output_config[ + YAMLKeyword.model_file_path] + ) + + def get_output_map(self, + target_abi, + output_nodes, + output_shapes, + model_output_dir): + output_map = {} + for i in range(len(output_nodes)): + output_name = output_nodes[i] + formatted_name = common.formatted_file_name( + "model_out", output_name) + if target_abi != "host": + if os.path.exists("%s/%s" % (model_output_dir, + formatted_name)): + sh.rm("-rf", "%s/%s" % (model_output_dir, + formatted_name)) + self.pull_from_data_dir(formatted_name, + model_output_dir) + output_file_path = os.path.join(model_output_dir, + formatted_name) + output_shape = [ + int(x) for x in common.split_shape(output_shapes[i])] + output_map[output_name] = np.fromfile( + output_file_path, dtype=np.float32).reshape(output_shape) + return output_map + + def run_specify_abi(self, flags, configs, target_abi): + if target_abi not in self.target_abis: + six.print_('The device %s with soc %s do not support the abi %s' % + (self.device_name, self.target_socs, target_abi)) + return + library_name = configs[YAMLKeyword.library_name] + model_output_dirs = [] for model_name in configs[YAMLKeyword.models]: @@ -510,9 +615,7 @@ class DeviceWrapper: sh.rm('-rf', model_output_dir) os.makedirs(model_output_dir) - is_tuned = False - model_opencl_output_bin_path = '' - model_opencl_parameter_path = '' + tuning = False if not flags.address_sanitizer \ and not flags.example \ and target_abi != ABIType.host \ @@ -525,33 +628,36 @@ class DeviceWrapper: self.tuning(library_name, model_name, model_config, configs[YAMLKeyword.model_graph_format], configs[YAMLKeyword.model_data_format], - target_abi, mace_lib_type) + target_abi, flags.mace_lib_type) model_output_dirs.append(model_output_dir) - model_opencl_output_bin_path = \ - '{}/{}/{}'.format(model_output_dir, - BUILD_TMP_OPENCL_BIN_DIR, - CL_COMPILED_BINARY_FILE_NAME) - model_opencl_parameter_path = \ - '{}/{}/{}'.format(model_output_dir, - BUILD_TMP_OPENCL_BIN_DIR, - CL_TUNED_PARAMETER_FILE_NAME) self.clear_data_dir() - is_tuned = True - elif target_abi != ABIType.host and self.target_socs: - model_opencl_output_bin_path = get_opencl_binary_output_path( - library_name, target_abi, self - ) - model_opencl_parameter_path = get_opencl_parameter_output_path( - library_name, target_abi, self - ) - sh_commands.gen_random_input( - model_output_dir, - subgraphs[0][YAMLKeyword.input_tensors], - subgraphs[0][YAMLKeyword.input_shapes], - subgraphs[0][YAMLKeyword.validation_inputs_data], - input_ranges=subgraphs[0][YAMLKeyword.input_ranges], - input_data_types=subgraphs[0][YAMLKeyword.input_data_types] - ) + tuning = True + + accuracy_validation_script = \ + subgraphs[0][YAMLKeyword.accuracy_validation_script] + output_configs = [] + if not accuracy_validation_script and flags.layers != "-1": + mace_check(configs[YAMLKeyword.model_graph_format] == + ModelFormat.file and + configs[YAMLKeyword.model_data_format] == + ModelFormat.file, "Device", + "'--layers' only supports model format 'file'.") + output_configs = self.get_layers(mace_model_dir, + model_name, + flags.layers) + # run for specified soc + if not subgraphs[0][YAMLKeyword.check_tensors]: + output_nodes = subgraphs[0][YAMLKeyword.output_tensors] + output_shapes = subgraphs[0][YAMLKeyword.output_shapes] + else: + output_nodes = subgraphs[0][YAMLKeyword.check_tensors] + output_shapes = subgraphs[0][YAMLKeyword.check_shapes] + model_path = "%s/%s.pb" % (mace_model_dir, model_name) + output_config = {YAMLKeyword.model_file_path: model_path, + YAMLKeyword.output_tensors: output_nodes, + YAMLKeyword.output_shapes: output_shapes} + output_configs.append(output_config) + runtime_list = [] if target_abi == ABIType.host: runtime_list.append(RuntimeType.cpu) @@ -559,143 +665,128 @@ class DeviceWrapper: runtime_list.extend([RuntimeType.cpu, RuntimeType.gpu]) else: runtime_list.append(model_runtime) - for runtime in runtime_list: - device_type = parse_device_type(runtime) - # run for specified soc - if not subgraphs[0][YAMLKeyword.check_tensors]: - output_nodes = subgraphs[0][YAMLKeyword.output_tensors] - output_shapes = subgraphs[0][YAMLKeyword.output_shapes] - else: - output_nodes = subgraphs[0][YAMLKeyword.check_tensors] - output_shapes = subgraphs[0][YAMLKeyword.check_shapes] - output_configs = [] - log_file = "" - if flags.layers != "-1": - mace_check(configs[YAMLKeyword.model_graph_format] == - ModelFormat.file and - configs[YAMLKeyword.model_data_format] == - ModelFormat.file, "Device", - "'--layers' only supports model format 'file'.") - output_configs = self.get_layers(mace_model_dir, - model_name, - flags.layers) - log_dir = mace_model_dir + "/" + runtime - if os.path.exists(log_dir): - sh.rm('-rf', log_dir) - os.makedirs(log_dir) - log_file = log_dir + "/log.csv" - model_path = "%s/%s.pb" % (mace_model_dir, model_name) - output_config = {YAMLKeyword.model_file_path: model_path, - YAMLKeyword.output_tensors: output_nodes, - YAMLKeyword.output_shapes: output_shapes} - output_configs.append(output_config) - for output_config in output_configs: - run_output = self.tuning_run( - abi=target_abi, - target_dir=build_tmp_binary_dir, - target_name=target_name, - vlog_level=flags.vlog_level, - embed_model_data=embed_model_data, - model_output_dir=model_output_dir, - input_nodes=subgraphs[0][YAMLKeyword.input_tensors], - output_nodes=output_config[ - YAMLKeyword.output_tensors], - input_shapes=subgraphs[0][YAMLKeyword.input_shapes], - output_shapes=output_config[YAMLKeyword.output_shapes], - input_data_formats=subgraphs[0][ - YAMLKeyword.input_data_formats], - output_data_formats=subgraphs[0][ - YAMLKeyword.output_data_formats], - mace_model_dir=mace_model_dir, - model_tag=model_name, - device_type=device_type, - running_round=flags.round, - restart_round=flags.restart_round, - limit_opencl_kernel_time=model_config[ - YAMLKeyword.limit_opencl_kernel_time], - tuning=False, - out_of_range_check=flags.gpu_out_of_range_check, - model_graph_format=configs[ - YAMLKeyword.model_graph_format], - omp_num_threads=flags.omp_num_threads, - cpu_affinity_policy=flags.cpu_affinity_policy, - gpu_perf_hint=flags.gpu_perf_hint, - gpu_priority_hint=flags.gpu_priority_hint, - runtime_failure_ratio=flags.runtime_failure_ratio, - address_sanitizer=flags.address_sanitizer, - opencl_binary_file=model_opencl_output_bin_path, - opencl_parameter_file=model_opencl_parameter_path, - libmace_dynamic_library_path=LIBMACE_DYNAMIC_PATH, - link_dynamic=link_dynamic, - quantize_stat=flags.quantize_stat, - input_dir=flags.input_dir, - output_dir=flags.output_dir, - layers_validate_file=output_config[ - YAMLKeyword.model_file_path] - ) - if flags.validate: - model_file_path, weight_file_path = get_model_files( - model_config[YAMLKeyword.model_file_path], - model_config[YAMLKeyword.model_sha256_checksum], - BUILD_DOWNLOADS_DIR, - model_config[YAMLKeyword.weight_file_path], - model_config[YAMLKeyword.weight_sha256_checksum] - ) - validate_type = device_type - if model_config[YAMLKeyword.quantize] == 1: - validate_type = device_type + '_QUANTIZE' - - dockerfile_path, docker_image_tag = \ - get_dockerfile_info( - model_config.get(YAMLKeyword.dockerfile_path), - model_config.get( - YAMLKeyword.dockerfile_sha256_checksum), - model_config.get(YAMLKeyword.docker_image_tag) - ) if YAMLKeyword.dockerfile_path in model_config \ - else ("third_party/caffe", "lastest") - - sh_commands.validate_model( - abi=target_abi, - device=self, - model_file_path=model_file_path, - weight_file_path=weight_file_path, - docker_image_tag=docker_image_tag, - dockerfile_path=dockerfile_path, - platform=model_config[YAMLKeyword.platform], - device_type=device_type, - input_nodes=subgraphs[0][ - YAMLKeyword.input_tensors], - output_nodes=output_config[ - YAMLKeyword.output_tensors], - input_shapes=subgraphs[0][ - YAMLKeyword.input_shapes], - output_shapes=output_config[ - YAMLKeyword.output_shapes], - input_data_formats=subgraphs[0][ - YAMLKeyword.input_data_formats], - output_data_formats=subgraphs[0][ - YAMLKeyword.output_data_formats], - model_output_dir=model_output_dir, - input_data_types=subgraphs[0][ - YAMLKeyword.input_data_types], - caffe_env=flags.caffe_env, - validation_threshold=subgraphs[0][ - YAMLKeyword.validation_threshold][ - validate_type], - backend=subgraphs[0][YAMLKeyword.backend], - validation_outputs_data=subgraphs[0][ - YAMLKeyword.validation_outputs_data], - log_file=log_file, - ) - if flags.report and flags.round > 0: - tuned = is_tuned and device_type == DeviceType.GPU - self.report_run_statistics( - target_abi=target_abi, - model_name=model_name, - device_type=device_type, - output_dir=flags.report_dir, - tuned=tuned - ) + if accuracy_validation_script: + flags.validate = False + flags.report = False + + import imp + accuracy_val_module = imp.load_source( + 'accuracy_val_module', + accuracy_validation_script) + for runtime in runtime_list: + accuracy_validator = \ + accuracy_val_module.AccuracyValidator() + sample_size = accuracy_validator.sample_size() + val_batch_size = accuracy_validator.batch_size() + for i in range(0, sample_size, val_batch_size): + inputs = accuracy_validator.preprocess( + i, i + val_batch_size) + sh_commands.gen_input( + model_output_dir, + subgraphs[0][YAMLKeyword.input_tensors], + subgraphs[0][YAMLKeyword.input_shapes], + input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa + input_data_map=inputs) + + self.run_model(flags, configs, target_abi, model_name, + output_configs[-1], runtime, tuning) + accuracy_validator.postprocess( + i, i + val_batch_size, + self.get_output_map( + target_abi, + output_nodes, + subgraphs[0][YAMLKeyword.output_shapes], + model_output_dir)) + accuracy_validator.result() + else: + sh_commands.gen_input( + model_output_dir, + subgraphs[0][YAMLKeyword.input_tensors], + subgraphs[0][YAMLKeyword.input_shapes], + subgraphs[0][YAMLKeyword.validation_inputs_data], + input_ranges=subgraphs[0][YAMLKeyword.input_ranges], + input_data_types=subgraphs[0][YAMLKeyword.input_data_types] + ) + for runtime in runtime_list: + device_type = parse_device_type(runtime) + for output_config in output_configs: + self.run_model(flags, configs, target_abi, model_name, + output_config, runtime, tuning) + if flags.validate: + log_file = "" + if flags.layers != "-1": + log_dir = mace_model_dir + "/" + runtime + if os.path.exists(log_dir): + sh.rm('-rf', log_dir) + os.makedirs(log_dir) + log_file = log_dir + "/log.csv" + model_file_path, weight_file_path = \ + get_model_files( + model_config[YAMLKeyword.model_file_path], + model_config[ + YAMLKeyword.model_sha256_checksum], + BUILD_DOWNLOADS_DIR, + model_config[YAMLKeyword.weight_file_path], + model_config[ + YAMLKeyword.weight_sha256_checksum]) + validate_type = device_type + if model_config[YAMLKeyword.quantize] == 1: + validate_type = device_type + '_QUANTIZE' + + dockerfile_path, docker_image_tag = \ + get_dockerfile_info( + model_config.get( + YAMLKeyword.dockerfile_path), + model_config.get( + YAMLKeyword.dockerfile_sha256_checksum), # noqa + model_config.get( + YAMLKeyword.docker_image_tag) + ) if YAMLKeyword.dockerfile_path \ + in model_config \ + else ("third_party/caffe", "lastest") + + sh_commands.validate_model( + abi=target_abi, + device=self, + model_file_path=model_file_path, + weight_file_path=weight_file_path, + docker_image_tag=docker_image_tag, + dockerfile_path=dockerfile_path, + platform=model_config[YAMLKeyword.platform], + device_type=device_type, + input_nodes=subgraphs[0][ + YAMLKeyword.input_tensors], + output_nodes=output_config[ + YAMLKeyword.output_tensors], + input_shapes=subgraphs[0][ + YAMLKeyword.input_shapes], + output_shapes=output_config[ + YAMLKeyword.output_shapes], + input_data_formats=subgraphs[0][ + YAMLKeyword.input_data_formats], + output_data_formats=subgraphs[0][ + YAMLKeyword.output_data_formats], + model_output_dir=model_output_dir, + input_data_types=subgraphs[0][ + YAMLKeyword.input_data_types], + caffe_env=flags.caffe_env, + validation_threshold=subgraphs[0][ + YAMLKeyword.validation_threshold][ + validate_type], + backend=subgraphs[0][YAMLKeyword.backend], + validation_outputs_data=subgraphs[0][ + YAMLKeyword.validation_outputs_data], + log_file=log_file, + ) + if flags.report and flags.round > 0: + tuned = tuning and device_type == DeviceType.GPU + self.report_run_statistics( + target_abi=target_abi, + model_name=model_name, + device_type=device_type, + output_dir=flags.report_dir, + tuned=tuned) + if model_output_dirs: opencl_output_bin_path = get_opencl_binary_output_path( library_name, target_abi, self @@ -956,7 +1047,7 @@ class DeviceWrapper: if target_abi != ABIType.host: self.clear_data_dir() - sh_commands.gen_random_input( + sh_commands.gen_input( model_output_dir, subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_shapes], diff --git a/tools/sh_commands.py b/tools/sh_commands.py index a5c50fc4ca7dd8fa51352b5e428ab803380d17cc..e67a7b22c95270f4d8d148b4ccc506f856e7cb5d 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -549,42 +549,64 @@ def gen_model_code(model_codegen_dir, _fg=True) -def gen_random_input(model_output_dir, - input_nodes, - input_shapes, - input_files, - input_ranges, - input_data_types, - input_file_name="model_input"): +def gen_input(model_output_dir, + input_nodes, + input_shapes, + input_files=None, + input_ranges=None, + input_data_types=None, + input_data_map=None, + input_file_name="model_input"): for input_name in input_nodes: formatted_name = common.formatted_file_name( input_file_name, input_name) if os.path.exists("%s/%s" % (model_output_dir, formatted_name)): sh.rm("%s/%s" % (model_output_dir, formatted_name)) - input_nodes_str = ",".join(input_nodes) - input_shapes_str = ":".join(input_shapes) - input_ranges_str = ":".join(input_ranges) - input_data_types_str = ",".join(input_data_types) - generate_input_data("%s/%s" % (model_output_dir, input_file_name), - input_nodes_str, - input_shapes_str, - input_ranges_str, - input_data_types_str) - input_file_list = [] if isinstance(input_files, list): input_file_list.extend(input_files) else: input_file_list.append(input_files) - if len(input_file_list) != 0: + if input_data_map: + for i in range(len(input_nodes)): + dst_input_file = model_output_dir + '/' + \ + common.formatted_file_name(input_file_name, + input_nodes[i]) + input_name = input_nodes[i] + common.mace_check(input_name in input_data_map, + common.ModuleName.RUN, + "The preprocessor API in PrecisionValidator" + " script should return all inputs of model") + if input_data_types[i] == 'float32': + input_data = np.array(input_data_map[input_name], + dtype=np.float32) + elif input_data_types[i] == 'int32': + input_data = np.array(input_data_map[input_name], + dtype=np.int32) + else: + common.mace_check( + False, + common.ModuleName.RUN, + 'Do not support input data type %s' % input_data_types[i]) + common.mace_check( + list(map(int, common.split_shape(input_shapes[i]))) + == list(input_data.shape), + common.ModuleName.RUN, + "The shape return from preprocessor API of" + " PrecisionValidator script is not same with" + " model deployment file. %s vs %s" + % (str(input_shapes[i]), str(input_data.shape))) + input_data.tofile(dst_input_file) + elif len(input_file_list) != 0: input_name_list = [] if isinstance(input_nodes, list): input_name_list.extend(input_nodes) else: input_name_list.append(input_nodes) - if len(input_file_list) != len(input_name_list): - raise Exception('If input_files set, the input files should ' - 'match the input names.') + common.mace_check(len(input_file_list) == len(input_name_list), + common.ModuleName.RUN, + 'If input_files set, the input files should ' + 'match the input names.') for i in range(len(input_file_list)): if input_file_list[i] is not None: dst_input_file = model_output_dir + '/' + \ @@ -596,6 +618,17 @@ def gen_random_input(model_output_dir, dst_input_file) else: sh.cp("-f", input_file_list[i], dst_input_file) + else: + # generate random input files + input_nodes_str = ",".join(input_nodes) + input_shapes_str = ":".join(input_shapes) + input_ranges_str = ":".join(input_ranges) + input_data_types_str = ",".join(input_data_types) + generate_input_data("%s/%s" % (model_output_dir, input_file_name), + input_nodes_str, + input_shapes_str, + input_ranges_str, + input_data_types_str) def gen_opencl_binary_cpps(opencl_bin_file_path,