未验证 提交 87fb873c 编写于 作者: H haoyuying 提交者: GitHub

Solov2

上级 ac26a0e0
## 模型概述
solov2是基于'SOLOv2: Dynamic, Faster and Stronger'实现的快速实例分割的模型。该模型基于SOLOV1, 并且针对mask的检测效果和运行效率进行改进,在实例分割任务中表现优秀。相对语义分割,实例分割需要标注出图上同一物体的不同个体。solov2实例分割效果如下:
<div align="center">
<img src="example.png" width = "642" height = "400" />
</div>
## API
```python
def predict(self,
image: Union[str, np.ndarray],
threshold: float = 0.5,
visualization: bool = False,
save_dir: str = 'solov2_result'):
```
预测API,实例分割。
**参数**
* image (Union\[str, np.ndarray\]): 图片路径或者图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* threshold (float): 检测模型输出结果中,预测得分低于该阈值的框将被滤除,默认值为0.5;
* visualization (bool): 是否将可视化图片保存;
* save_dir (str): 保存图片到路径, 默认为"solov2_result"。
**返回**
* res (dict): 识别结果,关键字有 'segm', 'label', 'score'对应的取值为:
* segm (np.ndarray): 实例分割结果,取值为0或1。0表示背景,1为实例;
* label (list): 实例分割结果类别id;
* score (list):实例分割结果类别得分;
## 代码示例
```python
import cv2
import paddlehub as hub
img = cv2.imread('/PATH/TO/IMAGE')
model = hub.Module(name='solov2', use_gpu=False)
output = model.predict(image=img,visualization=False)
```
## 服务部署
PaddleHub Serving可以部署一个实例分割的在线服务。
## 第一步:启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m solov2
```
默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
## 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import cv2
import base64
import numpy as np
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# 发送HTTP请求
org_im = cv2.imread('/PATH/TO/IMAGE')
h, w, c = org_im.shape
data = {'images':[cv2_to_base64(org_im)]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/solov2"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
seg = base64.b64decode(r.json()["results"]['segm'].encode('utf8'))
seg = np.fromstring(seg, dtype=np.int32).reshape((-1, h, w))
label = base64.b64decode(r.json()["results"]['label'].encode('utf8'))
label = np.fromstring(label, dtype=np.int64)
score = base64.b64decode(r.json()["results"]['score'].encode('utf8'))
score = np.fromstring(score, dtype=np.float32)
print('seg', seg)
print('label', label)
print('score', score)
```
### 查看代码
https://github.com/PaddlePaddle/PaddleDetection
### 依赖
paddlepaddle >= 2.0.0
paddlehub >= 2.0.0
import os
import base64
import cv2
import numpy as np
from PIL import Image, ImageDraw
import paddle.fluid as fluid
def create_inputs(im, im_info):
"""generate input for different model type
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
inputs (dict): input of model
"""
inputs = {}
inputs['image'] = im
origin_shape = list(im_info['origin_shape'])
resize_shape = list(im_info['resize_shape'])
pad_shape = list(im_info['pad_shape']) if im_info['pad_shape'] is not None else list(im_info['resize_shape'])
scale_x, scale_y = im_info['scale']
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info
return inputs
def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0.5):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
labels (list): labels:['class1', ..., 'classn']
mask_resolution (int): shape of a mask is:[mask_resolution, mask_resolution]
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
"""
if not labels:
labels = ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', 'stop sign', 'parking meter',
'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle',
'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush']
if isinstance(im, str):
im = Image.open(im).convert('RGB')
else:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = Image.fromarray(im)
if 'masks' in results and 'boxes' in results:
im = draw_mask(
im,
results['boxes'],
results['masks'],
labels,
resolution=mask_resolution)
if 'boxes' in results:
im = draw_box(im, results['boxes'], labels)
if 'segm' in results:
im = draw_segm(
im,
results['segm'],
results['label'],
results['score'],
labels,
threshold=threshold)
return im
def get_color_map_list(num_classes):
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def expand_boxes(boxes, scale=0.0):
"""
Args:
boxes (np.ndarray): shape:[N,4], N:number of box,
matix element:[x_min, y_min, x_max, y_max]
scale (float): scale of boxes
Returns:
boxes_exp (np.ndarray): expanded boxes
"""
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half *= scale
h_half *= scale
boxes_exp = np.zeros(boxes.shape)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp
def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
np_masks (np.ndarray): shape:[N, class_num, resolution, resolution]
labels (list): labels:['class1', ..., 'classn']
resolution (int): shape of a mask is:[resolution, resolution]
threshold (float): threshold of mask
Returns:
im (PIL.Image.Image): visualized image
"""
color_list = get_color_map_list(len(labels))
scale = (resolution + 2.0) / resolution
im_w, im_h = im.size
w_ratio = 0.4
alpha = 0.7
im = np.array(im).astype('float32')
rects = np_boxes[:, 2:]
expand_rects = expand_boxes(rects, scale)
expand_rects = expand_rects.astype(np.int32)
clsid_scores = np_boxes[:, 0:2]
padded_mask = np.zeros((resolution + 2, resolution + 2), dtype=np.float32)
clsid2color = {}
for idx in range(len(np_boxes)):
clsid, score = clsid_scores[idx].tolist()
clsid = int(clsid)
xmin, ymin, xmax, ymax = expand_rects[idx].tolist()
w = xmax - xmin + 1
h = ymax - ymin + 1
w = np.maximum(w, 1)
h = np.maximum(h, 1)
padded_mask[1:-1, 1:-1] = np_masks[idx, int(clsid), :, :]
resized_mask = cv2.resize(padded_mask, (w, h))
resized_mask = np.array(resized_mask > threshold, dtype=np.uint8)
x0 = min(max(xmin, 0), im_w)
x1 = min(max(xmax + 1, 0), im_w)
y0 = min(max(ymin, 0), im_h)
y1 = min(max(ymax + 1, 0), im_h)
im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (x0 - xmin):(x1 - xmin)]
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid]
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(im_mask)
color_mask = np.array(color_mask)
im[idx[0], idx[1], :] *= 1.0 - alpha
im[idx[0], idx[1], :] += alpha * color_mask
return Image.fromarray(im.astype('uint8'))
def draw_box(im, np_boxes, labels):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
Returns:
im (PIL.Image.Image): visualized image
"""
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
clsid2color = {}
color_list = get_color_map_list(len(labels))
for dt in np_boxes:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
xmin, ymin, xmax, ymax = bbox
w = xmax - xmin
h = ymax - ymin
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color = tuple(clsid2color[clsid])
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=draw_thickness,
fill=color)
# draw label
text = "{} {:.4f}".format(labels[clsid], score)
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def draw_segm(im,
np_segms,
np_label,
np_score,
labels,
threshold=0.5,
alpha=0.7):
"""
Draw segmentation on image.
"""
mask_color_id = 0
w_ratio = .4
color_list = get_color_map_list(len(labels))
im = np.array(im).astype('float32')
clsid2color = {}
np_segms = np_segms.astype(np.uint8)
for i in range(np_segms.shape[0]):
mask, score, clsid = np_segms[i], np_score[i], np_label[i] + 1
if score < threshold:
continue
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid]
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
color_mask = np.array(color_mask)
im[idx[0], idx[1], :] *= 1.0 - alpha
im[idx[0], idx[1], :] += alpha * color_mask
sum_x = np.sum(mask, axis=0)
x = np.where(sum_x > 0.5)[0]
sum_y = np.sum(mask, axis=1)
y = np.where(sum_y > 0.5)[0]
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
cv2.rectangle(im, (x0, y0), (x1, y1),
tuple(color_mask.astype('int32').tolist()), 1)
bbox_text = '%s %.2f' % (labels[clsid], score)
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3),
tuple(color_mask.astype('int32').tolist()), -1)
cv2.putText(
im,
bbox_text, (x0, y0 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.3, (0, 0, 0),
1,
lineType=cv2.LINE_AA)
return Image.fromarray(im.astype('uint8'))
def load_predictor(model_dir,
run_mode='fluid',
batch_size=1,
use_gpu=False,
min_subgraph_size=3):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
use_gpu (bool): whether use gpu
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need use_gpu == True.
"""
if not use_gpu and not run_mode == 'fluid':
raise ValueError(
"Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
.format(run_mode, use_gpu))
if run_mode == 'trt_int8':
raise ValueError("TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead.")
precision_map = {
'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
}
config = fluid.core.AnalysisConfig(
os.path.join(model_dir, '__model__'),
os.path.join(model_dir, '__params__'))
if use_gpu:
# initial GPU memory(M), device ID
config.enable_use_gpu(100, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
else:
config.disable_gpu()
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=1 << 10,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=False)
# disable print log when predict
config.disable_glog_info()
# enable shared memory
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
predictor = fluid.core.create_paddle_predictor(config)
return predictor
def cv2_to_base64(image: np.ndarray):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str: str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import base64
from functools import reduce
from typing import Union
import cv2
import numpy as np
from paddlehub.module.module import moduleinfo, serving
import solov2.processor as P
import solov2.data_feed as D
class Detector(object):
"""
Args:
min_subgraph_size (int): number of tensorRT graphs.
use_gpu (bool): whether use gpu
threshold (float): threshold to reserve the result for output.
"""
def __init__(self,
min_subgraph_size: int = 60,
use_gpu=False,
threshold: float = 0.5):
model_dir = os.path.join(self.directory, 'solov2_r50_fpn_1x')
self.predictor = D.load_predictor(
model_dir,
min_subgraph_size=min_subgraph_size,
use_gpu=use_gpu)
self.compose = [P.Resize(max_size=1333),
P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
P.Permute(),
P.PadStride(stride=32)]
def transform(self, im: Union[str, np.ndarray]):
im, im_info = P.preprocess(im, self.compose)
inputs = D.create_inputs(im, im_info)
return inputs, im_info
def postprocess(self, np_boxes: np.ndarray, np_masks: np.ndarray, im_info: dict, threshold: float = 0.5):
# postprocess output of predictor
results = {}
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
for box in np_boxes:
print('class_id:{:d}, confidence:{:.4f},'
'left_top:[{:.2f},{:.2f}],'
' right_bottom:[{:.2f},{:.2f}]'.format(
int(box[0]), box[1], box[2], box[3], box[4], box[5]))
results['boxes'] = np_boxes
if np_masks is not None:
np_masks = np_masks[expect_boxes, :, :, :]
results['masks'] = np_masks
return results
def predict(self,
image: Union[str, np.ndarray],
threshold: float = 0.5):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, class_num, mask_resolution, mask_resolution]
'''
inputs, im_info = self.transform(image)
np_boxes, np_masks = None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_tensor(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
self.predictor.zero_copy_run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
# do not perform postprocess in benchmark mode
results = []
if reduce(lambda x, y: x * y, np_boxes.shape) < 6:
print('[WARNNING] No object detected.')
results = {'boxes': np.array([])}
else:
results = self.postprocess(np_boxes, np_masks, im_info, threshold=threshold)
return results
@moduleinfo(
name="solov2",
type="CV/instance_segmentation",
author="paddlepaddle",
author_email="",
summary="solov2 is a detection model, this module is trained with COCO dataset.",
version="1.0.0")
class DetectorSOLOv2(Detector):
"""
Args:
use_gpu (bool): whether use gpu
threshold (float): threshold to reserve the result for output.
"""
def __init__(self,
use_gpu: bool = False,
threshold: float = 0.5):
super(DetectorSOLOv2, self).__init__(
use_gpu=use_gpu,
threshold=threshold)
def predict(self,
image: Union[str, np.ndarray],
threshold: float = 0.5,
visualization: bool = False,
save_dir: str = 'solov2_result'):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
threshold (float): threshold of predicted box' score
visualization (bool): Whether to save visualization result.
save_dir (str): save path.
'''
inputs, im_info = self.transform(image)
np_label, np_score, np_segms = None, None, None
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_tensor(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]])
self.predictor.zero_copy_run()
output_names = self.predictor.get_output_names()
np_label = self.predictor.get_output_tensor(output_names[
0]).copy_to_cpu()
np_score = self.predictor.get_output_tensor(output_names[
1]).copy_to_cpu()
np_segms = self.predictor.get_output_tensor(output_names[
2]).copy_to_cpu()
output = dict(segm=np_segms, label=np_label, score=np_score)
if visualization:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image = D.visualize_box_mask(im=image, results=output)
name = str(time.time()) + '.png'
save_path = os.path.join(save_dir, name)
image.save(save_path)
return output
@serving
def serving_method(self, images: list, **kwargs):
"""
Run as a service.
"""
images_decode = D.base64_to_cv2(images[0])
results = self.predict(image=images_decode, **kwargs)
final = {}
final['segm'] = base64.b64encode(results['segm']).decode('utf8')
final['label'] = base64.b64encode(results['label']).decode('utf8')
final['score'] = base64.b64encode(results['score']).decode('utf8')
return final
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from PIL import Image
import cv2
import numpy as np
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str/np.ndarray): path of image/ np.ndarray read by cv2
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_info['origin_shape'] = im.shape[:2]
im_info['resize_shape'] = im.shape[:2]
else:
im = im_file
#im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_info['origin_shape'] = im.shape[:2]
im_info['resize_shape'] = im.shape[:2]
return im, im_info
class Resize(object):
"""resize image by target_size and max_size
Args:
arch (str): model type
target_size (int): the target size of image
max_size (int): the max size of image
use_cv2 (bool): whether us cv2
image_shape (list): input shape of model
interp (int): method of resize
"""
def __init__(self,
target_size=800,
max_size=1333,
use_cv2=True,
image_shape=None,
interp=cv2.INTER_LINEAR,
resize_box=False):
self.target_size = target_size
self.max_size = max_size
self.image_shape = image_shape
self.use_cv2 = use_cv2
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im_channel = im.shape[2]
im_scale_x, im_scale_y = self.generate_scale(im)
if self.use_cv2:
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
else:
resize_w = int(im_scale_x * float(im.shape[1]))
resize_h = int(im_scale_y * float(im.shape[0]))
if self.max_size != 0:
raise TypeError(
'If you set max_size to cap the maximum size of image,'
'please set use_cv2 to True to resize the image.')
im = im.astype('uint8')
im = Image.fromarray(im)
im = im.resize((int(resize_w), int(resize_h)), self.interp)
im = np.array(im)
# padding im when image_shape fixed by infer_cfg.yml
if self.max_size != 0 and self.image_shape is not None:
padding_im = np.zeros(
(self.max_size, self.max_size, im_channel), dtype=np.float32)
im_h, im_w = im.shape[:2]
padding_im[:im_h, :im_w, :] = im
im = padding_im
im_info['scale'] = [im_scale_x, im_scale_y]
im_info['resize_shape'] = im.shape[:2]
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.max_size != 0:
im_size_min = np.min(origin_shape[0:2])
im_size_max = np.max(origin_shape[0:2])
im_scale = float(self.target_size) / float(im_size_min)
if np.round(im_scale * im_size_max) > self.max_size:
im_scale = float(self.max_size) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
im_scale_x = float(self.target_size) / float(origin_shape[1])
im_scale_y = float(self.target_size) / float(origin_shape[0])
return im_scale_x, im_scale_y
class Normalize(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
"""
def __init__(self, mean, std, is_scale=True, is_channel_first=False):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.is_channel_first = is_channel_first
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_channel_first:
mean = np.array(self.mean)[:, np.newaxis, np.newaxis]
std = np.array(self.std)[:, np.newaxis, np.newaxis]
else:
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, to_bgr=False, channel_first=True):
self.to_bgr = to_bgr
self.channel_first = channel_first
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if self.channel_first:
im = im.transpose((2, 0, 1)).copy()
if self.to_bgr:
im = im[[2, 1, 0], :, :]
return im, im_info
class PadStride(object):
""" padding image for model with FPN
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride == 0:
return im
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
im_info['pad_shape'] = padding_im.shape[1:]
return padding_im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale': [1., 1.],
'origin_shape': None,
'resize_shape': None,
'pad_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
im = np.array((im, )).astype('float32')
return im, im_info
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册