diff --git a/example/post_training_quantization/analysis.md b/example/post_training_quantization/analysis.md index 3f7b89615adfa5d41bb1dd4bca2c722fb5c79ea7..ce5e7a75eed98016fcac4845528c15e7ce1e1ab2 100644 --- a/example/post_training_quantization/analysis.md +++ b/example/post_training_quantization/analysis.md @@ -15,13 +15,7 @@ data_loader: None save_dir: 'analysis_results' checkpoint_name: 'analysis_checkpoint.pkl' num_histogram_plots: 10 - -quantizable_op_type: ["conv2d", "depthwise_conv2d", "mul"] -weight_quantize_type: 'abs_max' -activation_quantize_type: 'moving_average_abs_max' -is_full_quantize: False -batch_size: 10 -batch_nums: 10 +ptq_config ``` - model_dir: 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可。 - model_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。 @@ -31,18 +25,7 @@ batch_nums: 10 - save_dir:分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`。 - checkpoint_name:由于模型可能存在大量层需要分析,因此分析过程中会中间保存结果,如果程序中断会自动加载已经分析好的结果,默认为`analysis_checkpoint.pkl`。 - num_histogram_plots:需要可视化的直方分布图数量。可视化量化效果最好和最坏的该数量个权重和激活的分布图。默认为10。若不需要可视化直方图,设置为0即可。 - -注:以下参数均为需要传入离线量化中的参数,保持默认不影响模型进行量化分析。 -- quantizable_op_type:需要进行量化的OP类型。通过以下代码可输出所有支持量化的OP类型: -``` -from paddleslim.quant.quanter import TRANSFORM_PASS_OP_TYPES,QUANT_DEQUANT_PASS_OP_TYPES -print(TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES) -``` -- weight_quantize_type:参数量化方式。可选 'abs_max' , 'channel_wise_abs_max' , 'range_abs_max' , 'moving_average_abs_max' 。 默认 'abs_max' 。 -- activation_quantize_type:激活量化方式,可选 'abs_max' , 'range_abs_max' , 'moving_average_abs_max' 。默认为 'moving_average_abs_max'。 -- is_full_quantize:是否对模型进行全量化,默认为False。 -- batch_size:模型校准使用的batch size大小,默认为10。 -- batch_nums:模型校准时的总batch数量,默认为10。 +- ptq_config:可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post)。 diff --git a/example/post_training_quantization/detection/analysis.py b/example/post_training_quantization/detection/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..ba65d7291fe635cc2d2cc7eb3c37bf7c2298b84f --- /dev/null +++ b/example/post_training_quantization/detection/analysis.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +from tqdm import tqdm +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.core.workspace import create +from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval +from keypoint_utils import keypoint_post_process +from post_process import PPYOLOEPostProcess +from paddleslim.quant.analysis import AnalysisQuant + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of analysis config.", + required=True) + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + return parser + + +def reader_wrapper(reader, input_list): + def gen(): + for data in reader: + in_dict = {} + if isinstance(input_list, list): + for input_name in input_list: + in_dict[input_name] = data[input_name] + elif isinstance(input_list, dict): + for input_name in input_list.keys(): + in_dict[input_list[input_name]] = data[input_name] + yield in_dict + + return gen + + +def convert_numpy_data(data, metric): + data_all = {} + data_all = {k: np.array(v) for k, v in data.items()} + if isinstance(metric, VOCMetric): + for k, v in data_all.items(): + if not isinstance(v[0], np.ndarray): + tmp_list = [] + for t in v: + tmp_list.append(np.array(t)) + data_all[k] = np.array(tmp_list) + else: + data_all = {k: np.array(v) for k, v in data.items()} + return data_all + + +def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): + with tqdm( + total=len(val_loader), + bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for batch_id, data in enumerate(val_loader): + data_all = convert_numpy_data(data, metric) + data_input = {} + for k, v in data.items(): + if isinstance(config['input_list'], list): + if k in test_feed_names: + data_input[k] = np.array(v) + elif isinstance(config['input_list'], dict): + if k in config['input_list'].keys(): + data_input[config['input_list'][k]] = np.array(v) + outs = exe.run(compiled_test_program, + feed=data_input, + fetch_list=test_fetch_list, + return_numpy=False) + res = {} + if 'arch' in config and config['arch'] == 'keypoint': + res = keypoint_post_process(data, data_input, exe, + compiled_test_program, + test_fetch_list, outs) + if 'arch' in config and config['arch'] == 'PPYOLOE': + postprocess = PPYOLOEPostProcess( + score_threshold=0.01, nms_threshold=0.6) + res = postprocess(np.array(outs[0]), data_all['scale_factor']) + else: + for out in outs: + v = np.array(out) + if len(v.shape) > 1: + res['bbox'] = v + else: + res['bbox_num'] = v + + metric.update(data_all, res) + t.update() + + metric.accumulate() + metric.log() + map_res = metric.get_results() + metric.reset() + map_key = 'keypoint' if 'arch' in config and config[ + 'arch'] == 'keypoint' else 'bbox' + return map_res[map_key][0] + + +def main(): + + global config + config = load_config(FLAGS.config_path) + ptq_config = config['PTQ'] + + data_loader = create('EvalReader')(config['EvalDataset'], + config['worker_num'], + return_list=True) + data_loader = reader_wrapper(data_loader, config['input_list']) + + dataset = config['EvalDataset'] + global val_loader + _eval_batch_sampler = paddle.io.BatchSampler( + dataset, batch_size=config['EvalReader']['batch_size']) + val_loader = create('EvalReader')(dataset, + config['worker_num'], + batch_sampler=_eval_batch_sampler, + return_list=True) + global metric + if config['metric'] == 'COCO': + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + elif config['metric'] == 'VOC': + metric = VOCMetric( + label_list=dataset.get_label_list(), + class_num=config['num_classes'], + map_type=config['map_type']) + elif config['metric'] == 'KeyPointTopDownCOCOEval': + anno_file = dataset.get_anno() + metric = KeyPointTopDownCOCOEval(anno_file, + len(dataset), 17, 'output_eval') + else: + raise ValueError("metric currently only supports COCO and VOC.") + + analyzer = AnalysisQuant( + model_dir=config["model_dir"], + model_filename=config["model_filename"], + params_filename=config["params_filename"], + eval_function=eval_function, + data_loader=data_loader, + save_dir=config['save_dir'], + ptq_config=ptq_config) + analyzer.analysis() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/example/post_training_quantization/detection/configs/picodet_s_analysis.yaml b/example/post_training_quantization/detection/configs/picodet_s_analysis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c640795c44b9abae9b1bd019fb318347d1c74c9 --- /dev/null +++ b/example/post_training_quantization/detection/configs/picodet_s_analysis.yaml @@ -0,0 +1,47 @@ +input_list: ['image', 'scale_factor'] +model_dir: ./picodet_s_416_coco_lcnet/ +model_filename: model.pdmodel +params_filename: model.pdiparams +save_dir: ./analysis_results +metric: COCO +num_classes: 80 + +PTQ: + quantizable_op_type: ["conv2d", "depthwise_conv2d"] + weight_quantize_type: 'abs_max' + activation_quantize_type: 'moving_average_abs_max' + is_full_quantize: False + batch_size: 10 + batch_nums: 10 + +# Datset configuration +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: /dataset/coco/ + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: /dataset/coco/ + +eval_height: &eval_height 416 +eval_width: &eval_width 416 +eval_size: &eval_size [*eval_height, *eval_width] + +worker_num: 0 + +EvalReader: + inputs_def: + image_shape: [1, 3, *eval_height, *eval_width] + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_size: 32 + + + diff --git a/example/post_training_quantization/detection/configs/picodet_s_ptq.yaml b/example/post_training_quantization/detection/configs/picodet_s_ptq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..005c0d46cdb10ae19a83b572e759a194513ba445 --- /dev/null +++ b/example/post_training_quantization/detection/configs/picodet_s_ptq.yaml @@ -0,0 +1,38 @@ +input_list: ['image', 'scale_factor'] +model_dir: ./picodet_s_416_coco_lcnet/ +model_filename: model.pdmodel +params_filename: model.pdiparams +skip_tensor_list: None + +metric: COCO +num_classes: 80 + +# Datset configuration +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: /dataset/coco/ + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: /dataset/coco/ + +eval_height: &eval_height 416 +eval_width: &eval_width 416 +eval_size: &eval_size [*eval_height, *eval_width] + +worker_num: 0 + +EvalReader: + inputs_def: + image_shape: [1, 3, *eval_height, *eval_width] + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: *eval_size, keep_ratio: False} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_size: 32 + diff --git a/example/post_training_quantization/detection/eval.py b/example/post_training_quantization/detection/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0c09ae46c644fea8ca6218d0f0da3544d59161 --- /dev/null +++ b/example/post_training_quantization/detection/eval.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import numpy as np +import argparse +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.core.workspace import create +from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval +from paddleslim.common import load_config as load_slim_config +from keypoint_utils import keypoint_post_process +from post_process import PPYOLOEPostProcess + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + '--config_path', + type=str, + default=None, + help="path of compression strategy config.", + required=True) + parser.add_argument( + '--devices', + type=str, + default='gpu', + help="which device used to compress.") + + return parser + + +def reader_wrapper(reader, input_list): + def gen(): + for data in reader: + in_dict = {} + if isinstance(input_list, list): + for input_name in input_list: + in_dict[input_name] = data[input_name] + elif isinstance(input_list, dict): + for input_name in input_list.keys(): + in_dict[input_list[input_name]] = data[input_name] + yield in_dict + + return gen + + +def convert_numpy_data(data, metric): + data_all = {} + data_all = {k: np.array(v) for k, v in data.items()} + if isinstance(metric, VOCMetric): + for k, v in data_all.items(): + if not isinstance(v[0], np.ndarray): + tmp_list = [] + for t in v: + tmp_list.append(np.array(t)) + data_all[k] = np.array(tmp_list) + else: + data_all = {k: np.array(v) for k, v in data.items()} + return data_all + + +def eval(): + + place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() + exe = paddle.static.Executor(place) + + val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( + global_config["model_dir"].rstrip('/'), + exe, + model_filename=global_config["model_filename"], + params_filename=global_config["params_filename"]) + print('Loaded model from: {}'.format(global_config["model_dir"])) + + metric = global_config['metric'] + for batch_id, data in enumerate(val_loader): + data_all = convert_numpy_data(data, metric) + data_input = {} + for k, v in data.items(): + if isinstance(global_config['input_list'], list): + if k in global_config['input_list']: + data_input[k] = np.array(v) + elif isinstance(global_config['input_list'], dict): + if k in global_config['input_list'].keys(): + data_input[global_config['input_list'][k]] = np.array(v) + + outs = exe.run(val_program, + feed=data_input, + fetch_list=fetch_targets, + return_numpy=False) + res = {} + if 'arch' in global_config and global_config['arch'] == 'keypoint': + res = keypoint_post_process(data, data_input, exe, val_program, + fetch_targets, outs) + if 'arch' in global_config and global_config['arch'] == 'PPYOLOE': + postprocess = PPYOLOEPostProcess( + score_threshold=0.01, nms_threshold=0.6) + res = postprocess(np.array(outs[0]), data_all['scale_factor']) + else: + for out in outs: + v = np.array(out) + if len(v.shape) > 1: + res['bbox'] = v + else: + res['bbox_num'] = v + metric.update(data_all, res) + if batch_id % 100 == 0: + print('Eval iter:', batch_id) + metric.accumulate() + metric.log() + metric.reset() + + +def main(): + global global_config + all_config = load_slim_config(FLAGS.config_path) + global_config = all_config["Global"] + reader_cfg = load_config(global_config['reader_config']) + + dataset = reader_cfg['EvalDataset'] + global val_loader + val_loader = create('EvalReader')(reader_cfg['EvalDataset'], + reader_cfg['worker_num'], + return_list=True) + metric = None + if reader_cfg['metric'] == 'COCO': + clsid2catid = {v: k for k, v in dataset.catid2clsid.items()} + anno_file = dataset.get_anno() + metric = COCOMetric( + anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox') + elif reader_cfg['metric'] == 'VOC': + metric = VOCMetric( + label_list=dataset.get_label_list(), + class_num=reader_cfg['num_classes'], + map_type=reader_cfg['map_type']) + elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval': + anno_file = dataset.get_anno() + metric = KeyPointTopDownCOCOEval(anno_file, + len(dataset), 17, 'output_eval') + else: + raise ValueError("metric currently only supports COCO and VOC.") + global_config['metric'] = metric + + eval() + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu'] + paddle.set_device(FLAGS.devices) + + main() diff --git a/example/post_training_quantization/detection/keypoint_utils.py b/example/post_training_quantization/detection/keypoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d17095f45f5f7bdb54129656360845fcd91dc4b4 --- /dev/null +++ b/example/post_training_quantization/detection/keypoint_utils.py @@ -0,0 +1,307 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +import cv2 +import copy +from paddleslim.common import get_logger + +logger = get_logger(__name__, level=logging.INFO) + +__all__ = ['keypoint_post_process'] + + +def flip_back(output_flipped, matched_parts): + assert output_flipped.ndim == 4,\ + 'output_flipped should be [batch_size, num_joints, height, width]' + + output_flipped = output_flipped[:, :, :, ::-1] + + for pair in matched_parts: + tmp = output_flipped[:, pair[0], :, :].copy() + output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :] + output_flipped[:, pair[1], :, :] = tmp + + return output_flipped + + +def get_affine_transform(center, + input_size, + rot, + output_size, + shift=(0., 0.), + inv=False): + """Get the affine transform matrix, given the center/scale/rot/output_size. + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + input_size (np.ndarray[2, ]): Size of input feature (width, height). + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ]): Size of the destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + Returns: + np.ndarray: The transform matrix. + """ + assert len(center) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + + if not isinstance(input_size, (np.ndarray, list)): + input_size = np.array([input_size, input_size], dtype=np.float32) + scale_tmp = input_size + + shift = np.array(shift) + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = rotate_point([0., src_w * -0.5], rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def _get_3rd_point(a, b): + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + Args: + a (np.ndarray): point(x,y) + b (np.ndarray): point(x,y) + Returns: + np.ndarray: The 3rd point. + """ + assert len( + a) == 2, 'input of _get_3rd_point should be point with length of 2' + assert len( + b) == 2, 'input of _get_3rd_point should be point with length of 2' + direction = a - b + third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32) + + return third_pt + + +def rotate_point(pt, angle_rad): + """Rotate a point by an angle. + Args: + pt (list[float]): 2 dimensional point to be rotated + angle_rad (float): rotation angle by radian + Returns: + list[float]: Rotated point. + """ + assert len(pt) == 2 + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + new_x = pt[0] * cs - pt[1] * sn + new_y = pt[0] * sn + pt[1] * cs + rotated_pt = [new_x, new_y] + + return rotated_pt + + +def affine_transform(pt, t): + new_pt = np.array([pt[0], pt[1], 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + + +def transform_preds(coords, center, scale, output_size): + target_coords = np.zeros(coords.shape) + trans = get_affine_transform(center, scale * 200, 0, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + + +class HRNetPostProcess(object): + def __init__(self, use_dark=True): + self.use_dark = use_dark + + def get_max_preds(self, heatmaps): + '''get predictions from score maps + Args: + heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) + Returns: + preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords + maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints + ''' + assert isinstance(heatmaps, + np.ndarray), 'heatmaps should be numpy.ndarray' + assert heatmaps.ndim == 4, 'batch_images should be 4-ndim' + + batch_size = heatmaps.shape[0] + num_joints = heatmaps.shape[1] + width = heatmaps.shape[3] + heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1)) + idx = np.argmax(heatmaps_reshaped, 2) + maxvals = np.amax(heatmaps_reshaped, 2) + + maxvals = maxvals.reshape((batch_size, num_joints, 1)) + idx = idx.reshape((batch_size, num_joints, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + + preds[:, :, 0] = (preds[:, :, 0]) % width + preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) + + pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) + pred_mask = pred_mask.astype(np.float32) + + preds *= pred_mask + + return preds, maxvals + + def gaussian_blur(self, heatmap, kernel): + border = (kernel - 1) // 2 + batch_size = heatmap.shape[0] + num_joints = heatmap.shape[1] + height = heatmap.shape[2] + width = heatmap.shape[3] + for i in range(batch_size): + for j in range(num_joints): + origin_max = np.max(heatmap[i, j]) + dr = np.zeros((height + 2 * border, width + 2 * border)) + dr[border:-border, border:-border] = heatmap[i, j].copy() + dr = cv2.GaussianBlur(dr, (kernel, kernel), 0) + heatmap[i, j] = dr[border:-border, border:-border].copy() + heatmap[i, j] *= origin_max / np.max(heatmap[i, j]) + return heatmap + + def dark_parse(self, hm, coord): + heatmap_height = hm.shape[0] + heatmap_width = hm.shape[1] + px = int(coord[0]) + py = int(coord[1]) + if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2: + dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1]) + dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px]) + dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2]) + dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \ + + hm[py-1][px-1]) + dyy = 0.25 * ( + hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px]) + derivative = np.matrix([[dx], [dy]]) + hessian = np.matrix([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = hessian.I + offset = -hessianinv * derivative + offset = np.squeeze(np.array(offset.T), axis=0) + coord += offset + return coord + + def dark_postprocess(self, hm, coords, kernelsize): + ''' + DARK postpocessing, Zhang et al. Distribution-Aware Coordinate + Representation for Human Pose Estimation (CVPR 2020). + ''' + hm = self.gaussian_blur(hm, kernelsize) + hm = np.maximum(hm, 1e-10) + hm = np.log(hm) + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + coords[n, p] = self.dark_parse(hm[n][p], coords[n][p]) + return coords + + def get_final_preds(self, heatmaps, center, scale, kernelsize=3): + """ + The highest heatvalue location with a quarter offset in the + direction from the highest response to the second highest response. + Args: + heatmaps (numpy.ndarray): The predicted heatmaps + center (numpy.ndarray): The boxes center + scale (numpy.ndarray): The scale factor + Returns: + preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords + maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints + """ + coords, maxvals = self.get_max_preds(heatmaps) + + heatmap_height = heatmaps.shape[2] + heatmap_width = heatmaps.shape[3] + + if self.use_dark: + coords = self.dark_postprocess(heatmaps, coords, kernelsize) + else: + for n in range(coords.shape[0]): + for p in range(coords.shape[1]): + hm = heatmaps[n][p] + px = int(math.floor(coords[n][p][0] + 0.5)) + py = int(math.floor(coords[n][p][1] + 0.5)) + if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1: + diff = np.array([ + hm[py][px + 1] - hm[py][px - 1], + hm[py + 1][px] - hm[py - 1][px] + ]) + coords[n][p] += np.sign(diff) * .25 + preds = coords.copy() + + # Transform back + for i in range(coords.shape[0]): + preds[i] = transform_preds(coords[i], center[i], scale[i], + [heatmap_width, heatmap_height]) + + return preds, maxvals + + def __call__(self, output, center, scale): + preds, maxvals = self.get_final_preds(np.array(output), center, scale) + outputs = [[ + np.concatenate( + (preds, maxvals), axis=-1), np.mean( + maxvals, axis=1) + ]] + return outputs + + +def keypoint_post_process(data, data_input, exe, val_program, fetch_targets, + outs): + data_input['image'] = np.flip(data_input['image'], [3]) + output_flipped = exe.run(val_program, + feed=data_input, + fetch_list=fetch_targets, + return_numpy=False) + + output_flipped = np.array(output_flipped[0]) + flip_perm = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], + [15, 16]] + output_flipped = flip_back(output_flipped, flip_perm) + output_flipped[:, :, :, 1:] = copy.copy(output_flipped)[:, :, :, 0:-1] + hrnet_outputs = (np.array(outs[0]) + output_flipped) * 0.5 + imshape = ( + np.array(data['im_shape']))[:, ::-1] if 'im_shape' in data else None + center = np.array(data['center']) if 'center' in data else np.round( + imshape / 2.) + scale = np.array(data['scale']) if 'scale' in data else imshape / 200. + post_process = HRNetPostProcess() + outputs = post_process(hrnet_outputs, center, scale) + return {'keypoint': outputs} diff --git a/example/post_training_quantization/detection/post_process.py b/example/post_training_quantization/detection/post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..eea2f019548ec288a23e37b3bd2faf24f9a98935 --- /dev/null +++ b/example/post_training_quantization/detection/post_process.py @@ -0,0 +1,157 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import cv2 + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + indexes = np.argsort(scores) + indexes = indexes[-candidate_size:] + while len(indexes) > 0: + current = indexes[-1] + picked.append(current) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[:-1] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + np.expand_dims( + current_box, axis=0), ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom): + """Compute the areas of rectangles given two corners. + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + Returns: + area (N): return the area. + """ + hw = np.clip(right_bottom - left_top, 0.0, None) + return hw[..., 0] * hw[..., 1] + + +class PPYOLOEPostProcess(object): + """ + Args: + input_shape (int): network input image size + scale_factor (float): scale factor of ori image + """ + + def __init__(self, + score_threshold=0.4, + nms_threshold=0.5, + nms_top_k=10000, + keep_top_k=300): + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + + def _non_max_suppression(self, prediction, scale_factor): + batch_size = prediction.shape[0] + out_boxes_list = [] + box_num_list = [] + for batch_id in range(batch_size): + bboxes, confidences = prediction[batch_id][..., :4], prediction[ + batch_id][..., 4:] + # nms + picked_box_probs = [] + picked_labels = [] + for class_index in range(0, confidences.shape[1]): + probs = confidences[:, class_index] + mask = probs > self.score_threshold + probs = probs[mask] + if probs.shape[0] == 0: + continue + subset_boxes = bboxes[mask, :] + box_probs = np.concatenate( + [subset_boxes, probs.reshape(-1, 1)], axis=1) + box_probs = hard_nms( + box_probs, + iou_threshold=self.nms_threshold, + top_k=self.nms_top_k) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.shape[0]) + + if len(picked_box_probs) == 0: + out_boxes_list.append(np.empty((0, 4))) + + else: + picked_box_probs = np.concatenate(picked_box_probs) + # resize output boxes + picked_box_probs[:, 0] /= scale_factor[batch_id][1] + picked_box_probs[:, 2] /= scale_factor[batch_id][1] + picked_box_probs[:, 1] /= scale_factor[batch_id][0] + picked_box_probs[:, 3] /= scale_factor[batch_id][0] + + # clas score box + out_box = np.concatenate( + [ + np.expand_dims( + np.array(picked_labels), axis=-1), np.expand_dims( + picked_box_probs[:, 4], axis=-1), + picked_box_probs[:, :4] + ], + axis=1) + if out_box.shape[0] > self.keep_top_k: + out_box = out_box[out_box[:, 1].argsort()[::-1] + [:self.keep_top_k]] + out_boxes_list.append(out_box) + box_num_list.append(out_box.shape[0]) + + out_boxes_list = np.concatenate(out_boxes_list, axis=0) + box_num_list = np.array(box_num_list) + return out_boxes_list, box_num_list + + def __call__(self, outs, scale_factor): + out_boxes_list, box_num_list = self._non_max_suppression(outs, + scale_factor) + return {'bbox': out_boxes_list, 'bbox_num': box_num_list} diff --git a/example/auto_compression/detection/post_quant.py b/example/post_training_quantization/detection/post_quant.py similarity index 78% rename from example/auto_compression/detection/post_quant.py rename to example/post_training_quantization/detection/post_quant.py index edc7d2fea66dfb16e51b8ad16a5e61b75294b895..a0c010364dd1b47ce33131814fd95942da7d96b0 100644 --- a/example/auto_compression/detection/post_quant.py +++ b/example/post_training_quantization/detection/post_quant.py @@ -19,7 +19,6 @@ import argparse import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import create -from paddleslim.common import load_config as load_slim_config from paddleslim.quant import quant_post_static @@ -63,33 +62,32 @@ def reader_wrapper(reader, input_list): def main(): - global global_config - all_config = load_slim_config(FLAGS.config_path) - assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" - global_config = all_config["Global"] - reader_cfg = load_config(global_config['reader_config']) + global config + config = load_config(FLAGS.config_path) - train_loader = create('EvalReader')(reader_cfg['TrainDataset'], - reader_cfg['worker_num'], + train_loader = create('EvalReader')(config['TrainDataset'], + config['worker_num'], return_list=True) - train_loader = reader_wrapper(train_loader, global_config['input_list']) + train_loader = reader_wrapper(train_loader, config['input_list']) place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() exe = paddle.static.Executor(place) quant_post_static( executor=exe, - model_dir=global_config["model_dir"], + model_dir=config["model_dir"], quantize_model_path=FLAGS.save_dir, data_loader=train_loader, - model_filename=global_config["model_filename"], - params_filename=global_config["params_filename"], + model_filename=config["model_filename"], + params_filename=config["params_filename"], batch_size=4, batch_nums=64, algo=FLAGS.algo, hist_percent=0.999, is_full_quantize=False, bias_correction=False, - onnx_format=False) + onnx_format=False, + skip_tensor_list=config['skip_tensor_list'] + if 'skip_tensor_list' in config else None) if __name__ == '__main__': diff --git a/example/post_training_quantization/pytorch_yolo_series/analysis.py b/example/post_training_quantization/pytorch_yolo_series/analysis.py index d9e5629fe5a5581fbbadf9f62cd599319f1a7d3a..1c8cbcb5b9a969bb485ca2fdaf2610d7222d2ad2 100644 --- a/example/post_training_quantization/pytorch_yolo_series/analysis.py +++ b/example/post_training_quantization/pytorch_yolo_series/analysis.py @@ -68,6 +68,7 @@ def main(): global config config = load_config(FLAGS.config_path) + ptq_config = config['PTQ'] input_name = 'x2paddle_image_arrays' if config[ 'arch'] == 'YOLOv6' else 'x2paddle_images' @@ -97,13 +98,9 @@ def main(): model_filename='model.pdmodel', params_filename='model.pdiparams', eval_function=eval_function, - quantizable_op_type=config['quantizable_op_type'], - weight_quantize_type=config['weight_quantize_type'], - activation_quantize_type=config['activation_quantize_type'], - is_full_quantize=config['is_full_quantize'], data_loader=data_loader, - batch_size=config['batch_size'], - save_dir=config['save_dir'], ) + save_dir=config['save_dir'], + ptq_config=ptq_config) analyzer.analysis() diff --git a/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analysis.yaml b/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analysis.yaml index 7e6d9785cb0043848cc57935f59dd07202218baa..a99198a444cfccf90bde0f874fc90edb1a75b92e 100644 --- a/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analysis.yaml +++ b/example/post_training_quantization/pytorch_yolo_series/configs/yolov6s_analysis.yaml @@ -1,11 +1,15 @@ arch: YOLOv6 model_dir: ./yolov6s.onnx save_dir: ./analysis_results -quantizable_op_type: ["conv2d", "depthwise_conv2d"] -weight_quantize_type: 'abs_max' -activation_quantize_type: 'moving_average_abs_max' -is_full_quantize: False dataset_dir: /dataset/coco/ val_image_dir: val2017 val_anno_path: annotations/instances_val2017.json -batch_size: 10 + +PTQ: + quantizable_op_type: ["conv2d", "depthwise_conv2d"] + weight_quantize_type: 'abs_max' + activation_quantize_type: 'moving_average_abs_max' + is_full_quantize: False + batch_size: 10 + batch_nums: 10 + diff --git a/paddleslim/quant/analysis.py b/paddleslim/quant/analysis.py index 848e183f4d6eecce9089c95e557bae4549635db8..d81e80f8eb52fa2ccbb8549a7de4a04b8a91af13 100644 --- a/paddleslim/quant/analysis.py +++ b/paddleslim/quant/analysis.py @@ -39,22 +39,16 @@ __all__ = ["AnalysisQuant"] class AnalysisQuant(object): - def __init__( - self, - model_dir, - model_filename=None, - params_filename=None, - eval_function=None, - data_loader=None, - save_dir='analysis_results', - checkpoint_name='analysis_checkpoint.pkl', - num_histogram_plots=10, - quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], - weight_quantize_type='abs_max', - activation_quantize_type='moving_average_abs_max', - is_full_quantize=False, - batch_size=10, - batch_nums=10, ): + def __init__(self, + model_dir, + model_filename=None, + params_filename=None, + eval_function=None, + data_loader=None, + save_dir='analysis_results', + checkpoint_name='analysis_checkpoint.pkl', + num_histogram_plots=10, + ptq_config=None): """ AnalysisQuant provides to analysis the sensitivity of each op in the model. @@ -68,13 +62,8 @@ class AnalysisQuant(object): return a batch every time save_dir(str, optional): the output dir that stores the analyzed information checkpoint_name(str, optional): the name of checkpoint file that saves analyzed information and avoids break off while ananlyzing - num_histogram_plots: the number histogram plots you want to visilize, the plots will show in four PDF files for both best and worst and for both weight and act ops in the save_dir - quantizable_op_type(list): op types that can be quantized - weight_quantize_type(str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max' - activation_quantize_type(str): quantization type for activation, now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max' - is_full_quantize(bool): if True, apply quantization to all supported quantizable op type. If False, only apply quantization to the input quantizable_op_type. Default is False. - batch_size(int, optional): the batch size of DataLoader, default is 10 - batch_nums(int, optional): the number of calibrate data is 'batch_size*batch_nums' + ptq_config(dict, optional): the args that can initialize PostTrainingQuantization + """ if model_filename is None: model_filename = 'model.pdmodel' @@ -83,20 +72,16 @@ class AnalysisQuant(object): self.model_dir = model_dir self.model_filename = model_filename self.params_filename = params_filename - self.batch_nums = batch_nums - self.quantizable_op_type = quantizable_op_type - self.weight_quantize_type = weight_quantize_type - self.activation_quantize_type = activation_quantize_type - self.is_full_quantize = is_full_quantize self.histogram_bins = 1000 self.save_dir = save_dir self.eval_function = eval_function self.quant_layer_names = [] self.checkpoint_name = os.path.join(save_dir, checkpoint_name) self.quant_layer_metrics = {} - self.batch_size = batch_size - self.batch_nums = batch_nums self.num_histogram_plots = num_histogram_plots + self.ptq_config = ptq_config + self.batch_nums = ptq_config[ + 'batch_nums'] if 'batch_nums' in ptq_config else 10 if not os.path.exists(self.save_dir): os.mkdir(self.save_dir) @@ -130,14 +115,9 @@ class AnalysisQuant(object): model_dir=self.model_dir, model_filename=self.model_filename, params_filename=self.params_filename, - batch_size=self.batch_size, - batch_nums=self.batch_nums, - algo='avg', # fastest - quantizable_op_type=self.quantizable_op_type, - weight_quantize_type=self.weight_quantize_type, - activation_quantize_type=self.activation_quantize_type, - is_full_quantize=self.is_full_quantize, - skip_tensor_list=None, ) + skip_tensor_list=None, + algo='avg', #fastest + **self.ptq_config) program = post_training_quantization.quantize() self.quant_metric = self.eval_function(executor, program, self.feed_list, self.fetch_list) @@ -208,14 +188,9 @@ class AnalysisQuant(object): model_dir=self.model_dir, model_filename=self.model_filename, params_filename=self.params_filename, - batch_size=self.batch_size, - batch_nums=self.batch_nums, - algo='avg', # fastest - quantizable_op_type=self.quantizable_op_type, - weight_quantize_type=self.weight_quantize_type, - activation_quantize_type=self.activation_quantize_type, - is_full_quantize=self.is_full_quantize, - skip_tensor_list=skip_list, ) + skip_tensor_list=skip_list, + algo='avg', #fastest + **self.ptq_config) program = post_training_quantization.quantize() _logger.info('Evaluating...')