未验证 提交 7ca2aa68 编写于 作者: C Chang Xu 提交者: GitHub

Add Picodet Analysis Demo & Simplify Analysis Args (#1356)

上级 fd85f9db
...@@ -15,13 +15,7 @@ data_loader: None ...@@ -15,13 +15,7 @@ data_loader: None
save_dir: 'analysis_results' save_dir: 'analysis_results'
checkpoint_name: 'analysis_checkpoint.pkl' checkpoint_name: 'analysis_checkpoint.pkl'
num_histogram_plots: 10 num_histogram_plots: 10
ptq_config
quantizable_op_type: ["conv2d", "depthwise_conv2d", "mul"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 10
batch_nums: 10
``` ```
- model_dir: 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可。 - model_dir: 必须传入的模型文件路径,可为文件夹名;若模型为ONNX类型,直接输入'.onnx'模型文件名称即可。
- model_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。 - model_filename: 默认为None,若model_dir为文件夹名,则必须传入以'.pdmodel'结尾的模型名称,若model_dir为'.onnx'模型文件名称,则不需要传入。
...@@ -31,18 +25,7 @@ batch_nums: 10 ...@@ -31,18 +25,7 @@ batch_nums: 10
- save_dir:分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results` - save_dir:分析后保存模型精度或pdf等文件的文件夹,默认为`analysis_results`
- checkpoint_name:由于模型可能存在大量层需要分析,因此分析过程中会中间保存结果,如果程序中断会自动加载已经分析好的结果,默认为`analysis_checkpoint.pkl` - checkpoint_name:由于模型可能存在大量层需要分析,因此分析过程中会中间保存结果,如果程序中断会自动加载已经分析好的结果,默认为`analysis_checkpoint.pkl`
- num_histogram_plots:需要可视化的直方分布图数量。可视化量化效果最好和最坏的该数量个权重和激活的分布图。默认为10。若不需要可视化直方图,设置为0即可。 - num_histogram_plots:需要可视化的直方分布图数量。可视化量化效果最好和最坏的该数量个权重和激活的分布图。默认为10。若不需要可视化直方图,设置为0即可。
- ptq_config:可传入的离线量化中的参数,详细可参考[离线量化文档](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/quant/quant_post)
注:以下参数均为需要传入离线量化中的参数,保持默认不影响模型进行量化分析。
- quantizable_op_type:需要进行量化的OP类型。通过以下代码可输出所有支持量化的OP类型:
```
from paddleslim.quant.quanter import TRANSFORM_PASS_OP_TYPES,QUANT_DEQUANT_PASS_OP_TYPES
print(TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
```
- weight_quantize_type:参数量化方式。可选 'abs_max' , 'channel_wise_abs_max' , 'range_abs_max' , 'moving_average_abs_max' 。 默认 'abs_max' 。
- activation_quantize_type:激活量化方式,可选 'abs_max' , 'range_abs_max' , 'moving_average_abs_max' 。默认为 'moving_average_abs_max'。
- is_full_quantize:是否对模型进行全量化,默认为False。
- batch_size:模型校准使用的batch size大小,默认为10。
- batch_nums:模型校准时的总batch数量,默认为10。
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
from tqdm import tqdm
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from keypoint_utils import keypoint_post_process
from post_process import PPYOLOEPostProcess
from paddleslim.quant.analysis import AnalysisQuant
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of analysis config.",
required=True)
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
with tqdm(
total=len(val_loader),
bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for batch_id, data in enumerate(val_loader):
data_all = convert_numpy_data(data, metric)
data_input = {}
for k, v in data.items():
if isinstance(config['input_list'], list):
if k in test_feed_names:
data_input[k] = np.array(v)
elif isinstance(config['input_list'], dict):
if k in config['input_list'].keys():
data_input[config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
if 'arch' in config and config['arch'] == 'keypoint':
res = keypoint_post_process(data, data_input, exe,
compiled_test_program,
test_fetch_list, outs)
if 'arch' in config and config['arch'] == 'PPYOLOE':
postprocess = PPYOLOEPostProcess(
score_threshold=0.01, nms_threshold=0.6)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
t.update()
metric.accumulate()
metric.log()
map_res = metric.get_results()
metric.reset()
map_key = 'keypoint' if 'arch' in config and config[
'arch'] == 'keypoint' else 'bbox'
return map_res[map_key][0]
def main():
global config
config = load_config(FLAGS.config_path)
ptq_config = config['PTQ']
data_loader = create('EvalReader')(config['EvalDataset'],
config['worker_num'],
return_list=True)
data_loader = reader_wrapper(data_loader, config['input_list'])
dataset = config['EvalDataset']
global val_loader
_eval_batch_sampler = paddle.io.BatchSampler(
dataset, batch_size=config['EvalReader']['batch_size'])
val_loader = create('EvalReader')(dataset,
config['worker_num'],
batch_sampler=_eval_batch_sampler,
return_list=True)
global metric
if config['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif config['metric'] == 'VOC':
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=config['num_classes'],
map_type=config['map_type'])
elif config['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else:
raise ValueError("metric currently only supports COCO and VOC.")
analyzer = AnalysisQuant(
model_dir=config["model_dir"],
model_filename=config["model_filename"],
params_filename=config["params_filename"],
eval_function=eval_function,
data_loader=data_loader,
save_dir=config['save_dir'],
ptq_config=ptq_config)
analyzer.analysis()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
input_list: ['image', 'scale_factor']
model_dir: ./picodet_s_416_coco_lcnet/
model_filename: model.pdmodel
params_filename: model.pdiparams
save_dir: ./analysis_results
metric: COCO
num_classes: 80
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 10
batch_nums: 10
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: /dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: /dataset/coco/
eval_height: &eval_height 416
eval_width: &eval_width 416
eval_size: &eval_size [*eval_height, *eval_width]
worker_num: 0
EvalReader:
inputs_def:
image_shape: [1, 3, *eval_height, *eval_width]
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_size: 32
input_list: ['image', 'scale_factor']
model_dir: ./picodet_s_416_coco_lcnet/
model_filename: model.pdmodel
params_filename: model.pdiparams
skip_tensor_list: None
metric: COCO
num_classes: 80
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: /dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: /dataset/coco/
eval_height: &eval_height 416
eval_width: &eval_width 416
eval_size: &eval_size [*eval_height, *eval_width]
worker_num: 0
EvalReader:
inputs_def:
image_shape: [1, 3, *eval_height, *eval_width]
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: *eval_size, keep_ratio: False}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_size: 32
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import argparse
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.common import load_config as load_slim_config
from keypoint_utils import keypoint_post_process
from post_process import PPYOLOEPostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--config_path',
type=str,
default=None,
help="path of compression strategy config.",
required=True)
parser.add_argument(
'--devices',
type=str,
default='gpu',
help="which device used to compress.")
return parser
def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
def convert_numpy_data(data, metric):
data_all = {}
data_all = {k: np.array(v) for k, v in data.items()}
if isinstance(metric, VOCMetric):
for k, v in data_all.items():
if not isinstance(v[0], np.ndarray):
tmp_list = []
for t in v:
tmp_list.append(np.array(t))
data_all[k] = np.array(tmp_list)
else:
data_all = {k: np.array(v) for k, v in data.items()}
return data_all
def eval():
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place)
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
global_config["model_dir"].rstrip('/'),
exe,
model_filename=global_config["model_filename"],
params_filename=global_config["params_filename"])
print('Loaded model from: {}'.format(global_config["model_dir"]))
metric = global_config['metric']
for batch_id, data in enumerate(val_loader):
data_all = convert_numpy_data(data, metric)
data_input = {}
for k, v in data.items():
if isinstance(global_config['input_list'], list):
if k in global_config['input_list']:
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
res = {}
if 'arch' in global_config and global_config['arch'] == 'keypoint':
res = keypoint_post_process(data, data_input, exe, val_program,
fetch_targets, outs)
if 'arch' in global_config and global_config['arch'] == 'PPYOLOE':
postprocess = PPYOLOEPostProcess(
score_threshold=0.01, nms_threshold=0.6)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
print('Eval iter:', batch_id)
metric.accumulate()
metric.log()
metric.reset()
def main():
global global_config
all_config = load_slim_config(FLAGS.config_path)
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
dataset = reader_cfg['EvalDataset']
global val_loader
val_loader = create('EvalReader')(reader_cfg['EvalDataset'],
reader_cfg['worker_num'],
return_list=True)
metric = None
if reader_cfg['metric'] == 'COCO':
clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
elif reader_cfg['metric'] == 'VOC':
metric = VOCMetric(
label_list=dataset.get_label_list(),
class_num=reader_cfg['num_classes'],
map_type=reader_cfg['map_type'])
elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
anno_file = dataset.get_anno()
metric = KeyPointTopDownCOCOEval(anno_file,
len(dataset), 17, 'output_eval')
else:
raise ValueError("metric currently only supports COCO and VOC.")
global_config['metric'] = metric
eval()
if __name__ == '__main__':
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
paddle.set_device(FLAGS.devices)
main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import cv2
import copy
from paddleslim.common import get_logger
logger = get_logger(__name__, level=logging.INFO)
__all__ = ['keypoint_post_process']
def flip_back(output_flipped, matched_parts):
assert output_flipped.ndim == 4,\
'output_flipped should be [batch_size, num_joints, height, width]'
output_flipped = output_flipped[:, :, :, ::-1]
for pair in matched_parts:
tmp = output_flipped[:, pair[0], :, :].copy()
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
output_flipped[:, pair[1], :, :] = tmp
return output_flipped
def get_affine_transform(center,
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
input_size (np.ndarray[2, ]): Size of input feature (width, height).
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert len(
a) == 2, 'input of _get_3rd_point should be point with length of 2'
assert len(
b) == 2, 'input of _get_3rd_point should be point with length of 2'
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
def rotate_point(pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def transform_preds(coords, center, scale, output_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale * 200, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
class HRNetPostProcess(object):
def __init__(self, use_dark=True):
self.use_dark = use_dark
def get_max_preds(self, heatmaps):
'''get predictions from score maps
Args:
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 2]), the maximum confidence of the keypoints
'''
assert isinstance(heatmaps,
np.ndarray), 'heatmaps should be numpy.ndarray'
assert heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = heatmaps.shape[0]
num_joints = heatmaps.shape[1]
width = heatmaps.shape[3]
heatmaps_reshaped = heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
def gaussian_blur(self, heatmap, kernel):
border = (kernel - 1) // 2
batch_size = heatmap.shape[0]
num_joints = heatmap.shape[1]
height = heatmap.shape[2]
width = heatmap.shape[3]
for i in range(batch_size):
for j in range(num_joints):
origin_max = np.max(heatmap[i, j])
dr = np.zeros((height + 2 * border, width + 2 * border))
dr[border:-border, border:-border] = heatmap[i, j].copy()
dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
heatmap[i, j] = dr[border:-border, border:-border].copy()
heatmap[i, j] *= origin_max / np.max(heatmap[i, j])
return heatmap
def dark_parse(self, hm, coord):
heatmap_height = hm.shape[0]
heatmap_width = hm.shape[1]
px = int(coord[0])
py = int(coord[1])
if 1 < px < heatmap_width - 2 and 1 < py < heatmap_height - 2:
dx = 0.5 * (hm[py][px + 1] - hm[py][px - 1])
dy = 0.5 * (hm[py + 1][px] - hm[py - 1][px])
dxx = 0.25 * (hm[py][px + 2] - 2 * hm[py][px] + hm[py][px - 2])
dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1] \
+ hm[py-1][px-1])
dyy = 0.25 * (
hm[py + 2 * 1][px] - 2 * hm[py][px] + hm[py - 2 * 1][px])
derivative = np.matrix([[dx], [dy]])
hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
if dxx * dyy - dxy**2 != 0:
hessianinv = hessian.I
offset = -hessianinv * derivative
offset = np.squeeze(np.array(offset.T), axis=0)
coord += offset
return coord
def dark_postprocess(self, hm, coords, kernelsize):
'''
DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
Representation for Human Pose Estimation (CVPR 2020).
'''
hm = self.gaussian_blur(hm, kernelsize)
hm = np.maximum(hm, 1e-10)
hm = np.log(hm)
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
coords[n, p] = self.dark_parse(hm[n][p], coords[n][p])
return coords
def get_final_preds(self, heatmaps, center, scale, kernelsize=3):
"""
The highest heatvalue location with a quarter offset in the
direction from the highest response to the second highest response.
Args:
heatmaps (numpy.ndarray): The predicted heatmaps
center (numpy.ndarray): The boxes center
scale (numpy.ndarray): The scale factor
Returns:
preds: numpy.ndarray([batch_size, num_joints, 2]), keypoints coords
maxvals: numpy.ndarray([batch_size, num_joints, 1]), the maximum confidence of the keypoints
"""
coords, maxvals = self.get_max_preds(heatmaps)
heatmap_height = heatmaps.shape[2]
heatmap_width = heatmaps.shape[3]
if self.use_dark:
coords = self.dark_postprocess(heatmaps, coords, kernelsize)
else:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
diff = np.array([
hm[py][px + 1] - hm[py][px - 1],
hm[py + 1][px] - hm[py - 1][px]
])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy()
# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(coords[i], center[i], scale[i],
[heatmap_width, heatmap_height])
return preds, maxvals
def __call__(self, output, center, scale):
preds, maxvals = self.get_final_preds(np.array(output), center, scale)
outputs = [[
np.concatenate(
(preds, maxvals), axis=-1), np.mean(
maxvals, axis=1)
]]
return outputs
def keypoint_post_process(data, data_input, exe, val_program, fetch_targets,
outs):
data_input['image'] = np.flip(data_input['image'], [3])
output_flipped = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
output_flipped = np.array(output_flipped[0])
flip_perm = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14],
[15, 16]]
output_flipped = flip_back(output_flipped, flip_perm)
output_flipped[:, :, :, 1:] = copy.copy(output_flipped)[:, :, :, 0:-1]
hrnet_outputs = (np.array(outs[0]) + output_flipped) * 0.5
imshape = (
np.array(data['im_shape']))[:, ::-1] if 'im_shape' in data else None
center = np.array(data['center']) if 'center' in data else np.round(
imshape / 2.)
scale = np.array(data['scale']) if 'scale' in data else imshape / 200.
post_process = HRNetPostProcess()
outputs = post_process(hrnet_outputs, center, scale)
return {'keypoint': outputs}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
indexes = np.argsort(scores)
indexes = indexes[-candidate_size:]
while len(indexes) > 0:
current = indexes[-1]
picked.append(current)
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[:-1]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
np.expand_dims(
current_box, axis=0), )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)
def area_of(left_top, right_bottom):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
class PPYOLOEPostProcess(object):
"""
Args:
input_shape (int): network input image size
scale_factor (float): scale factor of ori image
"""
def __init__(self,
score_threshold=0.4,
nms_threshold=0.5,
nms_top_k=10000,
keep_top_k=300):
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
def _non_max_suppression(self, prediction, scale_factor):
batch_size = prediction.shape[0]
out_boxes_list = []
box_num_list = []
for batch_id in range(batch_size):
bboxes, confidences = prediction[batch_id][..., :4], prediction[
batch_id][..., 4:]
# nms
picked_box_probs = []
picked_labels = []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
mask = probs > self.score_threshold
probs = probs[mask]
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate(
[subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
top_k=self.nms_top_k)
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
else:
picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes
picked_box_probs[:, 0] /= scale_factor[batch_id][1]
picked_box_probs[:, 2] /= scale_factor[batch_id][1]
picked_box_probs[:, 1] /= scale_factor[batch_id][0]
picked_box_probs[:, 3] /= scale_factor[batch_id][0]
# clas score box
out_box = np.concatenate(
[
np.expand_dims(
np.array(picked_labels), axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4]
],
axis=1)
if out_box.shape[0] > self.keep_top_k:
out_box = out_box[out_box[:, 1].argsort()[::-1]
[:self.keep_top_k]]
out_boxes_list.append(out_box)
box_num_list.append(out_box.shape[0])
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
box_num_list = np.array(box_num_list)
return out_boxes_list, box_num_list
def __call__(self, outs, scale_factor):
out_boxes_list, box_num_list = self._non_max_suppression(outs,
scale_factor)
return {'bbox': out_boxes_list, 'bbox_num': box_num_list}
...@@ -19,7 +19,6 @@ import argparse ...@@ -19,7 +19,6 @@ import argparse
import paddle import paddle
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create from ppdet.core.workspace import create
from paddleslim.common import load_config as load_slim_config
from paddleslim.quant import quant_post_static from paddleslim.quant import quant_post_static
...@@ -63,33 +62,32 @@ def reader_wrapper(reader, input_list): ...@@ -63,33 +62,32 @@ def reader_wrapper(reader, input_list):
def main(): def main():
global global_config global config
all_config = load_slim_config(FLAGS.config_path) config = load_config(FLAGS.config_path)
assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
global_config = all_config["Global"]
reader_cfg = load_config(global_config['reader_config'])
train_loader = create('EvalReader')(reader_cfg['TrainDataset'], train_loader = create('EvalReader')(config['TrainDataset'],
reader_cfg['worker_num'], config['worker_num'],
return_list=True) return_list=True)
train_loader = reader_wrapper(train_loader, global_config['input_list']) train_loader = reader_wrapper(train_loader, config['input_list'])
place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace() place = paddle.CUDAPlace(0) if FLAGS.devices == 'gpu' else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
quant_post_static( quant_post_static(
executor=exe, executor=exe,
model_dir=global_config["model_dir"], model_dir=config["model_dir"],
quantize_model_path=FLAGS.save_dir, quantize_model_path=FLAGS.save_dir,
data_loader=train_loader, data_loader=train_loader,
model_filename=global_config["model_filename"], model_filename=config["model_filename"],
params_filename=global_config["params_filename"], params_filename=config["params_filename"],
batch_size=4, batch_size=4,
batch_nums=64, batch_nums=64,
algo=FLAGS.algo, algo=FLAGS.algo,
hist_percent=0.999, hist_percent=0.999,
is_full_quantize=False, is_full_quantize=False,
bias_correction=False, bias_correction=False,
onnx_format=False) onnx_format=False,
skip_tensor_list=config['skip_tensor_list']
if 'skip_tensor_list' in config else None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -68,6 +68,7 @@ def main(): ...@@ -68,6 +68,7 @@ def main():
global config global config
config = load_config(FLAGS.config_path) config = load_config(FLAGS.config_path)
ptq_config = config['PTQ']
input_name = 'x2paddle_image_arrays' if config[ input_name = 'x2paddle_image_arrays' if config[
'arch'] == 'YOLOv6' else 'x2paddle_images' 'arch'] == 'YOLOv6' else 'x2paddle_images'
...@@ -97,13 +98,9 @@ def main(): ...@@ -97,13 +98,9 @@ def main():
model_filename='model.pdmodel', model_filename='model.pdmodel',
params_filename='model.pdiparams', params_filename='model.pdiparams',
eval_function=eval_function, eval_function=eval_function,
quantizable_op_type=config['quantizable_op_type'],
weight_quantize_type=config['weight_quantize_type'],
activation_quantize_type=config['activation_quantize_type'],
is_full_quantize=config['is_full_quantize'],
data_loader=data_loader, data_loader=data_loader,
batch_size=config['batch_size'], save_dir=config['save_dir'],
save_dir=config['save_dir'], ) ptq_config=ptq_config)
analyzer.analysis() analyzer.analysis()
......
arch: YOLOv6 arch: YOLOv6
model_dir: ./yolov6s.onnx model_dir: ./yolov6s.onnx
save_dir: ./analysis_results save_dir: ./analysis_results
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
dataset_dir: /dataset/coco/ dataset_dir: /dataset/coco/
val_image_dir: val2017 val_image_dir: val2017
val_anno_path: annotations/instances_val2017.json val_anno_path: annotations/instances_val2017.json
batch_size: 10
PTQ:
quantizable_op_type: ["conv2d", "depthwise_conv2d"]
weight_quantize_type: 'abs_max'
activation_quantize_type: 'moving_average_abs_max'
is_full_quantize: False
batch_size: 10
batch_nums: 10
...@@ -39,8 +39,7 @@ __all__ = ["AnalysisQuant"] ...@@ -39,8 +39,7 @@ __all__ = ["AnalysisQuant"]
class AnalysisQuant(object): class AnalysisQuant(object):
def __init__( def __init__(self,
self,
model_dir, model_dir,
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
...@@ -49,12 +48,7 @@ class AnalysisQuant(object): ...@@ -49,12 +48,7 @@ class AnalysisQuant(object):
save_dir='analysis_results', save_dir='analysis_results',
checkpoint_name='analysis_checkpoint.pkl', checkpoint_name='analysis_checkpoint.pkl',
num_histogram_plots=10, num_histogram_plots=10,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], ptq_config=None):
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max',
is_full_quantize=False,
batch_size=10,
batch_nums=10, ):
""" """
AnalysisQuant provides to analysis the sensitivity of each op in the model. AnalysisQuant provides to analysis the sensitivity of each op in the model.
...@@ -68,13 +62,8 @@ class AnalysisQuant(object): ...@@ -68,13 +62,8 @@ class AnalysisQuant(object):
return a batch every time return a batch every time
save_dir(str, optional): the output dir that stores the analyzed information save_dir(str, optional): the output dir that stores the analyzed information
checkpoint_name(str, optional): the name of checkpoint file that saves analyzed information and avoids break off while ananlyzing checkpoint_name(str, optional): the name of checkpoint file that saves analyzed information and avoids break off while ananlyzing
num_histogram_plots: the number histogram plots you want to visilize, the plots will show in four PDF files for both best and worst and for both weight and act ops in the save_dir ptq_config(dict, optional): the args that can initialize PostTrainingQuantization
quantizable_op_type(list): op types that can be quantized
weight_quantize_type(str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'
activation_quantize_type(str): quantization type for activation, now support 'range_abs_max', 'moving_average_abs_max' and 'abs_max'
is_full_quantize(bool): if True, apply quantization to all supported quantizable op type. If False, only apply quantization to the input quantizable_op_type. Default is False.
batch_size(int, optional): the batch size of DataLoader, default is 10
batch_nums(int, optional): the number of calibrate data is 'batch_size*batch_nums'
""" """
if model_filename is None: if model_filename is None:
model_filename = 'model.pdmodel' model_filename = 'model.pdmodel'
...@@ -83,20 +72,16 @@ class AnalysisQuant(object): ...@@ -83,20 +72,16 @@ class AnalysisQuant(object):
self.model_dir = model_dir self.model_dir = model_dir
self.model_filename = model_filename self.model_filename = model_filename
self.params_filename = params_filename self.params_filename = params_filename
self.batch_nums = batch_nums
self.quantizable_op_type = quantizable_op_type
self.weight_quantize_type = weight_quantize_type
self.activation_quantize_type = activation_quantize_type
self.is_full_quantize = is_full_quantize
self.histogram_bins = 1000 self.histogram_bins = 1000
self.save_dir = save_dir self.save_dir = save_dir
self.eval_function = eval_function self.eval_function = eval_function
self.quant_layer_names = [] self.quant_layer_names = []
self.checkpoint_name = os.path.join(save_dir, checkpoint_name) self.checkpoint_name = os.path.join(save_dir, checkpoint_name)
self.quant_layer_metrics = {} self.quant_layer_metrics = {}
self.batch_size = batch_size
self.batch_nums = batch_nums
self.num_histogram_plots = num_histogram_plots self.num_histogram_plots = num_histogram_plots
self.ptq_config = ptq_config
self.batch_nums = ptq_config[
'batch_nums'] if 'batch_nums' in ptq_config else 10
if not os.path.exists(self.save_dir): if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir) os.mkdir(self.save_dir)
...@@ -130,14 +115,9 @@ class AnalysisQuant(object): ...@@ -130,14 +115,9 @@ class AnalysisQuant(object):
model_dir=self.model_dir, model_dir=self.model_dir,
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
batch_size=self.batch_size, skip_tensor_list=None,
batch_nums=self.batch_nums, algo='avg', #fastest
algo='avg', # fastest **self.ptq_config)
quantizable_op_type=self.quantizable_op_type,
weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type,
is_full_quantize=self.is_full_quantize,
skip_tensor_list=None, )
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
self.quant_metric = self.eval_function(executor, program, self.quant_metric = self.eval_function(executor, program,
self.feed_list, self.fetch_list) self.feed_list, self.fetch_list)
...@@ -208,14 +188,9 @@ class AnalysisQuant(object): ...@@ -208,14 +188,9 @@ class AnalysisQuant(object):
model_dir=self.model_dir, model_dir=self.model_dir,
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
batch_size=self.batch_size, skip_tensor_list=skip_list,
batch_nums=self.batch_nums, algo='avg', #fastest
algo='avg', # fastest **self.ptq_config)
quantizable_op_type=self.quantizable_op_type,
weight_quantize_type=self.weight_quantize_type,
activation_quantize_type=self.activation_quantize_type,
is_full_quantize=self.is_full_quantize,
skip_tensor_list=skip_list, )
program = post_training_quantization.quantize() program = post_training_quantization.quantize()
_logger.info('Evaluating...') _logger.info('Evaluating...')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册