diff --git a/modules/image/instance_segmentation/solov2/README.md b/modules/image/instance_segmentation/solov2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..60c8d811758d4830fd381e45a6b1476dacd892e3
--- /dev/null
+++ b/modules/image/instance_segmentation/solov2/README.md
@@ -0,0 +1,119 @@
+## 模型概述
+
+solov2是基于'SOLOv2: Dynamic, Faster and Stronger'实现的快速实例分割的模型。该模型基于SOLOV1, 并且针对mask的检测效果和运行效率进行改进,在实例分割任务中表现优秀。相对语义分割,实例分割需要标注出图上同一物体的不同个体。solov2实例分割效果如下:
+
+
+
![](example.png)
+
+
+
+## API
+
+```python
+def predict(self,
+ image: Union[str, np.ndarray],
+ threshold: float = 0.5,
+ visualization: bool = False,
+ save_dir: str = 'solov2_result'):
+```
+
+预测API,实例分割。
+
+**参数**
+
+* image (Union\[str, np.ndarray\]): 图片路径或者图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
+* threshold (float): 检测模型输出结果中,预测得分低于该阈值的框将被滤除,默认值为0.5;
+* visualization (bool): 是否将可视化图片保存;
+* save_dir (str): 保存图片到路径, 默认为"solov2_result"。
+
+**返回**
+
+* res (dict): 识别结果,关键字有 'segm', 'label', 'score'对应的取值为:
+ * segm (np.ndarray): 实例分割结果,取值为0或1。0表示背景,1为实例;
+ * label (list): 实例分割结果类别id;
+ * score (list):实例分割结果类别得分;
+
+
+## 代码示例
+
+```python
+import cv2
+import paddlehub as hub
+
+img = cv2.imread('/PATH/TO/IMAGE')
+model = hub.Module(name='solov2', use_gpu=False)
+output = model.predict(image=img,visualization=False)
+```
+
+## 服务部署
+
+PaddleHub Serving可以部署一个实例分割的在线服务。
+
+## 第一步:启动PaddleHub Serving
+
+运行启动命令:
+
+```shell
+$ hub serving start -m solov2
+```
+
+默认端口号为8866。
+
+**NOTE:** 如使用GPU预测,则需要在启动服务之前,设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
+
+## 第二步:发送预测请求
+
+配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+```python
+import requests
+import json
+import cv2
+import base64
+
+import numpy as np
+
+def cv2_to_base64(image):
+ data = cv2.imencode('.jpg', image)[1]
+ return base64.b64encode(data.tostring()).decode('utf8')
+
+def base64_to_cv2(b64str):
+ data = base64.b64decode(b64str.encode('utf8'))
+ data = np.fromstring(data, np.uint8)
+ data = cv2.imdecode(data, cv2.IMREAD_COLOR)
+ return data
+
+# 发送HTTP请求
+
+org_im = cv2.imread('/PATH/TO/IMAGE')
+h, w, c = org_im.shape
+data = {'images':[cv2_to_base64(org_im)]}
+headers = {"Content-type": "application/json"}
+url = "http://127.0.0.1:8866/predict/solov2"
+r = requests.post(url=url, headers=headers, data=json.dumps(data))
+
+seg = base64.b64decode(r.json()["results"]['segm'].encode('utf8'))
+seg = np.fromstring(seg, dtype=np.int32).reshape((-1, h, w))
+
+label = base64.b64decode(r.json()["results"]['label'].encode('utf8'))
+label = np.fromstring(label, dtype=np.int64)
+
+score = base64.b64decode(r.json()["results"]['score'].encode('utf8'))
+score = np.fromstring(score, dtype=np.float32)
+
+print('seg', seg)
+print('label', label)
+print('score', score)
+```
+
+### 查看代码
+
+https://github.com/PaddlePaddle/PaddleDetection
+
+
+### 依赖
+
+paddlepaddle >= 2.0.0
+
+paddlehub >= 2.0.0
+
diff --git a/modules/image/instance_segmentation/solov2/data_feed.py b/modules/image/instance_segmentation/solov2/data_feed.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e2a28b5dd171cb427428376de85c293d0590304
--- /dev/null
+++ b/modules/image/instance_segmentation/solov2/data_feed.py
@@ -0,0 +1,335 @@
+import os
+import base64
+
+import cv2
+import numpy as np
+from PIL import Image, ImageDraw
+import paddle.fluid as fluid
+
+
+def create_inputs(im, im_info):
+ """generate input for different model type
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ inputs (dict): input of model
+ """
+ inputs = {}
+ inputs['image'] = im
+ origin_shape = list(im_info['origin_shape'])
+ resize_shape = list(im_info['resize_shape'])
+ pad_shape = list(im_info['pad_shape']) if im_info['pad_shape'] is not None else list(im_info['resize_shape'])
+ scale_x, scale_y = im_info['scale']
+ scale = scale_x
+ im_info = np.array([resize_shape + [scale]]).astype('float32')
+ inputs['im_info'] = im_info
+ return inputs
+
+
+def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0.5):
+ """
+ Args:
+ im (str/np.ndarray): path of image/np.ndarray read by cv2
+ results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ MaskRCNN's results include 'masks': np.ndarray:
+ shape:[N, class_num, mask_resolution, mask_resolution]
+ labels (list): labels:['class1', ..., 'classn']
+ mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
+ threshold (float): Threshold of score.
+ Returns:
+ im (PIL.Image.Image): visualized image
+ """
+ if not labels:
+ labels = ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', 'stop sign', 'parking meter',
+ 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
+ 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle',
+ 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
+ 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
+ 'hair drier', 'toothbrush']
+ if isinstance(im, str):
+ im = Image.open(im).convert('RGB')
+ else:
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ im = Image.fromarray(im)
+ if 'masks' in results and 'boxes' in results:
+ im = draw_mask(
+ im,
+ results['boxes'],
+ results['masks'],
+ labels,
+ resolution=mask_resolution)
+ if 'boxes' in results:
+ im = draw_box(im, results['boxes'], labels)
+ if 'segm' in results:
+ im = draw_segm(
+ im,
+ results['segm'],
+ results['label'],
+ results['score'],
+ labels,
+ threshold=threshold)
+ return im
+
+
+def get_color_map_list(num_classes):
+ """
+ Args:
+ num_classes (int): number of class
+ Returns:
+ color_map (list): RGB color list
+ """
+ color_map = num_classes * [0, 0, 0]
+ for i in range(0, num_classes):
+ j = 0
+ lab = i
+ while lab:
+ color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
+ color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
+ color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
+ j += 1
+ lab >>= 3
+ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
+ return color_map
+
+
+def expand_boxes(boxes, scale=0.0):
+ """
+ Args:
+ boxes (np.ndarray): shape:[N,4], N:number of box,
+ matix element:[x_min, y_min, x_max, y_max]
+ scale (float): scale of boxes
+ Returns:
+ boxes_exp (np.ndarray): expanded boxes
+ """
+ w_half = (boxes[:, 2] - boxes[:, 0]) * .5
+ h_half = (boxes[:, 3] - boxes[:, 1]) * .5
+ x_c = (boxes[:, 2] + boxes[:, 0]) * .5
+ y_c = (boxes[:, 3] + boxes[:, 1]) * .5
+ w_half *= scale
+ h_half *= scale
+ boxes_exp = np.zeros(boxes.shape)
+ boxes_exp[:, 0] = x_c - w_half
+ boxes_exp[:, 2] = x_c + w_half
+ boxes_exp[:, 1] = y_c - h_half
+ boxes_exp[:, 3] = y_c + h_half
+ return boxes_exp
+
+
+def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
+ """
+ Args:
+ im (PIL.Image.Image): PIL image
+ np_boxes (np.ndarray): shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ np_masks (np.ndarray): shape:[N, class_num, resolution, resolution]
+ labels (list): labels:['class1', ..., 'classn']
+ resolution (int): shape of a mask is:[resolution, resolution]
+ threshold (float): threshold of mask
+ Returns:
+ im (PIL.Image.Image): visualized image
+ """
+ color_list = get_color_map_list(len(labels))
+ scale = (resolution + 2.0) / resolution
+ im_w, im_h = im.size
+ w_ratio = 0.4
+ alpha = 0.7
+ im = np.array(im).astype('float32')
+ rects = np_boxes[:, 2:]
+ expand_rects = expand_boxes(rects, scale)
+ expand_rects = expand_rects.astype(np.int32)
+ clsid_scores = np_boxes[:, 0:2]
+ padded_mask = np.zeros((resolution + 2, resolution + 2), dtype=np.float32)
+ clsid2color = {}
+ for idx in range(len(np_boxes)):
+ clsid, score = clsid_scores[idx].tolist()
+ clsid = int(clsid)
+ xmin, ymin, xmax, ymax = expand_rects[idx].tolist()
+ w = xmax - xmin + 1
+ h = ymax - ymin + 1
+ w = np.maximum(w, 1)
+ h = np.maximum(h, 1)
+ padded_mask[1:-1, 1:-1] = np_masks[idx, int(clsid), :, :]
+ resized_mask = cv2.resize(padded_mask, (w, h))
+ resized_mask = np.array(resized_mask > threshold, dtype=np.uint8)
+ x0 = min(max(xmin, 0), im_w)
+ x1 = min(max(xmax + 1, 0), im_w)
+ y0 = min(max(ymin, 0), im_h)
+ y1 = min(max(ymax + 1, 0), im_h)
+ im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
+ im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (x0 - xmin):(x1 - xmin)]
+ if clsid not in clsid2color:
+ clsid2color[clsid] = color_list[clsid]
+ color_mask = clsid2color[clsid]
+ for c in range(3):
+ color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+ idx = np.nonzero(im_mask)
+ color_mask = np.array(color_mask)
+ im[idx[0], idx[1], :] *= 1.0 - alpha
+ im[idx[0], idx[1], :] += alpha * color_mask
+ return Image.fromarray(im.astype('uint8'))
+
+
+def draw_box(im, np_boxes, labels):
+ """
+ Args:
+ im (PIL.Image.Image): PIL image
+ np_boxes (np.ndarray): shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ labels (list): labels:['class1', ..., 'classn']
+ Returns:
+ im (PIL.Image.Image): visualized image
+ """
+ draw_thickness = min(im.size) // 320
+ draw = ImageDraw.Draw(im)
+ clsid2color = {}
+ color_list = get_color_map_list(len(labels))
+
+ for dt in np_boxes:
+ clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
+ xmin, ymin, xmax, ymax = bbox
+ w = xmax - xmin
+ h = ymax - ymin
+ if clsid not in clsid2color:
+ clsid2color[clsid] = color_list[clsid]
+ color = tuple(clsid2color[clsid])
+
+ # draw bbox
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
+ (xmin, ymin)],
+ width=draw_thickness,
+ fill=color)
+
+ # draw label
+ text = "{} {:.4f}".format(labels[clsid], score)
+ tw, th = draw.textsize(text)
+ draw.rectangle(
+ [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
+ draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
+ return im
+
+
+def draw_segm(im,
+ np_segms,
+ np_label,
+ np_score,
+ labels,
+ threshold=0.5,
+ alpha=0.7):
+ """
+ Draw segmentation on image.
+ """
+ mask_color_id = 0
+ w_ratio = .4
+ color_list = get_color_map_list(len(labels))
+ im = np.array(im).astype('float32')
+ clsid2color = {}
+ np_segms = np_segms.astype(np.uint8)
+
+ for i in range(np_segms.shape[0]):
+ mask, score, clsid = np_segms[i], np_score[i], np_label[i] + 1
+ if score < threshold:
+ continue
+ if clsid not in clsid2color:
+ clsid2color[clsid] = color_list[clsid]
+ color_mask = clsid2color[clsid]
+ for c in range(3):
+ color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+ idx = np.nonzero(mask)
+ color_mask = np.array(color_mask)
+ im[idx[0], idx[1], :] *= 1.0 - alpha
+ im[idx[0], idx[1], :] += alpha * color_mask
+ sum_x = np.sum(mask, axis=0)
+ x = np.where(sum_x > 0.5)[0]
+ sum_y = np.sum(mask, axis=1)
+ y = np.where(sum_y > 0.5)[0]
+ x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
+ cv2.rectangle(im, (x0, y0), (x1, y1),
+ tuple(color_mask.astype('int32').tolist()), 1)
+ bbox_text = '%s %.2f' % (labels[clsid], score)
+ t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
+ cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3),
+ tuple(color_mask.astype('int32').tolist()), -1)
+ cv2.putText(
+ im,
+ bbox_text, (x0, y0 - 2),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.3, (0, 0, 0),
+ 1,
+ lineType=cv2.LINE_AA)
+
+ return Image.fromarray(im.astype('uint8'))
+
+
+def load_predictor(model_dir,
+ run_mode='fluid',
+ batch_size=1,
+ use_gpu=False,
+ min_subgraph_size=3):
+ """set AnalysisConfig, generate AnalysisPredictor
+ Args:
+ model_dir (str): root path of __model__ and __params__
+ use_gpu (bool): whether use gpu
+ Returns:
+ predictor (PaddlePredictor): AnalysisPredictor
+ Raises:
+ ValueError: predict by TensorRT need use_gpu == True.
+ """
+ if not use_gpu and not run_mode == 'fluid':
+ raise ValueError(
+ "Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
+ .format(run_mode, use_gpu))
+ if run_mode == 'trt_int8':
+ raise ValueError("TensorRT int8 mode is not supported now, "
+ "please use trt_fp32 or trt_fp16 instead.")
+ precision_map = {
+ 'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
+ 'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
+ 'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
+ }
+ config = fluid.core.AnalysisConfig(
+ os.path.join(model_dir, '__model__'),
+ os.path.join(model_dir, '__params__'))
+ if use_gpu:
+ # initial GPU memory(M), device ID
+ config.enable_use_gpu(100, 0)
+ # optimize graph and fuse op
+ config.switch_ir_optim(True)
+ else:
+ config.disable_gpu()
+
+ if run_mode in precision_map.keys():
+ config.enable_tensorrt_engine(
+ workspace_size=1 << 10,
+ max_batch_size=batch_size,
+ min_subgraph_size=min_subgraph_size,
+ precision_mode=precision_map[run_mode],
+ use_static=False,
+ use_calib_mode=False)
+
+ # disable print log when predict
+ config.disable_glog_info()
+ # enable shared memory
+ config.enable_memory_optim()
+ # disable feed, fetch OP, needed by zero_copy_run
+ config.switch_use_feed_fetch_ops(False)
+ predictor = fluid.core.create_paddle_predictor(config)
+ return predictor
+
+
+def cv2_to_base64(image: np.ndarray):
+ data = cv2.imencode('.jpg', image)[1]
+ return base64.b64encode(data.tostring()).decode('utf8')
+
+
+def base64_to_cv2(b64str: str):
+ data = base64.b64decode(b64str.encode('utf8'))
+ data = np.fromstring(data, np.uint8)
+ data = cv2.imdecode(data, cv2.IMREAD_COLOR)
+ return data
diff --git a/modules/image/instance_segmentation/solov2/example.png b/modules/image/instance_segmentation/solov2/example.png
new file mode 100644
index 0000000000000000000000000000000000000000..4ece0a2df78452484829ce4e0fafd02341c5a3e5
Binary files /dev/null and b/modules/image/instance_segmentation/solov2/example.png differ
diff --git a/modules/image/instance_segmentation/solov2/module.py b/modules/image/instance_segmentation/solov2/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..006ea94f9f350fa69987ef8f4d4661c7c9f50a5a
--- /dev/null
+++ b/modules/image/instance_segmentation/solov2/module.py
@@ -0,0 +1,179 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import time
+import base64
+from functools import reduce
+from typing import Union
+
+import cv2
+import numpy as np
+from paddlehub.module.module import moduleinfo, serving
+
+import solov2.processor as P
+import solov2.data_feed as D
+
+
+class Detector(object):
+ """
+ Args:
+ min_subgraph_size (int): number of tensorRT graphs.
+ use_gpu (bool): whether use gpu
+ threshold (float): threshold to reserve the result for output.
+ """
+
+ def __init__(self,
+ min_subgraph_size: int = 60,
+ use_gpu=False,
+ threshold: float = 0.5):
+
+ model_dir = os.path.join(self.directory, 'solov2_r50_fpn_1x')
+ self.predictor = D.load_predictor(
+ model_dir,
+ min_subgraph_size=min_subgraph_size,
+ use_gpu=use_gpu)
+ self.compose = [P.Resize(max_size=1333),
+ P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ P.Permute(),
+ P.PadStride(stride=32)]
+
+ def transform(self, im: Union[str, np.ndarray]):
+ im, im_info = P.preprocess(im, self.compose)
+ inputs = D.create_inputs(im, im_info)
+ return inputs, im_info
+
+ def postprocess(self, np_boxes: np.ndarray, np_masks: np.ndarray, im_info: dict, threshold: float = 0.5):
+ # postprocess output of predictor
+ results = {}
+ expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
+ np_boxes = np_boxes[expect_boxes, :]
+ for box in np_boxes:
+ print('class_id:{:d}, confidence:{:.4f},'
+ 'left_top:[{:.2f},{:.2f}],'
+ ' right_bottom:[{:.2f},{:.2f}]'.format(
+ int(box[0]), box[1], box[2], box[3], box[4], box[5]))
+ results['boxes'] = np_boxes
+ if np_masks is not None:
+ np_masks = np_masks[expect_boxes, :, :, :]
+ results['masks'] = np_masks
+ return results
+
+ def predict(self,
+ image: Union[str, np.ndarray],
+ threshold: float = 0.5):
+ '''
+ Args:
+ image (str/np.ndarray): path of image/ np.ndarray read by cv2
+ threshold (float): threshold of predicted box' score
+ Returns:
+ results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ MaskRCNN's results include 'masks': np.ndarray:
+ shape:[N, class_num, mask_resolution, mask_resolution]
+ '''
+ inputs, im_info = self.transform(image)
+ np_boxes, np_masks = None, None
+
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_tensor(input_names[i])
+ input_tensor.copy_from_cpu(inputs[input_names[i]])
+
+ self.predictor.zero_copy_run()
+ output_names = self.predictor.get_output_names()
+ boxes_tensor = self.predictor.get_output_tensor(output_names[0])
+ np_boxes = boxes_tensor.copy_to_cpu()
+ # do not perform postprocess in benchmark mode
+ results = []
+ if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
+ print('[WARNNING] No object detected.')
+ results = {'boxes': np.array([])}
+ else:
+ results = self.postprocess(np_boxes, np_masks, im_info, threshold=threshold)
+ return results
+
+
+@moduleinfo(
+ name="solov2",
+ type="CV/instance_segmentation",
+ author="paddlepaddle",
+ author_email="",
+ summary="solov2 is a detection model, this module is trained with COCO dataset.",
+ version="1.0.0")
+class DetectorSOLOv2(Detector):
+ """
+ Args:
+ use_gpu (bool): whether use gpu
+ threshold (float): threshold to reserve the result for output.
+ """
+ def __init__(self,
+ use_gpu: bool = False,
+ threshold: float = 0.5):
+ super(DetectorSOLOv2, self).__init__(
+ use_gpu=use_gpu,
+ threshold=threshold)
+
+ def predict(self,
+ image: Union[str, np.ndarray],
+ threshold: float = 0.5,
+ visualization: bool = False,
+ save_dir: str = 'solov2_result'):
+ '''
+ Args:
+ image (str/np.ndarray): path of image/ np.ndarray read by cv2
+ threshold (float): threshold of predicted box' score
+ visualization (bool): Whether to save visualization result.
+ save_dir (str): save path.
+
+ '''
+
+ inputs, im_info = self.transform(image)
+ np_label, np_score, np_segms = None, None, None
+
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_tensor(input_names[i])
+ input_tensor.copy_from_cpu(inputs[input_names[i]])
+
+ self.predictor.zero_copy_run()
+ output_names = self.predictor.get_output_names()
+ np_label = self.predictor.get_output_tensor(output_names[
+ 0]).copy_to_cpu()
+ np_score = self.predictor.get_output_tensor(output_names[
+ 1]).copy_to_cpu()
+ np_segms = self.predictor.get_output_tensor(output_names[
+ 2]).copy_to_cpu()
+ output = dict(segm=np_segms, label=np_label, score=np_score)
+
+ if visualization:
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+ image = D.visualize_box_mask(im=image, results=output)
+ name = str(time.time()) + '.png'
+ save_path = os.path.join(save_dir, name)
+ image.save(save_path)
+ return output
+
+ @serving
+ def serving_method(self, images: list, **kwargs):
+ """
+ Run as a service.
+ """
+ images_decode = D.base64_to_cv2(images[0])
+ results = self.predict(image=images_decode, **kwargs)
+ final = {}
+ final['segm'] = base64.b64encode(results['segm']).decode('utf8')
+ final['label'] = base64.b64encode(results['label']).decode('utf8')
+ final['score'] = base64.b64encode(results['score']).decode('utf8')
+ return final
diff --git a/modules/image/instance_segmentation/solov2/processor.py b/modules/image/instance_segmentation/solov2/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6a92d9f567fae260dc068eca072a70f21d844da
--- /dev/null
+++ b/modules/image/instance_segmentation/solov2/processor.py
@@ -0,0 +1,243 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from PIL import Image
+import cv2
+import numpy as np
+
+
+def decode_image(im_file, im_info):
+ """read rgb image
+ Args:
+ im_file (str/np.ndarray): path of image/ np.ndarray read by cv2
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ if isinstance(im_file, str):
+ with open(im_file, 'rb') as f:
+ im_read = f.read()
+ data = np.frombuffer(im_read, dtype='uint8')
+ im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ im_info['origin_shape'] = im.shape[:2]
+ im_info['resize_shape'] = im.shape[:2]
+ else:
+ im = im_file
+ #im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ im_info['origin_shape'] = im.shape[:2]
+ im_info['resize_shape'] = im.shape[:2]
+ return im, im_info
+
+
+class Resize(object):
+ """resize image by target_size and max_size
+ Args:
+ arch (str): model type
+ target_size (int): the target size of image
+ max_size (int): the max size of image
+ use_cv2 (bool): whether us cv2
+ image_shape (list): input shape of model
+ interp (int): method of resize
+ """
+
+ def __init__(self,
+ target_size=800,
+ max_size=1333,
+ use_cv2=True,
+ image_shape=None,
+ interp=cv2.INTER_LINEAR,
+ resize_box=False):
+ self.target_size = target_size
+ self.max_size = max_size
+ self.image_shape = image_shape
+ self.use_cv2 = use_cv2
+ self.interp = interp
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im_channel = im.shape[2]
+ im_scale_x, im_scale_y = self.generate_scale(im)
+ if self.use_cv2:
+ im = cv2.resize(
+ im,
+ None,
+ None,
+ fx=im_scale_x,
+ fy=im_scale_y,
+ interpolation=self.interp)
+ else:
+ resize_w = int(im_scale_x * float(im.shape[1]))
+ resize_h = int(im_scale_y * float(im.shape[0]))
+ if self.max_size != 0:
+ raise TypeError(
+ 'If you set max_size to cap the maximum size of image,'
+ 'please set use_cv2 to True to resize the image.')
+ im = im.astype('uint8')
+ im = Image.fromarray(im)
+ im = im.resize((int(resize_w), int(resize_h)), self.interp)
+ im = np.array(im)
+
+ # padding im when image_shape fixed by infer_cfg.yml
+ if self.max_size != 0 and self.image_shape is not None:
+ padding_im = np.zeros(
+ (self.max_size, self.max_size, im_channel), dtype=np.float32)
+ im_h, im_w = im.shape[:2]
+ padding_im[:im_h, :im_w, :] = im
+ im = padding_im
+
+ im_info['scale'] = [im_scale_x, im_scale_y]
+ im_info['resize_shape'] = im.shape[:2]
+ return im, im_info
+
+ def generate_scale(self, im):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ Returns:
+ im_scale_x: the resize ratio of X
+ im_scale_y: the resize ratio of Y
+ """
+ origin_shape = im.shape[:2]
+ im_c = im.shape[2]
+ if self.max_size != 0:
+ im_size_min = np.min(origin_shape[0:2])
+ im_size_max = np.max(origin_shape[0:2])
+ im_scale = float(self.target_size) / float(im_size_min)
+ if np.round(im_scale * im_size_max) > self.max_size:
+ im_scale = float(self.max_size) / float(im_size_max)
+ im_scale_x = im_scale
+ im_scale_y = im_scale
+ else:
+ im_scale_x = float(self.target_size) / float(origin_shape[1])
+ im_scale_y = float(self.target_size) / float(origin_shape[0])
+ return im_scale_x, im_scale_y
+
+
+class Normalize(object):
+ """normalize image
+ Args:
+ mean (list): im - mean
+ std (list): im / std
+ is_scale (bool): whether need im / 255
+ is_channel_first (bool): if True: image shape is CHW, else: HWC
+ """
+
+ def __init__(self, mean, std, is_scale=True, is_channel_first=False):
+ self.mean = mean
+ self.std = std
+ self.is_scale = is_scale
+ self.is_channel_first = is_channel_first
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.astype(np.float32, copy=False)
+ if self.is_channel_first:
+ mean = np.array(self.mean)[:, np.newaxis, np.newaxis]
+ std = np.array(self.std)[:, np.newaxis, np.newaxis]
+ else:
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
+ if self.is_scale:
+ im = im / 255.0
+ im -= mean
+ im /= std
+ return im, im_info
+
+
+class Permute(object):
+ """permute image
+ Args:
+ to_bgr (bool): whether convert RGB to BGR
+ channel_first (bool): whether convert HWC to CHW
+ """
+
+ def __init__(self, to_bgr=False, channel_first=True):
+ self.to_bgr = to_bgr
+ self.channel_first = channel_first
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ if self.channel_first:
+ im = im.transpose((2, 0, 1)).copy()
+ if self.to_bgr:
+ im = im[[2, 1, 0], :, :]
+ return im, im_info
+
+
+class PadStride(object):
+ """ padding image for model with FPN
+ Args:
+ stride (bool): model with FPN need image shape % stride == 0
+ """
+
+ def __init__(self, stride=0):
+ self.coarsest_stride = stride
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ coarsest_stride = self.coarsest_stride
+ if coarsest_stride == 0:
+ return im
+ im_c, im_h, im_w = im.shape
+ pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
+ pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
+ padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
+ padding_im[:, :im_h, :im_w] = im
+ im_info['pad_shape'] = padding_im.shape[1:]
+ return padding_im, im_info
+
+
+def preprocess(im, preprocess_ops):
+ # process image by preprocess_ops
+ im_info = {
+ 'scale': [1., 1.],
+ 'origin_shape': None,
+ 'resize_shape': None,
+ 'pad_shape': None,
+ }
+ im, im_info = decode_image(im, im_info)
+ for operator in preprocess_ops:
+ im, im_info = operator(im, im_info)
+ im = np.array((im, )).astype('float32')
+ return im, im_info