diff --git a/modules/image/instance_segmentation/solov2/README.md b/modules/image/instance_segmentation/solov2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..60c8d811758d4830fd381e45a6b1476dacd892e3 --- /dev/null +++ b/modules/image/instance_segmentation/solov2/README.md @@ -0,0 +1,119 @@ +## 模型概述 + +solov2是基于'SOLOv2: Dynamic, Faster and Stronger'实现的快速实例分割的模型。该模型基于SOLOV1, 并且针对mask的检测效果和运行效率进行改进,在实例分割任务中表现优秀。相对语义分割,实例分割需要标注出图上同一物体的不同个体。solov2实例分割效果如下: + +
+ +
+ + +## API + +```python +def predict(self, + image: Union[str, np.ndarray], + threshold: float = 0.5, + visualization: bool = False, + save_dir: str = 'solov2_result'): +``` + +预测API,实例分割。 + +**参数** + +* image (Union\[str, np.ndarray\]): 图片路径或者图片数据,ndarray.shape 为 \[H, W, C\],BGR格式; +* threshold (float): 检测模型输出结果中,预测得分低于该阈值的框将被滤除,默认值为0.5; +* visualization (bool): 是否将可视化图片保存; +* save_dir (str): 保存图片到路径, 默认为"solov2_result"。 + +**返回** + +* res (dict): 识别结果,关键字有 'segm', 'label', 'score'对应的取值为: + * segm (np.ndarray): 实例分割结果,取值为0或1。0表示背景,1为实例; + * label (list): 实例分割结果类别id; + * score (list):实例分割结果类别得分; + + +## 代码示例 + +```python +import cv2 +import paddlehub as hub + +img = cv2.imread('/PATH/TO/IMAGE') +model = hub.Module(name='solov2', use_gpu=False) +output = model.predict(image=img,visualization=False) +``` + +## 服务部署 + +PaddleHub Serving可以部署一个实例分割的在线服务。 + +## 第一步:启动PaddleHub Serving + +运行启动命令: + +```shell +$ hub serving start -m solov2 +``` + +默认端口号为8866。 + +**NOTE:** 如使用GPU预测,则需要在启动服务之前,设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。 + +## 第二步:发送预测请求 + +配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + +```python +import requests +import json +import cv2 +import base64 + +import numpy as np + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data + +# 发送HTTP请求 + +org_im = cv2.imread('/PATH/TO/IMAGE') +h, w, c = org_im.shape +data = {'images':[cv2_to_base64(org_im)]} +headers = {"Content-type": "application/json"} +url = "http://127.0.0.1:8866/predict/solov2" +r = requests.post(url=url, headers=headers, data=json.dumps(data)) + +seg = base64.b64decode(r.json()["results"]['segm'].encode('utf8')) +seg = np.fromstring(seg, dtype=np.int32).reshape((-1, h, w)) + +label = base64.b64decode(r.json()["results"]['label'].encode('utf8')) +label = np.fromstring(label, dtype=np.int64) + +score = base64.b64decode(r.json()["results"]['score'].encode('utf8')) +score = np.fromstring(score, dtype=np.float32) + +print('seg', seg) +print('label', label) +print('score', score) +``` + +### 查看代码 + +https://github.com/PaddlePaddle/PaddleDetection + + +### 依赖 + +paddlepaddle >= 2.0.0 + +paddlehub >= 2.0.0 + diff --git a/modules/image/instance_segmentation/solov2/data_feed.py b/modules/image/instance_segmentation/solov2/data_feed.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2a28b5dd171cb427428376de85c293d0590304 --- /dev/null +++ b/modules/image/instance_segmentation/solov2/data_feed.py @@ -0,0 +1,335 @@ +import os +import base64 + +import cv2 +import numpy as np +from PIL import Image, ImageDraw +import paddle.fluid as fluid + + +def create_inputs(im, im_info): + """generate input for different model type + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + inputs (dict): input of model + """ + inputs = {} + inputs['image'] = im + origin_shape = list(im_info['origin_shape']) + resize_shape = list(im_info['resize_shape']) + pad_shape = list(im_info['pad_shape']) if im_info['pad_shape'] is not None else list(im_info['resize_shape']) + scale_x, scale_y = im_info['scale'] + scale = scale_x + im_info = np.array([resize_shape + [scale]]).astype('float32') + inputs['im_info'] = im_info + return inputs + + +def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0.5): + """ + Args: + im (str/np.ndarray): path of image/np.ndarray read by cv2 + results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + MaskRCNN's results include 'masks': np.ndarray: + shape:[N, class_num, mask_resolution, mask_resolution] + labels (list): labels:['class1', ..., 'classn'] + mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution] + threshold (float): Threshold of score. + Returns: + im (PIL.Image.Image): visualized image + """ + if not labels: + labels = ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', 'stop sign', 'parking meter', + 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', + 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', + 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', + 'hair drier', 'toothbrush'] + if isinstance(im, str): + im = Image.open(im).convert('RGB') + else: + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im = Image.fromarray(im) + if 'masks' in results and 'boxes' in results: + im = draw_mask( + im, + results['boxes'], + results['masks'], + labels, + resolution=mask_resolution) + if 'boxes' in results: + im = draw_box(im, results['boxes'], labels) + if 'segm' in results: + im = draw_segm( + im, + results['segm'], + results['label'], + results['score'], + labels, + threshold=threshold) + return im + + +def get_color_map_list(num_classes): + """ + Args: + num_classes (int): number of class + Returns: + color_map (list): RGB color list + """ + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + return color_map + + +def expand_boxes(boxes, scale=0.0): + """ + Args: + boxes (np.ndarray): shape:[N,4], N:number of box, + matix element:[x_min, y_min, x_max, y_max] + scale (float): scale of boxes + Returns: + boxes_exp (np.ndarray): expanded boxes + """ + w_half = (boxes[:, 2] - boxes[:, 0]) * .5 + h_half = (boxes[:, 3] - boxes[:, 1]) * .5 + x_c = (boxes[:, 2] + boxes[:, 0]) * .5 + y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + w_half *= scale + h_half *= scale + boxes_exp = np.zeros(boxes.shape) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + return boxes_exp + + +def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5): + """ + Args: + im (PIL.Image.Image): PIL image + np_boxes (np.ndarray): shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + np_masks (np.ndarray): shape:[N, class_num, resolution, resolution] + labels (list): labels:['class1', ..., 'classn'] + resolution (int): shape of a mask is:[resolution, resolution] + threshold (float): threshold of mask + Returns: + im (PIL.Image.Image): visualized image + """ + color_list = get_color_map_list(len(labels)) + scale = (resolution + 2.0) / resolution + im_w, im_h = im.size + w_ratio = 0.4 + alpha = 0.7 + im = np.array(im).astype('float32') + rects = np_boxes[:, 2:] + expand_rects = expand_boxes(rects, scale) + expand_rects = expand_rects.astype(np.int32) + clsid_scores = np_boxes[:, 0:2] + padded_mask = np.zeros((resolution + 2, resolution + 2), dtype=np.float32) + clsid2color = {} + for idx in range(len(np_boxes)): + clsid, score = clsid_scores[idx].tolist() + clsid = int(clsid) + xmin, ymin, xmax, ymax = expand_rects[idx].tolist() + w = xmax - xmin + 1 + h = ymax - ymin + 1 + w = np.maximum(w, 1) + h = np.maximum(h, 1) + padded_mask[1:-1, 1:-1] = np_masks[idx, int(clsid), :, :] + resized_mask = cv2.resize(padded_mask, (w, h)) + resized_mask = np.array(resized_mask > threshold, dtype=np.uint8) + x0 = min(max(xmin, 0), im_w) + x1 = min(max(xmax + 1, 0), im_w) + y0 = min(max(ymin, 0), im_h) + y1 = min(max(ymax + 1, 0), im_h) + im_mask = np.zeros((im_h, im_w), dtype=np.uint8) + im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (x0 - xmin):(x1 - xmin)] + if clsid not in clsid2color: + clsid2color[clsid] = color_list[clsid] + color_mask = clsid2color[clsid] + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(im_mask) + color_mask = np.array(color_mask) + im[idx[0], idx[1], :] *= 1.0 - alpha + im[idx[0], idx[1], :] += alpha * color_mask + return Image.fromarray(im.astype('uint8')) + + +def draw_box(im, np_boxes, labels): + """ + Args: + im (PIL.Image.Image): PIL image + np_boxes (np.ndarray): shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + labels (list): labels:['class1', ..., 'classn'] + Returns: + im (PIL.Image.Image): visualized image + """ + draw_thickness = min(im.size) // 320 + draw = ImageDraw.Draw(im) + clsid2color = {} + color_list = get_color_map_list(len(labels)) + + for dt in np_boxes: + clsid, bbox, score = int(dt[0]), dt[2:], dt[1] + xmin, ymin, xmax, ymax = bbox + w = xmax - xmin + h = ymax - ymin + if clsid not in clsid2color: + clsid2color[clsid] = color_list[clsid] + color = tuple(clsid2color[clsid]) + + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=draw_thickness, + fill=color) + + # draw label + text = "{} {:.4f}".format(labels[clsid], score) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + return im + + +def draw_segm(im, + np_segms, + np_label, + np_score, + labels, + threshold=0.5, + alpha=0.7): + """ + Draw segmentation on image. + """ + mask_color_id = 0 + w_ratio = .4 + color_list = get_color_map_list(len(labels)) + im = np.array(im).astype('float32') + clsid2color = {} + np_segms = np_segms.astype(np.uint8) + + for i in range(np_segms.shape[0]): + mask, score, clsid = np_segms[i], np_score[i], np_label[i] + 1 + if score < threshold: + continue + if clsid not in clsid2color: + clsid2color[clsid] = color_list[clsid] + color_mask = clsid2color[clsid] + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + color_mask = np.array(color_mask) + im[idx[0], idx[1], :] *= 1.0 - alpha + im[idx[0], idx[1], :] += alpha * color_mask + sum_x = np.sum(mask, axis=0) + x = np.where(sum_x > 0.5)[0] + sum_y = np.sum(mask, axis=1) + y = np.where(sum_y > 0.5)[0] + x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1] + cv2.rectangle(im, (x0, y0), (x1, y1), + tuple(color_mask.astype('int32').tolist()), 1) + bbox_text = '%s %.2f' % (labels[clsid], score) + t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0] + cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3), + tuple(color_mask.astype('int32').tolist()), -1) + cv2.putText( + im, + bbox_text, (x0, y0 - 2), + cv2.FONT_HERSHEY_SIMPLEX, + 0.3, (0, 0, 0), + 1, + lineType=cv2.LINE_AA) + + return Image.fromarray(im.astype('uint8')) + + +def load_predictor(model_dir, + run_mode='fluid', + batch_size=1, + use_gpu=False, + min_subgraph_size=3): + """set AnalysisConfig, generate AnalysisPredictor + Args: + model_dir (str): root path of __model__ and __params__ + use_gpu (bool): whether use gpu + Returns: + predictor (PaddlePredictor): AnalysisPredictor + Raises: + ValueError: predict by TensorRT need use_gpu == True. + """ + if not use_gpu and not run_mode == 'fluid': + raise ValueError( + "Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}" + .format(run_mode, use_gpu)) + if run_mode == 'trt_int8': + raise ValueError("TensorRT int8 mode is not supported now, " + "please use trt_fp32 or trt_fp16 instead.") + precision_map = { + 'trt_int8': fluid.core.AnalysisConfig.Precision.Int8, + 'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32, + 'trt_fp16': fluid.core.AnalysisConfig.Precision.Half + } + config = fluid.core.AnalysisConfig( + os.path.join(model_dir, '__model__'), + os.path.join(model_dir, '__params__')) + if use_gpu: + # initial GPU memory(M), device ID + config.enable_use_gpu(100, 0) + # optimize graph and fuse op + config.switch_ir_optim(True) + else: + config.disable_gpu() + + if run_mode in precision_map.keys(): + config.enable_tensorrt_engine( + workspace_size=1 << 10, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[run_mode], + use_static=False, + use_calib_mode=False) + + # disable print log when predict + config.disable_glog_info() + # enable shared memory + config.enable_memory_optim() + # disable feed, fetch OP, needed by zero_copy_run + config.switch_use_feed_fetch_ops(False) + predictor = fluid.core.create_paddle_predictor(config) + return predictor + + +def cv2_to_base64(image: np.ndarray): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + +def base64_to_cv2(b64str: str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data diff --git a/modules/image/instance_segmentation/solov2/example.png b/modules/image/instance_segmentation/solov2/example.png new file mode 100644 index 0000000000000000000000000000000000000000..4ece0a2df78452484829ce4e0fafd02341c5a3e5 Binary files /dev/null and b/modules/image/instance_segmentation/solov2/example.png differ diff --git a/modules/image/instance_segmentation/solov2/module.py b/modules/image/instance_segmentation/solov2/module.py new file mode 100644 index 0000000000000000000000000000000000000000..006ea94f9f350fa69987ef8f4d4661c7c9f50a5a --- /dev/null +++ b/modules/image/instance_segmentation/solov2/module.py @@ -0,0 +1,179 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +import base64 +from functools import reduce +from typing import Union + +import cv2 +import numpy as np +from paddlehub.module.module import moduleinfo, serving + +import solov2.processor as P +import solov2.data_feed as D + + +class Detector(object): + """ + Args: + min_subgraph_size (int): number of tensorRT graphs. + use_gpu (bool): whether use gpu + threshold (float): threshold to reserve the result for output. + """ + + def __init__(self, + min_subgraph_size: int = 60, + use_gpu=False, + threshold: float = 0.5): + + model_dir = os.path.join(self.directory, 'solov2_r50_fpn_1x') + self.predictor = D.load_predictor( + model_dir, + min_subgraph_size=min_subgraph_size, + use_gpu=use_gpu) + self.compose = [P.Resize(max_size=1333), + P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + P.Permute(), + P.PadStride(stride=32)] + + def transform(self, im: Union[str, np.ndarray]): + im, im_info = P.preprocess(im, self.compose) + inputs = D.create_inputs(im, im_info) + return inputs, im_info + + def postprocess(self, np_boxes: np.ndarray, np_masks: np.ndarray, im_info: dict, threshold: float = 0.5): + # postprocess output of predictor + results = {} + expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) + np_boxes = np_boxes[expect_boxes, :] + for box in np_boxes: + print('class_id:{:d}, confidence:{:.4f},' + 'left_top:[{:.2f},{:.2f}],' + ' right_bottom:[{:.2f},{:.2f}]'.format( + int(box[0]), box[1], box[2], box[3], box[4], box[5])) + results['boxes'] = np_boxes + if np_masks is not None: + np_masks = np_masks[expect_boxes, :, :, :] + results['masks'] = np_masks + return results + + def predict(self, + image: Union[str, np.ndarray], + threshold: float = 0.5): + ''' + Args: + image (str/np.ndarray): path of image/ np.ndarray read by cv2 + threshold (float): threshold of predicted box' score + Returns: + results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + MaskRCNN's results include 'masks': np.ndarray: + shape:[N, class_num, mask_resolution, mask_resolution] + ''' + inputs, im_info = self.transform(image) + np_boxes, np_masks = None, None + + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_tensor(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + self.predictor.zero_copy_run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_tensor(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + # do not perform postprocess in benchmark mode + results = [] + if reduce(lambda x, y: x * y, np_boxes.shape) < 6: + print('[WARNNING] No object detected.') + results = {'boxes': np.array([])} + else: + results = self.postprocess(np_boxes, np_masks, im_info, threshold=threshold) + return results + + +@moduleinfo( + name="solov2", + type="CV/instance_segmentation", + author="paddlepaddle", + author_email="", + summary="solov2 is a detection model, this module is trained with COCO dataset.", + version="1.0.0") +class DetectorSOLOv2(Detector): + """ + Args: + use_gpu (bool): whether use gpu + threshold (float): threshold to reserve the result for output. + """ + def __init__(self, + use_gpu: bool = False, + threshold: float = 0.5): + super(DetectorSOLOv2, self).__init__( + use_gpu=use_gpu, + threshold=threshold) + + def predict(self, + image: Union[str, np.ndarray], + threshold: float = 0.5, + visualization: bool = False, + save_dir: str = 'solov2_result'): + ''' + Args: + image (str/np.ndarray): path of image/ np.ndarray read by cv2 + threshold (float): threshold of predicted box' score + visualization (bool): Whether to save visualization result. + save_dir (str): save path. + + ''' + + inputs, im_info = self.transform(image) + np_label, np_score, np_segms = None, None, None + + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_tensor(input_names[i]) + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + self.predictor.zero_copy_run() + output_names = self.predictor.get_output_names() + np_label = self.predictor.get_output_tensor(output_names[ + 0]).copy_to_cpu() + np_score = self.predictor.get_output_tensor(output_names[ + 1]).copy_to_cpu() + np_segms = self.predictor.get_output_tensor(output_names[ + 2]).copy_to_cpu() + output = dict(segm=np_segms, label=np_label, score=np_score) + + if visualization: + if not os.path.exists(save_dir): + os.makedirs(save_dir) + image = D.visualize_box_mask(im=image, results=output) + name = str(time.time()) + '.png' + save_path = os.path.join(save_dir, name) + image.save(save_path) + return output + + @serving + def serving_method(self, images: list, **kwargs): + """ + Run as a service. + """ + images_decode = D.base64_to_cv2(images[0]) + results = self.predict(image=images_decode, **kwargs) + final = {} + final['segm'] = base64.b64encode(results['segm']).decode('utf8') + final['label'] = base64.b64encode(results['label']).decode('utf8') + final['score'] = base64.b64encode(results['score']).decode('utf8') + return final diff --git a/modules/image/instance_segmentation/solov2/processor.py b/modules/image/instance_segmentation/solov2/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a92d9f567fae260dc068eca072a70f21d844da --- /dev/null +++ b/modules/image/instance_segmentation/solov2/processor.py @@ -0,0 +1,243 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PIL import Image +import cv2 +import numpy as np + + +def decode_image(im_file, im_info): + """read rgb image + Args: + im_file (str/np.ndarray): path of image/ np.ndarray read by cv2 + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + if isinstance(im_file, str): + with open(im_file, 'rb') as f: + im_read = f.read() + data = np.frombuffer(im_read, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im_info['origin_shape'] = im.shape[:2] + im_info['resize_shape'] = im.shape[:2] + else: + im = im_file + #im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im_info['origin_shape'] = im.shape[:2] + im_info['resize_shape'] = im.shape[:2] + return im, im_info + + +class Resize(object): + """resize image by target_size and max_size + Args: + arch (str): model type + target_size (int): the target size of image + max_size (int): the max size of image + use_cv2 (bool): whether us cv2 + image_shape (list): input shape of model + interp (int): method of resize + """ + + def __init__(self, + target_size=800, + max_size=1333, + use_cv2=True, + image_shape=None, + interp=cv2.INTER_LINEAR, + resize_box=False): + self.target_size = target_size + self.max_size = max_size + self.image_shape = image_shape + self.use_cv2 = use_cv2 + self.interp = interp + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im_channel = im.shape[2] + im_scale_x, im_scale_y = self.generate_scale(im) + if self.use_cv2: + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + else: + resize_w = int(im_scale_x * float(im.shape[1])) + resize_h = int(im_scale_y * float(im.shape[0])) + if self.max_size != 0: + raise TypeError( + 'If you set max_size to cap the maximum size of image,' + 'please set use_cv2 to True to resize the image.') + im = im.astype('uint8') + im = Image.fromarray(im) + im = im.resize((int(resize_w), int(resize_h)), self.interp) + im = np.array(im) + + # padding im when image_shape fixed by infer_cfg.yml + if self.max_size != 0 and self.image_shape is not None: + padding_im = np.zeros( + (self.max_size, self.max_size, im_channel), dtype=np.float32) + im_h, im_w = im.shape[:2] + padding_im[:im_h, :im_w, :] = im + im = padding_im + + im_info['scale'] = [im_scale_x, im_scale_y] + im_info['resize_shape'] = im.shape[:2] + return im, im_info + + def generate_scale(self, im): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + im_c = im.shape[2] + if self.max_size != 0: + im_size_min = np.min(origin_shape[0:2]) + im_size_max = np.max(origin_shape[0:2]) + im_scale = float(self.target_size) / float(im_size_min) + if np.round(im_scale * im_size_max) > self.max_size: + im_scale = float(self.max_size) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + im_scale_x = float(self.target_size) / float(origin_shape[1]) + im_scale_y = float(self.target_size) / float(origin_shape[0]) + return im_scale_x, im_scale_y + + +class Normalize(object): + """normalize image + Args: + mean (list): im - mean + std (list): im / std + is_scale (bool): whether need im / 255 + is_channel_first (bool): if True: image shape is CHW, else: HWC + """ + + def __init__(self, mean, std, is_scale=True, is_channel_first=False): + self.mean = mean + self.std = std + self.is_scale = is_scale + self.is_channel_first = is_channel_first + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.astype(np.float32, copy=False) + if self.is_channel_first: + mean = np.array(self.mean)[:, np.newaxis, np.newaxis] + std = np.array(self.std)[:, np.newaxis, np.newaxis] + else: + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + if self.is_scale: + im = im / 255.0 + im -= mean + im /= std + return im, im_info + + +class Permute(object): + """permute image + Args: + to_bgr (bool): whether convert RGB to BGR + channel_first (bool): whether convert HWC to CHW + """ + + def __init__(self, to_bgr=False, channel_first=True): + self.to_bgr = to_bgr + self.channel_first = channel_first + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + if self.channel_first: + im = im.transpose((2, 0, 1)).copy() + if self.to_bgr: + im = im[[2, 1, 0], :, :] + return im, im_info + + +class PadStride(object): + """ padding image for model with FPN + Args: + stride (bool): model with FPN need image shape % stride == 0 + """ + + def __init__(self, stride=0): + self.coarsest_stride = stride + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + coarsest_stride = self.coarsest_stride + if coarsest_stride == 0: + return im + im_c, im_h, im_w = im.shape + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + im_info['pad_shape'] = padding_im.shape[1:] + return padding_im, im_info + + +def preprocess(im, preprocess_ops): + # process image by preprocess_ops + im_info = { + 'scale': [1., 1.], + 'origin_shape': None, + 'resize_shape': None, + 'pad_shape': None, + } + im, im_info = decode_image(im, im_info) + for operator in preprocess_ops: + im, im_info = operator(im, im_info) + im = np.array((im, )).astype('float32') + return im, im_info