未验证 提交 d207622b 编写于 作者: G Guanghua Yu 提交者: GitHub

add YOLOv5s ACT demo (#1143)

* add YOLOv5s ACT demo

* fix comment

* fix docs
上级 567a97d9
......@@ -18,7 +18,7 @@
## 2.Benchmark
- PP-YOLOE模型
### PP-YOLOE
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP32</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
......@@ -28,12 +28,33 @@
- mAP的指标均在COCO val2017数据集中评测得到。
- PP-YOLOE模型在Tesla V100的GPU环境下测试,测试脚本是[benchmark demo](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/python)
### YOLOv5
| 模型 | 策略 | 输入尺寸 | mAP<sup>val<br>0.5:0.95 | 预测时延<sup><small>FP32</small><sup><br><sup>(ms) |预测时延<sup><small>FP32</small><sup><br><sup>(ms) | 预测时延<sup><small>INT8</small><sup><br><sup>(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv5s | Base模型 | 640*640 | 37.4 | 6.0 | 4.9ms | - | - | [Model](https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar) |
| YOLOv5s | 量化+蒸馏 | 640*640 | 36.5 | - | - | 4.5ms | [config](https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/auto_compression/detection/configs/yolov5s_qat_dis.yaml) | [Model](https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar) |
说明:
- mAP的指标均在COCO val2017数据集中评测得到。
- YOLOv5s模型在Tesla V100的GPU环境下测试,测试脚本是[benchmark demo](./infer.py)
- YOLOv5模型源自[ultralytics/yolov5](https://github.com/ultralytics/yolov5),通过[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)工具转换YOLOv5预测模型步骤:
(1) 安装X2Paddle的1.3.6以上版本;(pip install x2paddle)
(2) 转换模型:
```
x2paddle --framework=onnx --model=yolov5s.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov5_inference_model
```
即可得到YOLOv5s模型的预测模型(`model.pdmodel``model.pdiparams`)。如想快速体验,可直接下载上方表格中YOLOv5s的Base预测模型。
## 3. 自动压缩流程
#### 3.1 准备环境
- PaddlePaddle >= 2.3 (可从[Paddle官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载安装)
- PaddleSlim develop版本
- PaddleDet >= 2.4
- opencv-python
安装paddlepaddle:
```shell
......
metric: COCO
num_classes: 80
# Datset configuration
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco/
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco/
worker_num: 4
# preprocess reader in test
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: True}
- Pad: {size: [640, 640], fill_value: [114., 114., 114.]}
- NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
- Permute: {}
batch_size: 1
Global:
reader_config: configs/yolov5_reader.yml
input_list: {'image': 'x2paddle_images'}
Evaluation: True
arch: 'YOLOv5'
model_dir: ./yolov5s_infer/
model_filename: model.pdmodel
params_filename: model.pdiparams
Distillation:
distill_lambda: 1.0
distill_loss: l2_loss
distill_node_pair:
- teacher_conv2d_106.tmp_1
- conv2d_106.tmp_1
- teacher_conv2d_113.tmp_1
- conv2d_113.tmp_1
- teacher_conv2d_119.tmp_1
- conv2d_119.tmp_1
merge_feed: true
teacher_model_dir: ./yolov5_inference_model/
teacher_model_filename: model.pdmodel
teacher_params_filename: model.pdiparams
Quantization:
use_pact: true
activation_bits: 8
weight_bits: 8
activation_quantize_type: 'range_abs_max'
weight_quantize_type: 'channel_wise_abs_max'
is_full_quantize: false
not_quant_pattern:
- skip_quant
quantize_op_types:
- conv2d
- depthwise_conv2d
TrainConfig:
epochs: 1
eval_iter: 1000
learning_rate: 0.00001
optimizer: SGD
optim_args:
weight_decay: 4.0e-05
target_metric: 0.365
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import argparse
import time
from paddle.inference import Config
from paddle.inference import create_predictor
from post_process import YOLOv5PostProcess
CLASS_LABEL = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'
]
def generate_scale(im, target_shape, keep_ratio=True):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
if keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(target_shape)
target_size_max = np.max(target_shape)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = target_shape
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
def image_preprocess(img_path, target_shape):
img = cv2.imread(img_path)
# Resize
im_scale_y, im_scale_x = generate_scale(img, target_shape)
img = cv2.resize(
img,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_LINEAR)
# Pad
im_h, im_w = img.shape[:2]
h, w = target_shape[:]
if h != im_h or w != im_w:
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array([114.0, 114.0, 114.0], dtype=np.float32)
canvas[0:im_h, 0:im_w, :] = img.astype(np.float32)
img = canvas
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img, [2, 0, 1]) / 255
img = np.expand_dims(img, 0)
scale_factor = np.array([[im_scale_y, im_scale_x]])
return img.astype(np.float32), scale_factor
def get_color_map_list(num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def draw_box(image_file, results, class_label, threshold=0.5):
srcimg = cv2.imread(image_file, 1)
for i in range(len(results)):
color_list = get_color_map_list(len(class_label))
clsid2color = {}
classid, conf = int(results[i, 0]), results[i, 1]
if conf < threshold:
continue
xmin, ymin, xmax, ymax = int(results[i, 2]), int(results[i, 3]), int(
results[i, 4]), int(results[i, 5])
if classid not in clsid2color:
clsid2color[classid] = color_list[classid]
color = tuple(clsid2color[classid])
cv2.rectangle(srcimg, (xmin, ymin), (xmax, ymax), color, thickness=2)
print(class_label[classid] + ': ' + str(round(conf, 3)))
cv2.putText(
srcimg,
class_label[classid] + ':' + str(round(conf, 3)), (xmin, ymin - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.8, (0, 255, 0),
thickness=2)
return srcimg
def load_predictor(model_dir,
run_mode='paddle',
batch_size=1,
device='CPU',
min_subgraph_size=3,
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
delete_shuffle_pass=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
Used by action model.
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
if device != 'GPU' and run_mode != 'paddle':
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.format(run_mode, device))
config = Config(
os.path.join(model_dir, 'model.pdmodel'),
os.path.join(model_dir, 'model.pdiparams'))
if device == 'GPU':
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
elif device == 'XPU':
config.enable_lite_engine()
config.enable_xpu(10 * 1024 * 1024)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)
if enable_mkldnn:
try:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if enable_mkldnn_bfloat16:
config.enable_mkldnn_bfloat16()
except Exception as e:
print(
"The current environment does not support `mkldnn`, so disable mkldnn."
)
pass
precision_map = {
'trt_int8': Config.Precision.Int8,
'trt_fp32': Config.Precision.Float32,
'trt_fp16': Config.Precision.Half
}
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=(1 << 25) * batch_size,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if use_dynamic_shape:
min_input_shape = {
'image': [batch_size, 3, trt_min_shape, trt_min_shape]
}
max_input_shape = {
'image': [batch_size, 3, trt_max_shape, trt_max_shape]
}
opt_input_shape = {
'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
print('trt set dynamic shape done!')
# disable print log when predict
config.disable_glog_info()
# enable shared memory
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
if delete_shuffle_pass:
config.delete_pass("shuffle_channel_detect_pass")
predictor = create_predictor(config)
return predictor
def predict_image(predictor,
image_file,
image_shape=[640, 640],
warmup=1,
repeats=1,
threshold=0.5,
arch='YOLOv5'):
img, scale_factor = image_preprocess(image_file, image_shape)
inputs = {}
if arch == 'YOLOv5':
inputs['x2paddle_images'] = img
input_names = predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
for i in range(warmup):
predictor.run()
np_boxes = None
predict_time = 0.
time_min = float("inf")
time_max = float('-inf')
for i in range(repeats):
start_time = time.time()
predictor.run()
output_names = predictor.get_output_names()
boxes_tensor = predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
end_time = time.time()
timed = end_time - start_time
time_min = min(time_min, timed)
time_max = max(time_max, timed)
predict_time += timed
time_avg = predict_time / repeats
print('Inference time(ms): min={}, max={}, avg={}'.format(
round(time_min * 1000, 2),
round(time_max * 1000, 1), round(time_avg * 1000, 1)))
postprocess = YOLOv5PostProcess(
score_threshold=0.001, nms_threshold=0.6, multi_label=True)
res = postprocess(np_boxes, scale_factor)
res_img = draw_box(
image_file, res['bbox'], CLASS_LABEL, threshold=threshold)
cv2.imwrite('result.jpg', res_img)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--image_file', type=str, default=None, help="image path")
parser.add_argument(
'--model_path', type=str, help="inference model filepath")
parser.add_argument(
'--benchmark',
type=bool,
default=False,
help="Whether run benchmark or not.")
parser.add_argument(
'--run_mode',
type=str,
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
'--device',
type=str,
default='CPU',
help="Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU"
)
parser.add_argument('--img_shape', type=int, default=640, help="input_size")
args = parser.parse_args()
predictor = load_predictor(
args.model_path, run_mode=args.run_mode, device=args.device)
warmup, repeats = 1, 1
if args.benchmark:
warmup, repeats = 50, 100
predict_image(
predictor,
args.image_file,
image_shape=[args.img_shape, args.img_shape],
warmup=warmup,
repeats=repeats)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
def box_area(boxes):
"""
Args:
boxes(np.ndarray): [N, 4]
return: [N]
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def box_iou(box1, box2):
"""
Args:
box1(np.ndarray): [N, 4]
box2(np.ndarray): [M, 4]
return: [N, M]
"""
area1 = box_area(box1)
area2 = box_area(box2)
lt = np.maximum(box1[:, np.newaxis, :2], box2[:, :2])
rb = np.minimum(box1[:, np.newaxis, 2:], box2[:, 2:])
wh = rb - lt
wh = np.maximum(0, wh)
inter = wh[:, :, 0] * wh[:, :, 1]
iou = inter / (area1[:, np.newaxis] + area2 - inter)
return iou
def nms(boxes, scores, iou_threshold):
"""
Non Max Suppression numpy implementation.
args:
boxes(np.ndarray): [N, 4]
scores(np.ndarray): [N, 1]
iou_threshold(float): Threshold of IoU.
"""
idxs = scores.argsort()
keep = []
while idxs.size > 0:
max_score_index = idxs[-1]
max_score_box = boxes[max_score_index][None, :]
keep.append(max_score_index)
if idxs.size == 1:
break
idxs = idxs[:-1]
other_boxes = boxes[idxs]
ious = box_iou(max_score_box, other_boxes)
idxs = idxs[ious[0] <= iou_threshold]
keep = np.array(keep)
return keep
class YOLOv5PostProcess(object):
"""
Post process of YOLOv5 network.
args:
score_threshold(float): Threshold to filter out bounding boxes with low
confidence score. If not provided, consider all boxes.
nms_threshold(float): The threshold to be used in NMS.
multi_label(bool): Whether keep multi label in boxes.
keep_top_k(int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
"""
def __init__(self,
score_threshold=0.25,
nms_threshold=0.5,
multi_label=False,
keep_top_k=300):
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.multi_label = multi_label
self.keep_top_k = keep_top_k
def _xywh2xyxy(self, x):
# Convert from [x, y, w, h] to [x1, y1, x2, y2]
y = np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
def _non_max_suppression(self, prediction):
max_wh = 4096 # (pixels) minimum and maximum box width and height
nms_top_k = 30000
cand_boxes = prediction[..., 4] > self.score_threshold # candidates
output = [np.zeros((0, 6))] * prediction.shape[0]
for batch_id, boxes in enumerate(prediction):
# Apply constraints
boxes = boxes[cand_boxes[batch_id]]
if not boxes.shape[0]:
continue
# Compute conf (conf = obj_conf * cls_conf)
boxes[:, 5:] *= boxes[:, 4:5]
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
convert_box = self._xywh2xyxy(boxes[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
if self.multi_label:
i, j = (boxes[:, 5:] > self.score_threshold).nonzero()
boxes = np.concatenate(
(convert_box[i], boxes[i, j + 5, None],
j[:, None].astype(np.float32)),
axis=1)
else:
conf = np.max(boxes[:, 5:], axis=1)
j = np.argmax(boxes[:, 5:], axis=1)
re = np.array(conf.reshape(-1) > self.score_threshold)
conf = conf.reshape(-1, 1)
j = j.reshape(-1, 1)
boxes = np.concatenate((convert_box, conf, j), axis=1)[re]
num_box = boxes.shape[0]
if not num_box:
continue
elif num_box > nms_top_k:
boxes = boxes[boxes[:, 4].argsort()[::-1][:nms_top_k]]
# Batched NMS
c = boxes[:, 5:6] * max_wh
clean_boxes, scores = boxes[:, :4] + c, boxes[:, 4]
keep = nms(clean_boxes, scores, self.nms_threshold)
# limit detection box num
if keep.shape[0] > self.keep_top_k:
keep = keep[:self.keep_top_k]
output[batch_id] = boxes[keep]
return output
def __call__(self, outs, scale_factor):
preds = self._non_max_suppression(outs)
bboxs, box_nums = [], []
for i, pred in enumerate(preds):
if len(pred.shape) > 2:
pred = np.squeeze(pred)
if len(pred.shape) == 1:
pred = pred[np.newaxis, :]
pred_bboxes = pred[:, :4]
scale_factor = np.tile(scale_factor[i][::-1], (1, 2))
pred_bboxes /= scale_factor
bbox = np.concatenate(
[
pred[:, -1][:, np.newaxis], pred[:, -2][:, np.newaxis],
pred_bboxes
],
axis=-1)
bboxs.append(bbox)
box_num = bbox.shape[0]
box_nums.append(box_num)
bboxs = np.concatenate(bboxs, axis=0)
box_nums = np.array(box_nums)
return {'bbox': bboxs, 'bbox_num': box_nums}
......@@ -23,6 +23,8 @@ from ppdet.metrics import COCOMetric
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
from post_process import YOLOv5PostProcess
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
......@@ -59,8 +61,12 @@ def reader_wrapper(reader, input_list):
def gen():
for data in reader:
in_dict = {}
for input_name in input_list:
in_dict[input_name] = data[input_name]
if isinstance(input_list, list):
for input_name in input_list:
in_dict[input_name] = data[input_name]
elif isinstance(input_list, dict):
for input_name in input_list.keys():
in_dict[input_list[input_name]] = data[input_name]
yield in_dict
return gen
......@@ -80,24 +86,34 @@ def eval(config):
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, bias=0, IouType='bbox')
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
data_input = {}
for k, v in data.items():
if k in config['input_list']:
data_input[k] = np.array(v)
if isinstance(config['input_list'], list):
if k in config['input_list']:
data_input[k] = np.array(v)
elif isinstance(config['input_list'], dict):
if k in config['input_list'].keys():
data_input[config['input_list'][k]] = np.array(v)
outs = exe.run(val_program,
feed=data_input,
fetch_list=fetch_targets,
return_numpy=False)
res = {}
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
if 'arch' in config and config['arch'] == 'YOLOv5':
postprocess = YOLOv5PostProcess(
score_threshold=0.001, nms_threshold=0.6, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
......@@ -112,24 +128,33 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
anno_file = dataset.get_anno()
metric = COCOMetric(
anno_file=anno_file, clsid2catid=clsid2catid, bias=1, IouType='bbox')
anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
for batch_id, data in enumerate(val_loader):
data_all = {k: np.array(v) for k, v in data.items()}
data_input = {}
for k, v in data.items():
if k in test_feed_names:
data_input[k] = np.array(v)
if isinstance(global_config['input_list'], list):
if k in test_feed_names:
data_input[k] = np.array(v)
elif isinstance(global_config['input_list'], dict):
if k in global_config['input_list'].keys():
data_input[global_config['input_list'][k]] = np.array(v)
outs = exe.run(compiled_test_program,
feed=data_input,
fetch_list=test_fetch_list,
return_numpy=False)
res = {}
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
if 'arch' in global_config and global_config['arch'] == 'YOLOv5':
postprocess = YOLOv5PostProcess(
score_threshold=0.001, nms_threshold=0.6, multi_label=True)
res = postprocess(np.array(outs[0]), data_all['scale_factor'])
else:
for out in outs:
v = np.array(out)
if len(v.shape) > 1:
res['bbox'] = v
else:
res['bbox_num'] = v
metric.update(data_all, res)
if batch_id % 100 == 0:
......@@ -142,6 +167,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
def main():
global global_config
compress_config, train_config, global_config = load_slim_config(
FLAGS.config_path)
reader_cfg = load_config(global_config['reader_config'])
......
......@@ -571,10 +571,9 @@ class AutoCompression:
os.remove(os.path.join(self.tmp_dir, 'best_model.pdparams'))
if 'qat' in strategy:
float_program, int8_program = convert(test_program_info.program._program, self._places, self._quant_config, \
test_program, int8_program = convert(test_program, self._places, self._quant_config, \
scope=paddle.static.global_scope(), \
save_int8=True)
test_program_info.program = float_program
model_dir = os.path.join(self.tmp_dir,
'strategy_{}'.format(str(strategy_idx + 1)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册