提交 437da1ea 编写于 作者: L liuqi

Support accuracy validation using python script as a plugin.

上级 3e9bb73e
......@@ -76,6 +76,8 @@ in one deployment file.
- The numerical range of the input tensors' data, default [-1, 1]. It is only for test.
* - validation_inputs_data
- [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used.
* - accuracy_validation_script
- [optional] Specify the accuracy validation script as a plugin to test accuracy, see `doc <#validate-accuracy-of-mace-model>`__.
* - validation_threshold
- [optional] Specify the similarity threshold for validation. A dict with key in 'CPU', 'GPU' and/or 'HEXAGON' and value <= 1.0.
* - backend
......@@ -358,6 +360,19 @@ Tuning for specific SoC's GPU
// ... Same with the code in basic usage.
Validate accuracy of MACE model
-------------------------------
MACE supports **python validation script** as a plugin to test the accuracy, the plugin script could be used for below two purpose.
1. Test the **accuracy(like Top-1)** of MACE model(specifically quantization model) converted from other framework(like tensorflow)
2. Show some real output if you want to see it.
The script define some interfaces like `preprocess` and `postprocess` to deal with input/outut and calculate the accuracy,
you could refer to the `sample code <https://github.com/XiaoMi/mace/tree/master/tools/accuracy_validator.py>`__ for detail.
the sample code show how to calculate the Top-1 accuracy with imagenet validation dataset.
Useful Commands
---------------
* **run the model**
......
......@@ -2,6 +2,8 @@
library_name: mobile_squeeze
# host, armeabi-v7a or arm64-v8a
target_abis: [arm64-v8a]
# soc's name or all
target_socs: [all]
# The build mode for model(s).
# 'code' for transferring model(s) into cpp code, 'file' for keeping model(s) in protobuf file(s) (.pb).
model_graph_format: code
......@@ -43,6 +45,8 @@ models:
- prob
output_shapes:
- 1,1,1,1000
accuracy_validation_script:
- path/to/your/script
runtime: cpu+gpu
limit_opencl_kernel_time: 0
obfuscate: 0
......
# Copyright 2019 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import numpy as np
from PIL import Image
class AccuracyValidator(object):
"""Accuracy Validator Plugin:
Usage: This script is used to calculate the accuracy(like Top-1)
of MACE model.
User could reload this validator script to do
other accuracy validation(like MIOU for segmentation),
the new script's interface should be same
with current AccuracyValidator exactly,
Warning: Do not use relative path in this script.
"""
def __init__(self, **kwargs):
# absolute path
validation_set_image_dir = \
'/path/to/your/validation/set/directory'
validation_set_label_file_path =\
'/path/to/imagenet_groundtruth_labels.txt'
black_list_file_path = \
'/path/to/imagenet_blacklist.txt'
imagenet_classes_file = \
'/path/to/imagenet_classes.txt'
self._imagenet_classes = [
line.rstrip('\n') for line in open(imagenet_classes_file)]
imagenet_classes_map = {}
for idx in range(len(self._imagenet_classes)):
imagenet_classes_map[self._imagenet_classes[idx]] = idx
black_list = [
int(line.rstrip('\n')) for line in open(black_list_file_path)]
self._samples = []
self._labels = [0] # image id start from 1
self._correct_count = 0
for img_file in os.listdir(validation_set_image_dir):
if img_file.endswith(".JPEG"):
img_id = int(os.path.splitext(img_file)[0].split('_')[-1])
if img_id not in black_list:
self._samples.append(
os.path.join(validation_set_image_dir, img_file))
for label in open(validation_set_label_file_path):
label = label.rstrip('\n')
self._labels.append(imagenet_classes_map[label])
def sample_size(self):
"""
:return: the size of samples in validation set
"""
return len(self._samples)
def batch_size(self):
"""
batch size to do validation to speed up validation.
Keep same with batch size of input_shapes
in model deployment file(.yml). do not set too large
:return: batch size
"""
return 1
def preprocess(self, sample_idx_start, sample_idx_end, **kwargs):
"""
pre-process the input sample
:param sample_idx_start: start index of the sample.
:param sample_idx_end: end index of the sample(not include).
:param kwargs: other parameters.
:return: the batched inputs' map(name: data) feed into your model
"""
inputs = {}
batch_sample_data = []
sample_idx_end = min(sample_idx_end, self.sample_size())
for sample_idx in range(sample_idx_start, sample_idx_end):
sample_file_path = self._samples[sample_idx]
sample_img = Image.open(sample_file_path).resize((224, 224))
sample_data = np.asarray(sample_img, dtype=np.float32)
sample_data = (2.0 / 255.0) * sample_data - 1.0
batch_sample_data.append(sample_data.tolist())
inputs["input"] = batch_sample_data
return inputs
def postprocess(self,
sample_idx_start,
sample_idx_end,
output_map,
**kwargs):
"""
post-process the outputs of your model and calculate the accuracy
:param sample_idx_start: start index of input sample
:param sample_idx_end: end index of input sample
:param output_map: output map of the model
:param kwargs: other parameters.
:return: None
"""
output = output_map['MobilenetV2/Predictions/Reshape_1']
sample_idx_end = min(sample_idx_end, self.sample_size())
batch_size = sample_idx_end - sample_idx_start
output = np.array(output).reshape((batch_size, -1))
output = np.argmax(output, axis=-1)
output_idx = 0
for sample_idx in range(sample_idx_start, sample_idx_end):
sample_file_path = self._samples[sample_idx]
img_id = int(os.path.splitext(sample_file_path)[0].split('_')[-1])
if output[output_idx] == self._labels[img_id]:
self._correct_count += 1
else:
print(img_id, 'predict %s vs gt %s' %
(self._imagenet_classes[output[output_idx]],
self._imagenet_classes[self._labels[img_id]]))
output_idx += 1
def result(self):
"""
print or show the result
:return: None
"""
print("==========================================")
print("Top 1 accuracy: %f" %
(self._correct_count * 1.0 / self.sample_size()))
print("==========================================")
if __name__ == '__main__':
# sample usage code
validator = AccuracyValidator()
sample_size = validator.sample_size()
val_batch_size = validator.batch_size()
for i in range(0, sample_size, val_batch_size):
inputs = validator.preprocess(i, i+val_batch_size)
print(np.array(inputs['input']).shape)
output_map = {
'MobilenetV2/Predictions/Reshape_1': np.array([[0, 1], [1, 0]])
}
validator.postprocess(i, i+val_batch_size, output_map)
validator.result()
......@@ -413,6 +413,7 @@ class YAMLKeyword(object):
cl_mem_type = 'cl_mem_type'
backend = 'backend'
validation_outputs_data = 'validation_outputs_data'
accuracy_validation_script = 'accuracy_validation_script'
docker_image_tag = 'docker_image_tag'
dockerfile_path = 'dockerfile_path'
dockerfile_sha256_checksum = 'dockerfile_sha256_checksum'
......
......@@ -532,6 +532,16 @@ def format_model_config(flags):
subgraph[YAMLKeyword.input_ranges] = \
[str(v) for v in subgraph[YAMLKeyword.input_ranges]]
accuracy_validation_script = subgraph.get(
YAMLKeyword.accuracy_validation_script, "")
if isinstance(accuracy_validation_script, list):
mace_check(len(accuracy_validation_script) == 1,
ModuleName.YAML_CONFIG,
"Only support one accuracy validation script")
accuracy_validation_script = accuracy_validation_script[0]
subgraph[YAMLKeyword.accuracy_validation_script] = \
accuracy_validation_script
for key in [YAMLKeyword.limit_opencl_kernel_time,
YAMLKeyword.nnlib_graph_mode,
YAMLKeyword.obfuscate,
......
此差异已折叠。
......@@ -549,42 +549,64 @@ def gen_model_code(model_codegen_dir,
_fg=True)
def gen_random_input(model_output_dir,
input_nodes,
input_shapes,
input_files,
input_ranges,
input_data_types,
input_file_name="model_input"):
def gen_input(model_output_dir,
input_nodes,
input_shapes,
input_files=None,
input_ranges=None,
input_data_types=None,
input_data_map=None,
input_file_name="model_input"):
for input_name in input_nodes:
formatted_name = common.formatted_file_name(
input_file_name, input_name)
if os.path.exists("%s/%s" % (model_output_dir, formatted_name)):
sh.rm("%s/%s" % (model_output_dir, formatted_name))
input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes)
input_ranges_str = ":".join(input_ranges)
input_data_types_str = ",".join(input_data_types)
generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str,
input_shapes_str,
input_ranges_str,
input_data_types_str)
input_file_list = []
if isinstance(input_files, list):
input_file_list.extend(input_files)
else:
input_file_list.append(input_files)
if len(input_file_list) != 0:
if input_data_map:
for i in range(len(input_nodes)):
dst_input_file = model_output_dir + '/' + \
common.formatted_file_name(input_file_name,
input_nodes[i])
input_name = input_nodes[i]
common.mace_check(input_name in input_data_map,
common.ModuleName.RUN,
"The preprocessor API in PrecisionValidator"
" script should return all inputs of model")
if input_data_types[i] == 'float32':
input_data = np.array(input_data_map[input_name],
dtype=np.float32)
elif input_data_types[i] == 'int32':
input_data = np.array(input_data_map[input_name],
dtype=np.int32)
else:
common.mace_check(
False,
common.ModuleName.RUN,
'Do not support input data type %s' % input_data_types[i])
common.mace_check(
list(map(int, common.split_shape(input_shapes[i])))
== list(input_data.shape),
common.ModuleName.RUN,
"The shape return from preprocessor API of"
" PrecisionValidator script is not same with"
" model deployment file. %s vs %s"
% (str(input_shapes[i]), str(input_data.shape)))
input_data.tofile(dst_input_file)
elif len(input_file_list) != 0:
input_name_list = []
if isinstance(input_nodes, list):
input_name_list.extend(input_nodes)
else:
input_name_list.append(input_nodes)
if len(input_file_list) != len(input_name_list):
raise Exception('If input_files set, the input files should '
'match the input names.')
common.mace_check(len(input_file_list) == len(input_name_list),
common.ModuleName.RUN,
'If input_files set, the input files should '
'match the input names.')
for i in range(len(input_file_list)):
if input_file_list[i] is not None:
dst_input_file = model_output_dir + '/' + \
......@@ -596,6 +618,17 @@ def gen_random_input(model_output_dir,
dst_input_file)
else:
sh.cp("-f", input_file_list[i], dst_input_file)
else:
# generate random input files
input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes)
input_ranges_str = ":".join(input_ranges)
input_data_types_str = ",".join(input_data_types)
generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str,
input_shapes_str,
input_ranges_str,
input_data_types_str)
def gen_opencl_binary_cpps(opencl_bin_file_path,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册