diff --git a/deploy/serving/python/preprocess_ops.py b/deploy/serving/python/preprocess_ops.py index eaf013744e938164f6e42a99c5517c35e0beeb83..70e7d1baa312ce2419a5455d599bf3a661a93d97 100644 --- a/deploy/serving/python/preprocess_ops.py +++ b/deploy/serving/python/preprocess_ops.py @@ -3,10 +3,14 @@ import cv2 import copy -def decode_image(im, img_info): +def decode_image(im): im = np.array(im) - img_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32) - img_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) + img_info = { + "im_shape": np.array( + im.shape[:2], dtype=np.float32), + "scale_factor": np.array( + [1., 1.], dtype=np.float32) + } return im, img_info @@ -399,16 +403,10 @@ class Compose: op_type = new_op_info.pop('type') self.transforms.append(eval(op_type)(**new_op_info)) - self.im_info = { - 'scale_factor': np.array( - [1., 1.], dtype=np.float32), - 'im_shape': None - } - def __call__(self, img): - img, self.im_info = decode_image(img, self.im_info) + img, im_info = decode_image(img) for t in self.transforms: - img, self.im_info = t(img, self.im_info) - inputs = copy.deepcopy(self.im_info) + img, im_info = t(img, im_info) + inputs = copy.deepcopy(im_info) inputs['image'] = img return inputs diff --git a/deploy/serving/python/web_service.py b/deploy/serving/python/web_service.py index 5a7018f98d670a2f9dc3715888d9b0191091d088..12075bc87c6bc79003b6c55a0a9d8f248a8eb36e 100644 --- a/deploy/serving/python/web_service.py +++ b/deploy/serving/python/web_service.py @@ -132,7 +132,7 @@ class PredictConfig(object): self.arch = yml_conf['arch'] self.preprocess_infos = yml_conf['Preprocess'] self.min_subgraph_size = yml_conf['min_subgraph_size'] - self.labels = yml_conf['label_list'] + self.label_list = yml_conf['label_list'] self.use_dynamic_shape = yml_conf['use_dynamic_shape'] self.draw_threshold = yml_conf.get("draw_threshold", 0.5) self.mask = yml_conf.get("mask", False) @@ -189,8 +189,8 @@ class DetectorOp(Op): result = {} for k, num in zip(input_dict.keys(), bboxes_num): bbox = bboxes[idx:idx + num] - result[k] = self.parse_det_result(bbox, draw_threshold, - GLOBAL_VAR['model_config'].labels) + result[k] = self.parse_det_result( + bbox, draw_threshold, GLOBAL_VAR['model_config'].label_list) return result, None, "" def collate_inputs(self, inputs): @@ -206,7 +206,7 @@ class DetectorOp(Op): def parse_det_result(self, bbox, draw_threshold, label_list): result = [] for line in bbox: - if line[1] > draw_threshold: + if line[0] > -1 and line[1] > draw_threshold: result.append(f"{label_list[int(line[0])]} {line[1]} " f"{line[2]} {line[3]} {line[4]} {line[5]}") return result diff --git a/deploy/third_engine/demo_onnxruntime/infer_demo.py b/deploy/third_engine/demo_onnxruntime/infer_demo.py index 09a932d8e403f671923714280396393e37414af5..41f407097828fa099a43831d4200193ba91557be 100644 --- a/deploy/third_engine/demo_onnxruntime/infer_demo.py +++ b/deploy/third_engine/demo_onnxruntime/infer_demo.py @@ -55,8 +55,9 @@ class PicoDet(): origin_shape = srcimg.shape[:2] im_scale_y = newh / float(origin_shape[0]) im_scale_x = neww / float(origin_shape[1]) - img_shape = np.array([[float(origin_shape[0]), float(origin_shape[1])] - ]).astype('float32') + img_shape = np.array([ + [float(self.input_shape[0]), float(self.input_shape[1])] + ]).astype('float32') scale_factor = np.array([[im_scale_y, im_scale_x]]).astype('float32') if keep_ratio and srcimg.shape[0] != srcimg.shape[1]: diff --git a/deploy/third_engine/onnx/infer.py b/deploy/third_engine/onnx/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..322768759b591e60671736c833d0b3875a60942e --- /dev/null +++ b/deploy/third_engine/onnx/infer.py @@ -0,0 +1,161 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml +import argparse +import numpy as np +import glob +from onnxruntime import InferenceSession + +from preprocess import Compose + +# Global dictionary +SUPPORT_MODELS = { + 'YOLO', + 'RCNN', + 'SSD', + 'Face', + 'FCOS', + 'SOLOv2', + 'TTFNet', + 'S2ANet', + 'JDE', + 'FairMOT', + 'DeepSORT', + 'GFL', + 'PicoDet', + 'CenterNet', + 'TOOD', + 'RetinaNet', + 'StrongBaseline', + 'STGCN', + 'YOLOX', +} + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument("-c", "--config", type=str, help="infer_cfg.yml") +parser.add_argument( + '--onnx_file', type=str, default="model.onnx", help="onnx model file path") +parser.add_argument("--image_dir", type=str) +parser.add_argument("--image_file", type=str) + + +def get_test_images(infer_dir, infer_img): + """ + Get image path list in TEST mode + """ + assert infer_img is not None or infer_dir is not None, \ + "--image_file or --image_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) + + # infer_img has a higher priority + if infer_img and os.path.isfile(infer_img): + return [infer_img] + + images = set() + infer_dir = os.path.abspath(infer_dir) + assert os.path.isdir(infer_dir), \ + "infer_dir {} is not a directory".format(infer_dir) + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for ext in exts: + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images = list(images) + + assert len(images) > 0, "no image found in {}".format(infer_dir) + print("Found {} inference images in total.".format(len(images))) + + return images + + +class PredictConfig(object): + """set config of preprocess, postprocess and visualize + Args: + model_dir (str): root path of infer_cfg.yml + """ + + def __init__(self, infer_config): + # parsing Yaml config for Preprocess + with open(infer_config) as f: + yml_conf = yaml.safe_load(f) + self.check_model(yml_conf) + self.arch = yml_conf['arch'] + self.preprocess_infos = yml_conf['Preprocess'] + self.min_subgraph_size = yml_conf['min_subgraph_size'] + self.label_list = yml_conf['label_list'] + self.use_dynamic_shape = yml_conf['use_dynamic_shape'] + self.draw_threshold = yml_conf.get("draw_threshold", 0.5) + self.mask = yml_conf.get("mask", False) + self.tracker = yml_conf.get("tracker", None) + self.nms = yml_conf.get("NMS", None) + self.fpn_stride = yml_conf.get("fpn_stride", None) + if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): + print( + 'The RCNN export model is used for ONNX and it only supports batch_size = 1' + ) + self.print_config() + + def check_model(self, yml_conf): + """ + Raises: + ValueError: loaded model not in supported model type + """ + for support_model in SUPPORT_MODELS: + if support_model in yml_conf['arch']: + return True + raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ + 'arch'], SUPPORT_MODELS)) + + def print_config(self): + print('----------- Model Configuration -----------') + print('%s: %s' % ('Model Arch', self.arch)) + print('%s: ' % ('Transform Order')) + for op_info in self.preprocess_infos: + print('--%s: %s' % ('transform op', op_info['type'])) + print('--------------------------------------------') + + +def predict_image(infer_config, predictor, img_list): + # load preprocess transforms + transforms = Compose(infer_config.preprocess_infos) + # predict image + for img_path in img_list: + inputs = transforms(img_path) + inputs_name = [var.name for var in predictor.get_inputs()] + inputs = {k: inputs[k][None, ] for k in inputs_name} + + outputs = predictor.run(output_names=None, input_feed=inputs) + + print("ONNXRuntime predict: ") + bboxes = np.array(outputs[0]) + for bbox in bboxes: + if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold: + print(f"{infer_config.label_list[int(bbox[0])]} {bbox[1]} " + f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}") + + +if __name__ == '__main__': + FLAGS = parser.parse_args() + # load image list + img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file) + # load predictor + predictor = InferenceSession(FLAGS.onnx_file) + # load infer config + infer_config = PredictConfig(FLAGS.config) + + predict_image(infer_config, predictor, img_list) diff --git a/deploy/third_engine/onnx/preprocess.py b/deploy/third_engine/onnx/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b28ed2fe053e770339235ee9718be9112612e1 --- /dev/null +++ b/deploy/third_engine/onnx/preprocess.py @@ -0,0 +1,416 @@ +import numpy as np +import cv2 +import copy + + +def decode_image(img_path): + with open(img_path, 'rb') as f: + im_read = f.read() + data = np.frombuffer(im_read, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + img_info = { + "im_shape": np.array( + im.shape[:2], dtype=np.float32), + "scale_factor": np.array( + [1., 1.], dtype=np.float32) + } + return im, img_info + + +class Resize(object): + """resize image by target_size and max_size + Args: + target_size (int): the target size of image + keep_ratio (bool): whether keep_ratio or not, default true + interp (int): method of resize + """ + + def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + self.keep_ratio = keep_ratio + self.interp = interp + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + im_channel = im.shape[2] + im_scale_y, im_scale_x = self.generate_scale(im) + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') + im_info['scale_factor'] = np.array( + [im_scale_y, im_scale_x]).astype('float32') + return im, im_info + + def generate_scale(self, im): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + im_c = im.shape[2] + if self.keep_ratio: + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(self.target_size) + target_size_max = np.max(self.target_size) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = self.target_size + im_scale_y = resize_h / float(origin_shape[0]) + im_scale_x = resize_w / float(origin_shape[1]) + return im_scale_y, im_scale_x + + +class NormalizeImage(object): + """normalize image + Args: + mean (list): im - mean + std (list): im / std + is_scale (bool): whether need im / 255 + is_channel_first (bool): if True: image shape is CHW, else: HWC + """ + + def __init__(self, mean, std, is_scale=True): + self.mean = mean + self.std = std + self.is_scale = is_scale + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.astype(np.float32, copy=False) + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + + if self.is_scale: + im = im / 255.0 + im -= mean + im /= std + return im, im_info + + +class Permute(object): + """permute image + Args: + to_bgr (bool): whether convert RGB to BGR + channel_first (bool): whether convert HWC to CHW + """ + + def __init__(self, ): + super(Permute, self).__init__() + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.transpose((2, 0, 1)).copy() + return im, im_info + + +class PadStride(object): + """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config + Args: + stride (bool): model with FPN need image shape % stride == 0 + """ + + def __init__(self, stride=0): + self.coarsest_stride = stride + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + coarsest_stride = self.coarsest_stride + if coarsest_stride <= 0: + return im, im_info + im_c, im_h, im_w = im.shape + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + return padding_im, im_info + + +class LetterBoxResize(object): + def __init__(self, target_size): + """ + Resize image to target size, convert normalized xywh to pixel xyxy + format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]). + Args: + target_size (int|list): image target size. + """ + super(LetterBoxResize, self).__init__() + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + + def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)): + # letterbox: resize a rectangular image to a padded rectangular + shape = img.shape[:2] # [height, width] + ratio_h = float(height) / shape[0] + ratio_w = float(width) / shape[1] + ratio = min(ratio_h, ratio_w) + new_shape = (round(shape[1] * ratio), + round(shape[0] * ratio)) # [width, height] + padw = (width - new_shape[0]) / 2 + padh = (height - new_shape[1]) / 2 + top, bottom = round(padh - 0.1), round(padh + 0.1) + left, right = round(padw - 0.1), round(padw + 0.1) + + img = cv2.resize( + img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, + value=color) # padded rectangular + return img, ratio, padw, padh + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + height, width = self.target_size + h, w = im.shape[:2] + im, ratio, padw, padh = self.letterbox(im, height=height, width=width) + + new_shape = [round(h * ratio), round(w * ratio)] + im_info['im_shape'] = np.array(new_shape, dtype=np.float32) + im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32) + return im, im_info + + +class Pad(object): + def __init__(self, size, fill_value=[114.0, 114.0, 114.0]): + """ + Pad image to a specified size. + Args: + size (list[int]): image target size + fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0) + """ + super(Pad, self).__init__() + if isinstance(size, int): + size = [size, size] + self.size = size + self.fill_value = fill_value + + def __call__(self, im, im_info): + im_h, im_w = im.shape[:2] + h, w = self.size + if h == im_h and w == im_w: + im = im.astype(np.float32) + return im, im_info + + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array(self.fill_value, dtype=np.float32) + canvas[0:im_h, 0:im_w, :] = im.astype(np.float32) + im = canvas + return im, im_info + + +def rotate_point(pt, angle_rad): + """Rotate a point by an angle. + + Args: + pt (list[float]): 2 dimensional point to be rotated + angle_rad (float): rotation angle by radian + + Returns: + list[float]: Rotated point. + """ + assert len(pt) == 2 + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + new_x = pt[0] * cs - pt[1] * sn + new_y = pt[0] * sn + pt[1] * cs + rotated_pt = [new_x, new_y] + + return rotated_pt + + +def _get_3rd_point(a, b): + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): point(x,y) + b (np.ndarray): point(x,y) + + Returns: + np.ndarray: The 3rd point. + """ + assert len(a) == 2 + assert len(b) == 2 + direction = a - b + third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32) + + return third_pt + + +def get_affine_transform(center, + input_size, + rot, + output_size, + shift=(0., 0.), + inv=False): + """Get the affine transform matrix, given the center/scale/rot/output_size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ]): Size of the destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: The transform matrix. + """ + assert len(center) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + if not isinstance(input_size, (np.ndarray, list)): + input_size = np.array([input_size, input_size], dtype=np.float32) + scale_tmp = input_size + + shift = np.array(shift) + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = rotate_point([0., src_w * -0.5], rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +class WarpAffine(object): + """Warp affine the image + """ + + def __init__(self, + keep_res=False, + pad=31, + input_h=512, + input_w=512, + scale=0.4, + shift=0.1): + self.keep_res = keep_res + self.pad = pad + self.input_h = input_h + self.input_w = input_w + self.scale = scale + self.shift = shift + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + + h, w = img.shape[:2] + + if self.keep_res: + input_h = (h | self.pad) + 1 + input_w = (w | self.pad) + 1 + s = np.array([input_w, input_h], dtype=np.float32) + c = np.array([w // 2, h // 2], dtype=np.float32) + + else: + s = max(h, w) * 1.0 + input_h, input_w = self.input_h, self.input_w + c = np.array([w / 2., h / 2.], dtype=np.float32) + + trans_input = get_affine_transform(c, s, 0, [input_w, input_h]) + img = cv2.resize(img, (w, h)) + inp = cv2.warpAffine( + img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) + return inp, im_info + + +class Compose: + def __init__(self, transforms): + self.transforms = [] + for op_info in transforms: + new_op_info = op_info.copy() + op_type = new_op_info.pop('type') + self.transforms.append(eval(op_type)(**new_op_info)) + + def __call__(self, img_path): + img, im_info = decode_image(img_path) + for t in self.transforms: + img, im_info = t(img, im_info) + inputs = copy.deepcopy(im_info) + inputs['image'] = img + return inputs diff --git a/ppdet/modeling/assigners/task_aligned_assigner.py b/ppdet/modeling/assigners/task_aligned_assigner.py index 949e899192f65082c4cf422ae283e6a11834d3f6..5b3368e06814b4f3793e3f38c15d985cf4e8e6bc 100644 --- a/ppdet/modeling/assigners/task_aligned_assigner.py +++ b/ppdet/modeling/assigners/task_aligned_assigner.py @@ -93,7 +93,7 @@ class TaskAlignedAssigner(nn.Layer): return assigned_labels, assigned_bboxes, assigned_scores # compute iou between gt and pred bbox, [B, n, L] - ious = iou_similarity(gt_bboxes, pred_bboxes) + ious = batch_iou_similarity(gt_bboxes, pred_bboxes) # gather pred bboxes class score pred_scores = pred_scores.transpose([0, 2, 1]) batch_ind = paddle.arange(