“1d0f9cc12f9df7a7efaa179c8e8d111018b934fd”上不存在“tools/python/transform/caffe_converter.py”
提交 437da1ea 编写于 作者: L liuqi

Support accuracy validation using python script as a plugin.

上级 3e9bb73e
...@@ -76,6 +76,8 @@ in one deployment file. ...@@ -76,6 +76,8 @@ in one deployment file.
- The numerical range of the input tensors' data, default [-1, 1]. It is only for test. - The numerical range of the input tensors' data, default [-1, 1]. It is only for test.
* - validation_inputs_data * - validation_inputs_data
- [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used. - [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used.
* - accuracy_validation_script
- [optional] Specify the accuracy validation script as a plugin to test accuracy, see `doc <#validate-accuracy-of-mace-model>`__.
* - validation_threshold * - validation_threshold
- [optional] Specify the similarity threshold for validation. A dict with key in 'CPU', 'GPU' and/or 'HEXAGON' and value <= 1.0. - [optional] Specify the similarity threshold for validation. A dict with key in 'CPU', 'GPU' and/or 'HEXAGON' and value <= 1.0.
* - backend * - backend
...@@ -358,6 +360,19 @@ Tuning for specific SoC's GPU ...@@ -358,6 +360,19 @@ Tuning for specific SoC's GPU
// ... Same with the code in basic usage. // ... Same with the code in basic usage.
Validate accuracy of MACE model
-------------------------------
MACE supports **python validation script** as a plugin to test the accuracy, the plugin script could be used for below two purpose.
1. Test the **accuracy(like Top-1)** of MACE model(specifically quantization model) converted from other framework(like tensorflow)
2. Show some real output if you want to see it.
The script define some interfaces like `preprocess` and `postprocess` to deal with input/outut and calculate the accuracy,
you could refer to the `sample code <https://github.com/XiaoMi/mace/tree/master/tools/accuracy_validator.py>`__ for detail.
the sample code show how to calculate the Top-1 accuracy with imagenet validation dataset.
Useful Commands Useful Commands
--------------- ---------------
* **run the model** * **run the model**
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
library_name: mobile_squeeze library_name: mobile_squeeze
# host, armeabi-v7a or arm64-v8a # host, armeabi-v7a or arm64-v8a
target_abis: [arm64-v8a] target_abis: [arm64-v8a]
# soc's name or all
target_socs: [all]
# The build mode for model(s). # The build mode for model(s).
# 'code' for transferring model(s) into cpp code, 'file' for keeping model(s) in protobuf file(s) (.pb). # 'code' for transferring model(s) into cpp code, 'file' for keeping model(s) in protobuf file(s) (.pb).
model_graph_format: code model_graph_format: code
...@@ -43,6 +45,8 @@ models: ...@@ -43,6 +45,8 @@ models:
- prob - prob
output_shapes: output_shapes:
- 1,1,1,1000 - 1,1,1,1000
accuracy_validation_script:
- path/to/your/script
runtime: cpu+gpu runtime: cpu+gpu
limit_opencl_kernel_time: 0 limit_opencl_kernel_time: 0
obfuscate: 0 obfuscate: 0
......
# Copyright 2019 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import numpy as np
from PIL import Image
class AccuracyValidator(object):
"""Accuracy Validator Plugin:
Usage: This script is used to calculate the accuracy(like Top-1)
of MACE model.
User could reload this validator script to do
other accuracy validation(like MIOU for segmentation),
the new script's interface should be same
with current AccuracyValidator exactly,
Warning: Do not use relative path in this script.
"""
def __init__(self, **kwargs):
# absolute path
validation_set_image_dir = \
'/path/to/your/validation/set/directory'
validation_set_label_file_path =\
'/path/to/imagenet_groundtruth_labels.txt'
black_list_file_path = \
'/path/to/imagenet_blacklist.txt'
imagenet_classes_file = \
'/path/to/imagenet_classes.txt'
self._imagenet_classes = [
line.rstrip('\n') for line in open(imagenet_classes_file)]
imagenet_classes_map = {}
for idx in range(len(self._imagenet_classes)):
imagenet_classes_map[self._imagenet_classes[idx]] = idx
black_list = [
int(line.rstrip('\n')) for line in open(black_list_file_path)]
self._samples = []
self._labels = [0] # image id start from 1
self._correct_count = 0
for img_file in os.listdir(validation_set_image_dir):
if img_file.endswith(".JPEG"):
img_id = int(os.path.splitext(img_file)[0].split('_')[-1])
if img_id not in black_list:
self._samples.append(
os.path.join(validation_set_image_dir, img_file))
for label in open(validation_set_label_file_path):
label = label.rstrip('\n')
self._labels.append(imagenet_classes_map[label])
def sample_size(self):
"""
:return: the size of samples in validation set
"""
return len(self._samples)
def batch_size(self):
"""
batch size to do validation to speed up validation.
Keep same with batch size of input_shapes
in model deployment file(.yml). do not set too large
:return: batch size
"""
return 1
def preprocess(self, sample_idx_start, sample_idx_end, **kwargs):
"""
pre-process the input sample
:param sample_idx_start: start index of the sample.
:param sample_idx_end: end index of the sample(not include).
:param kwargs: other parameters.
:return: the batched inputs' map(name: data) feed into your model
"""
inputs = {}
batch_sample_data = []
sample_idx_end = min(sample_idx_end, self.sample_size())
for sample_idx in range(sample_idx_start, sample_idx_end):
sample_file_path = self._samples[sample_idx]
sample_img = Image.open(sample_file_path).resize((224, 224))
sample_data = np.asarray(sample_img, dtype=np.float32)
sample_data = (2.0 / 255.0) * sample_data - 1.0
batch_sample_data.append(sample_data.tolist())
inputs["input"] = batch_sample_data
return inputs
def postprocess(self,
sample_idx_start,
sample_idx_end,
output_map,
**kwargs):
"""
post-process the outputs of your model and calculate the accuracy
:param sample_idx_start: start index of input sample
:param sample_idx_end: end index of input sample
:param output_map: output map of the model
:param kwargs: other parameters.
:return: None
"""
output = output_map['MobilenetV2/Predictions/Reshape_1']
sample_idx_end = min(sample_idx_end, self.sample_size())
batch_size = sample_idx_end - sample_idx_start
output = np.array(output).reshape((batch_size, -1))
output = np.argmax(output, axis=-1)
output_idx = 0
for sample_idx in range(sample_idx_start, sample_idx_end):
sample_file_path = self._samples[sample_idx]
img_id = int(os.path.splitext(sample_file_path)[0].split('_')[-1])
if output[output_idx] == self._labels[img_id]:
self._correct_count += 1
else:
print(img_id, 'predict %s vs gt %s' %
(self._imagenet_classes[output[output_idx]],
self._imagenet_classes[self._labels[img_id]]))
output_idx += 1
def result(self):
"""
print or show the result
:return: None
"""
print("==========================================")
print("Top 1 accuracy: %f" %
(self._correct_count * 1.0 / self.sample_size()))
print("==========================================")
if __name__ == '__main__':
# sample usage code
validator = AccuracyValidator()
sample_size = validator.sample_size()
val_batch_size = validator.batch_size()
for i in range(0, sample_size, val_batch_size):
inputs = validator.preprocess(i, i+val_batch_size)
print(np.array(inputs['input']).shape)
output_map = {
'MobilenetV2/Predictions/Reshape_1': np.array([[0, 1], [1, 0]])
}
validator.postprocess(i, i+val_batch_size, output_map)
validator.result()
...@@ -413,6 +413,7 @@ class YAMLKeyword(object): ...@@ -413,6 +413,7 @@ class YAMLKeyword(object):
cl_mem_type = 'cl_mem_type' cl_mem_type = 'cl_mem_type'
backend = 'backend' backend = 'backend'
validation_outputs_data = 'validation_outputs_data' validation_outputs_data = 'validation_outputs_data'
accuracy_validation_script = 'accuracy_validation_script'
docker_image_tag = 'docker_image_tag' docker_image_tag = 'docker_image_tag'
dockerfile_path = 'dockerfile_path' dockerfile_path = 'dockerfile_path'
dockerfile_sha256_checksum = 'dockerfile_sha256_checksum' dockerfile_sha256_checksum = 'dockerfile_sha256_checksum'
......
...@@ -532,6 +532,16 @@ def format_model_config(flags): ...@@ -532,6 +532,16 @@ def format_model_config(flags):
subgraph[YAMLKeyword.input_ranges] = \ subgraph[YAMLKeyword.input_ranges] = \
[str(v) for v in subgraph[YAMLKeyword.input_ranges]] [str(v) for v in subgraph[YAMLKeyword.input_ranges]]
accuracy_validation_script = subgraph.get(
YAMLKeyword.accuracy_validation_script, "")
if isinstance(accuracy_validation_script, list):
mace_check(len(accuracy_validation_script) == 1,
ModuleName.YAML_CONFIG,
"Only support one accuracy validation script")
accuracy_validation_script = accuracy_validation_script[0]
subgraph[YAMLKeyword.accuracy_validation_script] = \
accuracy_validation_script
for key in [YAMLKeyword.limit_opencl_kernel_time, for key in [YAMLKeyword.limit_opencl_kernel_time,
YAMLKeyword.nnlib_graph_mode, YAMLKeyword.nnlib_graph_mode,
YAMLKeyword.obfuscate, YAMLKeyword.obfuscate,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import os import os
import sys import sys
import socket import socket
...@@ -387,7 +388,7 @@ class DeviceWrapper: ...@@ -387,7 +388,7 @@ class DeviceWrapper:
subgraphs = model_config[YAMLKeyword.subgraphs] subgraphs = model_config[YAMLKeyword.subgraphs]
# generate input data # generate input data
sh_commands.gen_random_input( sh_commands.gen_input(
model_output_dir, model_output_dir,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
...@@ -460,18 +461,14 @@ class DeviceWrapper: ...@@ -460,18 +461,14 @@ class DeviceWrapper:
return output_configs return output_configs
def run_specify_abi(self, flags, configs, target_abi): def run_model(self, flags, configs, target_abi,
if target_abi not in self.target_abis: model_name, output_config, runtime, tuning):
six.print_('The device %s with soc %s do not support the abi %s' %
(self.device_name, self.target_socs, target_abi))
return
library_name = configs[YAMLKeyword.library_name] library_name = configs[YAMLKeyword.library_name]
mace_lib_type = flags.mace_lib_type
embed_model_data = \ embed_model_data = \
configs[YAMLKeyword.model_data_format] == ModelFormat.code configs[YAMLKeyword.model_data_format] == ModelFormat.code
build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi) build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
# get target name for run # get target name for run
mace_lib_type = flags.mace_lib_type
if flags.example: if flags.example:
if mace_lib_type == MACELibType.static: if mace_lib_type == MACELibType.static:
target_name = EXAMPLE_STATIC_NAME target_name = EXAMPLE_STATIC_NAME
...@@ -483,6 +480,114 @@ class DeviceWrapper: ...@@ -483,6 +480,114 @@ class DeviceWrapper:
else: else:
target_name = MACE_RUN_DYNAMIC_NAME target_name = MACE_RUN_DYNAMIC_NAME
link_dynamic = mace_lib_type == MACELibType.dynamic link_dynamic = mace_lib_type == MACELibType.dynamic
if target_abi != ABIType.host:
self.clear_data_dir()
model_config = configs[YAMLKeyword.models][model_name]
subgraphs = model_config[YAMLKeyword.subgraphs]
model_output_base_dir, model_output_dir, mace_model_dir = \
get_build_model_dirs(
library_name, model_name, target_abi, self,
model_config[YAMLKeyword.model_file_path])
model_opencl_output_bin_path = ''
model_opencl_parameter_path = ''
if tuning:
model_opencl_output_bin_path = \
'{}/{}/{}'.format(model_output_dir,
BUILD_TMP_OPENCL_BIN_DIR,
CL_COMPILED_BINARY_FILE_NAME)
model_opencl_parameter_path = \
'{}/{}/{}'.format(model_output_dir,
BUILD_TMP_OPENCL_BIN_DIR,
CL_TUNED_PARAMETER_FILE_NAME)
elif target_abi != ABIType.host and self.target_socs:
model_opencl_output_bin_path = get_opencl_binary_output_path(
library_name, target_abi, self
)
model_opencl_parameter_path = get_opencl_parameter_output_path(
library_name, target_abi, self
)
# run for specified soc
device_type = parse_device_type(runtime)
self.tuning_run(
abi=target_abi,
target_dir=build_tmp_binary_dir,
target_name=target_name,
vlog_level=flags.vlog_level,
embed_model_data=embed_model_data,
model_output_dir=model_output_dir,
input_nodes=subgraphs[0][YAMLKeyword.input_tensors],
output_nodes=output_config[
YAMLKeyword.output_tensors],
input_shapes=subgraphs[0][YAMLKeyword.input_shapes],
output_shapes=output_config[YAMLKeyword.output_shapes],
input_data_formats=subgraphs[0][
YAMLKeyword.input_data_formats],
output_data_formats=subgraphs[0][
YAMLKeyword.output_data_formats],
mace_model_dir=mace_model_dir,
model_tag=model_name,
device_type=device_type,
running_round=flags.round,
restart_round=flags.restart_round,
limit_opencl_kernel_time=model_config[
YAMLKeyword.limit_opencl_kernel_time],
tuning=False,
out_of_range_check=flags.gpu_out_of_range_check,
model_graph_format=configs[
YAMLKeyword.model_graph_format],
omp_num_threads=flags.omp_num_threads,
cpu_affinity_policy=flags.cpu_affinity_policy,
gpu_perf_hint=flags.gpu_perf_hint,
gpu_priority_hint=flags.gpu_priority_hint,
runtime_failure_ratio=flags.runtime_failure_ratio,
address_sanitizer=flags.address_sanitizer,
opencl_binary_file=model_opencl_output_bin_path,
opencl_parameter_file=model_opencl_parameter_path,
libmace_dynamic_library_path=LIBMACE_DYNAMIC_PATH,
link_dynamic=link_dynamic,
quantize_stat=flags.quantize_stat,
input_dir=flags.input_dir,
output_dir=flags.output_dir,
layers_validate_file=output_config[
YAMLKeyword.model_file_path]
)
def get_output_map(self,
target_abi,
output_nodes,
output_shapes,
model_output_dir):
output_map = {}
for i in range(len(output_nodes)):
output_name = output_nodes[i]
formatted_name = common.formatted_file_name(
"model_out", output_name)
if target_abi != "host":
if os.path.exists("%s/%s" % (model_output_dir,
formatted_name)):
sh.rm("-rf", "%s/%s" % (model_output_dir,
formatted_name))
self.pull_from_data_dir(formatted_name,
model_output_dir)
output_file_path = os.path.join(model_output_dir,
formatted_name)
output_shape = [
int(x) for x in common.split_shape(output_shapes[i])]
output_map[output_name] = np.fromfile(
output_file_path, dtype=np.float32).reshape(output_shape)
return output_map
def run_specify_abi(self, flags, configs, target_abi):
if target_abi not in self.target_abis:
six.print_('The device %s with soc %s do not support the abi %s' %
(self.device_name, self.target_socs, target_abi))
return
library_name = configs[YAMLKeyword.library_name]
model_output_dirs = [] model_output_dirs = []
for model_name in configs[YAMLKeyword.models]: for model_name in configs[YAMLKeyword.models]:
...@@ -510,9 +615,7 @@ class DeviceWrapper: ...@@ -510,9 +615,7 @@ class DeviceWrapper:
sh.rm('-rf', model_output_dir) sh.rm('-rf', model_output_dir)
os.makedirs(model_output_dir) os.makedirs(model_output_dir)
is_tuned = False tuning = False
model_opencl_output_bin_path = ''
model_opencl_parameter_path = ''
if not flags.address_sanitizer \ if not flags.address_sanitizer \
and not flags.example \ and not flags.example \
and target_abi != ABIType.host \ and target_abi != ABIType.host \
...@@ -525,33 +628,36 @@ class DeviceWrapper: ...@@ -525,33 +628,36 @@ class DeviceWrapper:
self.tuning(library_name, model_name, model_config, self.tuning(library_name, model_name, model_config,
configs[YAMLKeyword.model_graph_format], configs[YAMLKeyword.model_graph_format],
configs[YAMLKeyword.model_data_format], configs[YAMLKeyword.model_data_format],
target_abi, mace_lib_type) target_abi, flags.mace_lib_type)
model_output_dirs.append(model_output_dir) model_output_dirs.append(model_output_dir)
model_opencl_output_bin_path = \
'{}/{}/{}'.format(model_output_dir,
BUILD_TMP_OPENCL_BIN_DIR,
CL_COMPILED_BINARY_FILE_NAME)
model_opencl_parameter_path = \
'{}/{}/{}'.format(model_output_dir,
BUILD_TMP_OPENCL_BIN_DIR,
CL_TUNED_PARAMETER_FILE_NAME)
self.clear_data_dir() self.clear_data_dir()
is_tuned = True tuning = True
elif target_abi != ABIType.host and self.target_socs:
model_opencl_output_bin_path = get_opencl_binary_output_path( accuracy_validation_script = \
library_name, target_abi, self subgraphs[0][YAMLKeyword.accuracy_validation_script]
) output_configs = []
model_opencl_parameter_path = get_opencl_parameter_output_path( if not accuracy_validation_script and flags.layers != "-1":
library_name, target_abi, self mace_check(configs[YAMLKeyword.model_graph_format] ==
) ModelFormat.file and
sh_commands.gen_random_input( configs[YAMLKeyword.model_data_format] ==
model_output_dir, ModelFormat.file, "Device",
subgraphs[0][YAMLKeyword.input_tensors], "'--layers' only supports model format 'file'.")
subgraphs[0][YAMLKeyword.input_shapes], output_configs = self.get_layers(mace_model_dir,
subgraphs[0][YAMLKeyword.validation_inputs_data], model_name,
input_ranges=subgraphs[0][YAMLKeyword.input_ranges], flags.layers)
input_data_types=subgraphs[0][YAMLKeyword.input_data_types] # run for specified soc
) if not subgraphs[0][YAMLKeyword.check_tensors]:
output_nodes = subgraphs[0][YAMLKeyword.output_tensors]
output_shapes = subgraphs[0][YAMLKeyword.output_shapes]
else:
output_nodes = subgraphs[0][YAMLKeyword.check_tensors]
output_shapes = subgraphs[0][YAMLKeyword.check_shapes]
model_path = "%s/%s.pb" % (mace_model_dir, model_name)
output_config = {YAMLKeyword.model_file_path: model_path,
YAMLKeyword.output_tensors: output_nodes,
YAMLKeyword.output_shapes: output_shapes}
output_configs.append(output_config)
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
runtime_list.append(RuntimeType.cpu) runtime_list.append(RuntimeType.cpu)
...@@ -559,143 +665,128 @@ class DeviceWrapper: ...@@ -559,143 +665,128 @@ class DeviceWrapper:
runtime_list.extend([RuntimeType.cpu, RuntimeType.gpu]) runtime_list.extend([RuntimeType.cpu, RuntimeType.gpu])
else: else:
runtime_list.append(model_runtime) runtime_list.append(model_runtime)
for runtime in runtime_list: if accuracy_validation_script:
device_type = parse_device_type(runtime) flags.validate = False
# run for specified soc flags.report = False
if not subgraphs[0][YAMLKeyword.check_tensors]:
output_nodes = subgraphs[0][YAMLKeyword.output_tensors] import imp
output_shapes = subgraphs[0][YAMLKeyword.output_shapes] accuracy_val_module = imp.load_source(
else: 'accuracy_val_module',
output_nodes = subgraphs[0][YAMLKeyword.check_tensors] accuracy_validation_script)
output_shapes = subgraphs[0][YAMLKeyword.check_shapes] for runtime in runtime_list:
output_configs = [] accuracy_validator = \
log_file = "" accuracy_val_module.AccuracyValidator()
if flags.layers != "-1": sample_size = accuracy_validator.sample_size()
mace_check(configs[YAMLKeyword.model_graph_format] == val_batch_size = accuracy_validator.batch_size()
ModelFormat.file and for i in range(0, sample_size, val_batch_size):
configs[YAMLKeyword.model_data_format] == inputs = accuracy_validator.preprocess(
ModelFormat.file, "Device", i, i + val_batch_size)
"'--layers' only supports model format 'file'.") sh_commands.gen_input(
output_configs = self.get_layers(mace_model_dir, model_output_dir,
model_name, subgraphs[0][YAMLKeyword.input_tensors],
flags.layers) subgraphs[0][YAMLKeyword.input_shapes],
log_dir = mace_model_dir + "/" + runtime input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa
if os.path.exists(log_dir): input_data_map=inputs)
sh.rm('-rf', log_dir)
os.makedirs(log_dir) self.run_model(flags, configs, target_abi, model_name,
log_file = log_dir + "/log.csv" output_configs[-1], runtime, tuning)
model_path = "%s/%s.pb" % (mace_model_dir, model_name) accuracy_validator.postprocess(
output_config = {YAMLKeyword.model_file_path: model_path, i, i + val_batch_size,
YAMLKeyword.output_tensors: output_nodes, self.get_output_map(
YAMLKeyword.output_shapes: output_shapes} target_abi,
output_configs.append(output_config) output_nodes,
for output_config in output_configs: subgraphs[0][YAMLKeyword.output_shapes],
run_output = self.tuning_run( model_output_dir))
abi=target_abi, accuracy_validator.result()
target_dir=build_tmp_binary_dir, else:
target_name=target_name, sh_commands.gen_input(
vlog_level=flags.vlog_level, model_output_dir,
embed_model_data=embed_model_data, subgraphs[0][YAMLKeyword.input_tensors],
model_output_dir=model_output_dir, subgraphs[0][YAMLKeyword.input_shapes],
input_nodes=subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.validation_inputs_data],
output_nodes=output_config[ input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
YAMLKeyword.output_tensors], input_data_types=subgraphs[0][YAMLKeyword.input_data_types]
input_shapes=subgraphs[0][YAMLKeyword.input_shapes], )
output_shapes=output_config[YAMLKeyword.output_shapes], for runtime in runtime_list:
input_data_formats=subgraphs[0][ device_type = parse_device_type(runtime)
YAMLKeyword.input_data_formats], for output_config in output_configs:
output_data_formats=subgraphs[0][ self.run_model(flags, configs, target_abi, model_name,
YAMLKeyword.output_data_formats], output_config, runtime, tuning)
mace_model_dir=mace_model_dir, if flags.validate:
model_tag=model_name, log_file = ""
device_type=device_type, if flags.layers != "-1":
running_round=flags.round, log_dir = mace_model_dir + "/" + runtime
restart_round=flags.restart_round, if os.path.exists(log_dir):
limit_opencl_kernel_time=model_config[ sh.rm('-rf', log_dir)
YAMLKeyword.limit_opencl_kernel_time], os.makedirs(log_dir)
tuning=False, log_file = log_dir + "/log.csv"
out_of_range_check=flags.gpu_out_of_range_check, model_file_path, weight_file_path = \
model_graph_format=configs[ get_model_files(
YAMLKeyword.model_graph_format], model_config[YAMLKeyword.model_file_path],
omp_num_threads=flags.omp_num_threads, model_config[
cpu_affinity_policy=flags.cpu_affinity_policy, YAMLKeyword.model_sha256_checksum],
gpu_perf_hint=flags.gpu_perf_hint, BUILD_DOWNLOADS_DIR,
gpu_priority_hint=flags.gpu_priority_hint, model_config[YAMLKeyword.weight_file_path],
runtime_failure_ratio=flags.runtime_failure_ratio, model_config[
address_sanitizer=flags.address_sanitizer, YAMLKeyword.weight_sha256_checksum])
opencl_binary_file=model_opencl_output_bin_path, validate_type = device_type
opencl_parameter_file=model_opencl_parameter_path, if model_config[YAMLKeyword.quantize] == 1:
libmace_dynamic_library_path=LIBMACE_DYNAMIC_PATH, validate_type = device_type + '_QUANTIZE'
link_dynamic=link_dynamic,
quantize_stat=flags.quantize_stat, dockerfile_path, docker_image_tag = \
input_dir=flags.input_dir, get_dockerfile_info(
output_dir=flags.output_dir, model_config.get(
layers_validate_file=output_config[ YAMLKeyword.dockerfile_path),
YAMLKeyword.model_file_path] model_config.get(
) YAMLKeyword.dockerfile_sha256_checksum), # noqa
if flags.validate: model_config.get(
model_file_path, weight_file_path = get_model_files( YAMLKeyword.docker_image_tag)
model_config[YAMLKeyword.model_file_path], ) if YAMLKeyword.dockerfile_path \
model_config[YAMLKeyword.model_sha256_checksum], in model_config \
BUILD_DOWNLOADS_DIR, else ("third_party/caffe", "lastest")
model_config[YAMLKeyword.weight_file_path],
model_config[YAMLKeyword.weight_sha256_checksum] sh_commands.validate_model(
) abi=target_abi,
validate_type = device_type device=self,
if model_config[YAMLKeyword.quantize] == 1: model_file_path=model_file_path,
validate_type = device_type + '_QUANTIZE' weight_file_path=weight_file_path,
docker_image_tag=docker_image_tag,
dockerfile_path, docker_image_tag = \ dockerfile_path=dockerfile_path,
get_dockerfile_info( platform=model_config[YAMLKeyword.platform],
model_config.get(YAMLKeyword.dockerfile_path), device_type=device_type,
model_config.get( input_nodes=subgraphs[0][
YAMLKeyword.dockerfile_sha256_checksum), YAMLKeyword.input_tensors],
model_config.get(YAMLKeyword.docker_image_tag) output_nodes=output_config[
) if YAMLKeyword.dockerfile_path in model_config \ YAMLKeyword.output_tensors],
else ("third_party/caffe", "lastest") input_shapes=subgraphs[0][
YAMLKeyword.input_shapes],
sh_commands.validate_model( output_shapes=output_config[
abi=target_abi, YAMLKeyword.output_shapes],
device=self, input_data_formats=subgraphs[0][
model_file_path=model_file_path, YAMLKeyword.input_data_formats],
weight_file_path=weight_file_path, output_data_formats=subgraphs[0][
docker_image_tag=docker_image_tag, YAMLKeyword.output_data_formats],
dockerfile_path=dockerfile_path, model_output_dir=model_output_dir,
platform=model_config[YAMLKeyword.platform], input_data_types=subgraphs[0][
device_type=device_type, YAMLKeyword.input_data_types],
input_nodes=subgraphs[0][ caffe_env=flags.caffe_env,
YAMLKeyword.input_tensors], validation_threshold=subgraphs[0][
output_nodes=output_config[ YAMLKeyword.validation_threshold][
YAMLKeyword.output_tensors], validate_type],
input_shapes=subgraphs[0][ backend=subgraphs[0][YAMLKeyword.backend],
YAMLKeyword.input_shapes], validation_outputs_data=subgraphs[0][
output_shapes=output_config[ YAMLKeyword.validation_outputs_data],
YAMLKeyword.output_shapes], log_file=log_file,
input_data_formats=subgraphs[0][ )
YAMLKeyword.input_data_formats], if flags.report and flags.round > 0:
output_data_formats=subgraphs[0][ tuned = tuning and device_type == DeviceType.GPU
YAMLKeyword.output_data_formats], self.report_run_statistics(
model_output_dir=model_output_dir, target_abi=target_abi,
input_data_types=subgraphs[0][ model_name=model_name,
YAMLKeyword.input_data_types], device_type=device_type,
caffe_env=flags.caffe_env, output_dir=flags.report_dir,
validation_threshold=subgraphs[0][ tuned=tuned)
YAMLKeyword.validation_threshold][
validate_type],
backend=subgraphs[0][YAMLKeyword.backend],
validation_outputs_data=subgraphs[0][
YAMLKeyword.validation_outputs_data],
log_file=log_file,
)
if flags.report and flags.round > 0:
tuned = is_tuned and device_type == DeviceType.GPU
self.report_run_statistics(
target_abi=target_abi,
model_name=model_name,
device_type=device_type,
output_dir=flags.report_dir,
tuned=tuned
)
if model_output_dirs: if model_output_dirs:
opencl_output_bin_path = get_opencl_binary_output_path( opencl_output_bin_path = get_opencl_binary_output_path(
library_name, target_abi, self library_name, target_abi, self
...@@ -956,7 +1047,7 @@ class DeviceWrapper: ...@@ -956,7 +1047,7 @@ class DeviceWrapper:
if target_abi != ABIType.host: if target_abi != ABIType.host:
self.clear_data_dir() self.clear_data_dir()
sh_commands.gen_random_input( sh_commands.gen_input(
model_output_dir, model_output_dir,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
......
...@@ -549,42 +549,64 @@ def gen_model_code(model_codegen_dir, ...@@ -549,42 +549,64 @@ def gen_model_code(model_codegen_dir,
_fg=True) _fg=True)
def gen_random_input(model_output_dir, def gen_input(model_output_dir,
input_nodes, input_nodes,
input_shapes, input_shapes,
input_files, input_files=None,
input_ranges, input_ranges=None,
input_data_types, input_data_types=None,
input_file_name="model_input"): input_data_map=None,
input_file_name="model_input"):
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = common.formatted_file_name( formatted_name = common.formatted_file_name(
input_file_name, input_name) input_file_name, input_name)
if os.path.exists("%s/%s" % (model_output_dir, formatted_name)): if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name)) sh.rm("%s/%s" % (model_output_dir, formatted_name))
input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes)
input_ranges_str = ":".join(input_ranges)
input_data_types_str = ",".join(input_data_types)
generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str,
input_shapes_str,
input_ranges_str,
input_data_types_str)
input_file_list = [] input_file_list = []
if isinstance(input_files, list): if isinstance(input_files, list):
input_file_list.extend(input_files) input_file_list.extend(input_files)
else: else:
input_file_list.append(input_files) input_file_list.append(input_files)
if len(input_file_list) != 0: if input_data_map:
for i in range(len(input_nodes)):
dst_input_file = model_output_dir + '/' + \
common.formatted_file_name(input_file_name,
input_nodes[i])
input_name = input_nodes[i]
common.mace_check(input_name in input_data_map,
common.ModuleName.RUN,
"The preprocessor API in PrecisionValidator"
" script should return all inputs of model")
if input_data_types[i] == 'float32':
input_data = np.array(input_data_map[input_name],
dtype=np.float32)
elif input_data_types[i] == 'int32':
input_data = np.array(input_data_map[input_name],
dtype=np.int32)
else:
common.mace_check(
False,
common.ModuleName.RUN,
'Do not support input data type %s' % input_data_types[i])
common.mace_check(
list(map(int, common.split_shape(input_shapes[i])))
== list(input_data.shape),
common.ModuleName.RUN,
"The shape return from preprocessor API of"
" PrecisionValidator script is not same with"
" model deployment file. %s vs %s"
% (str(input_shapes[i]), str(input_data.shape)))
input_data.tofile(dst_input_file)
elif len(input_file_list) != 0:
input_name_list = [] input_name_list = []
if isinstance(input_nodes, list): if isinstance(input_nodes, list):
input_name_list.extend(input_nodes) input_name_list.extend(input_nodes)
else: else:
input_name_list.append(input_nodes) input_name_list.append(input_nodes)
if len(input_file_list) != len(input_name_list): common.mace_check(len(input_file_list) == len(input_name_list),
raise Exception('If input_files set, the input files should ' common.ModuleName.RUN,
'match the input names.') 'If input_files set, the input files should '
'match the input names.')
for i in range(len(input_file_list)): for i in range(len(input_file_list)):
if input_file_list[i] is not None: if input_file_list[i] is not None:
dst_input_file = model_output_dir + '/' + \ dst_input_file = model_output_dir + '/' + \
...@@ -596,6 +618,17 @@ def gen_random_input(model_output_dir, ...@@ -596,6 +618,17 @@ def gen_random_input(model_output_dir,
dst_input_file) dst_input_file)
else: else:
sh.cp("-f", input_file_list[i], dst_input_file) sh.cp("-f", input_file_list[i], dst_input_file)
else:
# generate random input files
input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes)
input_ranges_str = ":".join(input_ranges)
input_data_types_str = ",".join(input_data_types)
generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str,
input_shapes_str,
input_ranges_str,
input_data_types_str)
def gen_opencl_binary_cpps(opencl_bin_file_path, def gen_opencl_binary_cpps(opencl_bin_file_path,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册