未验证 提交 2ce0e07b 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update solov2 (#2015)

* update solov2

* fix typo
上级 71ee4cf6
...@@ -78,7 +78,7 @@ ...@@ -78,7 +78,7 @@
- res (dict): 识别结果,关键字有 'segm', 'label', 'score'对应的取值为: - res (dict): 识别结果,关键字有 'segm', 'label', 'score'对应的取值为:
- segm (np.ndarray): 实例分割结果,取值为0或1。0表示背景,1为实例; - segm (np.ndarray): 实例分割结果,取值为0或1。0表示背景,1为实例;
- label (list): 实例分割结果类别id; - label (list): 实例分割结果类别id;
- score (list):实例分割结果类别得分;s - score (list):实例分割结果类别得分;
## 四、服务部署 ## 四、服务部署
...@@ -147,8 +147,10 @@ ...@@ -147,8 +147,10 @@
初始发布 初始发布
* ```shell * 1.1.0
$ hub install hand_pose_localization==1.0.0
```
适配 PaddlePaddle 2.2.0+
* ```shell
$ hub install hand_pose_localization==1.1.0
```
\ No newline at end of file
...@@ -3,8 +3,8 @@ import base64 ...@@ -3,8 +3,8 @@ import base64
import cv2 import cv2
import numpy as np import numpy as np
from paddle.inference import Config, create_predictor, PrecisionType
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
import paddle.fluid as fluid
def create_inputs(im, im_info): def create_inputs(im, im_info):
...@@ -19,11 +19,14 @@ def create_inputs(im, im_info): ...@@ -19,11 +19,14 @@ def create_inputs(im, im_info):
inputs['image'] = im inputs['image'] = im
origin_shape = list(im_info['origin_shape']) origin_shape = list(im_info['origin_shape'])
resize_shape = list(im_info['resize_shape']) resize_shape = list(im_info['resize_shape'])
pad_shape = list(im_info['pad_shape']) if im_info['pad_shape'] is not None else list(im_info['resize_shape']) pad_shape = list(im_info['pad_shape']) if im_info[
'pad_shape'] is not None else list(im_info['resize_shape'])
scale_x, scale_y = im_info['scale'] scale_x, scale_y = im_info['scale']
scale = scale_x scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32') im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info inputs['im_info'] = im_info
inputs['scale_factor'] = np.array([scale_x, scale_x]).astype('float32').reshape(-1, 2)
inputs['im_shape'] = np.array(resize_shape).astype('float32').reshape(-1, 2)
return inputs return inputs
...@@ -42,28 +45,38 @@ def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0 ...@@ -42,28 +45,38 @@ def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0
im (PIL.Image.Image): visualized image im (PIL.Image.Image): visualized image
""" """
if not labels: if not labels:
labels = [ labels = ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', 'stop sign', 'parking meter',
'traffic light', 'fire', 'hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'teddy bear', 'hair drier', 'toothbrush' 'hair drier', 'toothbrush']
]
if isinstance(im, str): if isinstance(im, str):
im = Image.open(im).convert('RGB') im = Image.open(im).convert('RGB')
else: else:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = Image.fromarray(im) im = Image.fromarray(im)
if 'masks' in results and 'boxes' in results: if 'masks' in results and 'boxes' in results:
im = draw_mask(im, results['boxes'], results['masks'], labels, resolution=mask_resolution) im = draw_mask(
im,
results['boxes'],
results['masks'],
labels,
resolution=mask_resolution)
if 'boxes' in results: if 'boxes' in results:
im = draw_box(im, results['boxes'], labels) im = draw_box(im, results['boxes'], labels)
if 'segm' in results: if 'segm' in results:
im = draw_segm(im, results['segm'], results['label'], results['score'], labels, threshold=threshold) im = draw_segm(
im,
results['segm'],
results['label'],
results['score'],
labels,
threshold=threshold)
return im return im
...@@ -152,7 +165,8 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5): ...@@ -152,7 +165,8 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
y0 = min(max(ymin, 0), im_h) y0 = min(max(ymin, 0), im_h)
y1 = min(max(ymax + 1, 0), im_h) y1 = min(max(ymax + 1, 0), im_h)
im_mask = np.zeros((im_h, im_w), dtype=np.uint8) im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (x0 - xmin):(x1 - xmin)] im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
x0 - xmin):(x1 - xmin)]
if clsid not in clsid2color: if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid] clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid] color_mask = clsid2color[clsid]
...@@ -190,19 +204,28 @@ def draw_box(im, np_boxes, labels): ...@@ -190,19 +204,28 @@ def draw_box(im, np_boxes, labels):
color = tuple(clsid2color[clsid]) color = tuple(clsid2color[clsid])
# draw bbox # draw bbox
draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)], draw.line(
width=draw_thickness, [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
fill=color) (xmin, ymin)],
width=draw_thickness,
fill=color)
# draw label # draw label
text = "{} {:.4f}".format(labels[clsid], score) text = "{} {:.4f}".format(labels[clsid], score)
tw, th = draw.textsize(text) tw, th = draw.textsize(text)
draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im return im
def draw_segm(im, np_segms, np_label, np_score, labels, threshold=0.5, alpha=0.7): def draw_segm(im,
np_segms,
np_label,
np_score,
labels,
threshold=0.5,
alpha=0.7):
""" """
Draw segmentation on image. Draw segmentation on image.
""" """
...@@ -231,17 +254,28 @@ def draw_segm(im, np_segms, np_label, np_score, labels, threshold=0.5, alpha=0.7 ...@@ -231,17 +254,28 @@ def draw_segm(im, np_segms, np_label, np_score, labels, threshold=0.5, alpha=0.7
sum_y = np.sum(mask, axis=1) sum_y = np.sum(mask, axis=1)
y = np.where(sum_y > 0.5)[0] y = np.where(sum_y > 0.5)[0]
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1] x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
cv2.rectangle(im, (x0, y0), (x1, y1), tuple(color_mask.astype('int32').tolist()), 1) cv2.rectangle(im, (x0, y0), (x1, y1),
tuple(color_mask.astype('int32').tolist()), 1)
bbox_text = '%s %.2f' % (labels[clsid], score) bbox_text = '%s %.2f' % (labels[clsid], score)
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0] t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3), tuple(color_mask.astype('int32').tolist()), cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3),
-1) tuple(color_mask.astype('int32').tolist()), -1)
cv2.putText(im, bbox_text, (x0, y0 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), 1, lineType=cv2.LINE_AA) cv2.putText(
im,
bbox_text, (x0, y0 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.3, (0, 0, 0),
1,
lineType=cv2.LINE_AA)
return Image.fromarray(im.astype('uint8')) return Image.fromarray(im.astype('uint8'))
def load_predictor(model_dir, run_mode='fluid', batch_size=1, use_gpu=False, min_subgraph_size=3): def load_predictor(model_dir,
run_mode='paddle',
batch_size=1,
use_gpu=False,
min_subgraph_size=3):
"""set AnalysisConfig, generate AnalysisPredictor """set AnalysisConfig, generate AnalysisPredictor
Args: Args:
model_dir (str): root path of __model__ and __params__ model_dir (str): root path of __model__ and __params__
...@@ -251,17 +285,19 @@ def load_predictor(model_dir, run_mode='fluid', batch_size=1, use_gpu=False, min ...@@ -251,17 +285,19 @@ def load_predictor(model_dir, run_mode='fluid', batch_size=1, use_gpu=False, min
Raises: Raises:
ValueError: predict by TensorRT need use_gpu == True. ValueError: predict by TensorRT need use_gpu == True.
""" """
if not use_gpu and not run_mode == 'fluid': if not use_gpu and not run_mode == 'paddle':
raise ValueError("Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}".format( raise ValueError(
run_mode, use_gpu)) "Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
.format(run_mode, use_gpu))
if run_mode == 'trt_int8': if run_mode == 'trt_int8':
raise ValueError("TensorRT int8 mode is not supported now, " "please use trt_fp32 or trt_fp16 instead.") raise ValueError("TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead.")
precision_map = { precision_map = {
'trt_int8': fluid.core.AnalysisConfig.Precision.Int8, 'trt_int8': PrecisionType.Int8,
'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32, 'trt_fp32': PrecisionType.Float32,
'trt_fp16': fluid.core.AnalysisConfig.Precision.Half 'trt_fp16': PrecisionType.Half
} }
config = fluid.core.AnalysisConfig(os.path.join(model_dir, '__model__'), os.path.join(model_dir, '__params__')) config = Config(model_dir+'.pdmodel', model_dir+'.pdiparams')
if use_gpu: if use_gpu:
# initial GPU memory(M), device ID # initial GPU memory(M), device ID
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, 0)
...@@ -285,7 +321,7 @@ def load_predictor(model_dir, run_mode='fluid', batch_size=1, use_gpu=False, min ...@@ -285,7 +321,7 @@ def load_predictor(model_dir, run_mode='fluid', batch_size=1, use_gpu=False, min
config.enable_memory_optim() config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run # disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
predictor = fluid.core.create_paddle_predictor(config) predictor = create_predictor(config)
return predictor return predictor
......
...@@ -11,13 +11,13 @@ ...@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import time import time
import base64 import base64
from functools import reduce from functools import reduce
from typing import Union from typing import Union
import cv2
import numpy as np import numpy as np
from paddlehub.module.module import moduleinfo, serving from paddlehub.module.module import moduleinfo, serving
...@@ -25,7 +25,7 @@ import solov2.processor as P ...@@ -25,7 +25,7 @@ import solov2.processor as P
import solov2.data_feed as D import solov2.data_feed as D
class Detector(object): class Detector:
""" """
Args: Args:
min_subgraph_size (int): number of tensorRT graphs. min_subgraph_size (int): number of tensorRT graphs.
...@@ -33,23 +33,26 @@ class Detector(object): ...@@ -33,23 +33,26 @@ class Detector(object):
threshold (float): threshold to reserve the result for output. threshold (float): threshold to reserve the result for output.
""" """
def __init__(self, min_subgraph_size: int = 60, use_gpu=False, threshold: float = 0.5): def __init__(self,
min_subgraph_size: int = 60,
use_gpu=False):
model_dir = os.path.join(self.directory, 'solov2_r50_fpn_1x') self.default_pretrained_model_path = os.path.join(self.directory, 'solov2_r50_fpn_1x', 'model')
self.predictor = D.load_predictor(model_dir, min_subgraph_size=min_subgraph_size, use_gpu=use_gpu) self.predictor = D.load_predictor(
self.compose = [ self.default_pretrained_model_path,
P.Resize(max_size=1333), min_subgraph_size=min_subgraph_size,
P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), use_gpu=use_gpu)
P.Permute(), self.compose = [P.Resize(max_size=1333),
P.PadStride(stride=32) P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
] P.Permute(),
P.PadStride(stride=32)]
def transform(self, im: Union[str, np.ndarray]): def transform(self, im: Union[str, np.ndarray]):
im, im_info = P.preprocess(im, self.compose) im, im_info = P.preprocess(im, self.compose)
inputs = D.create_inputs(im, im_info) inputs = D.create_inputs(im, im_info)
return inputs, im_info return inputs, im_info
def postprocess(self, np_boxes: np.ndarray, np_masks: np.ndarray, im_info: dict, threshold: float = 0.5): def postprocess(self, np_boxes: np.ndarray, np_masks: np.ndarray, threshold: float = 0.5):
# postprocess output of predictor # postprocess output of predictor
results = {} results = {}
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
...@@ -57,14 +60,17 @@ class Detector(object): ...@@ -57,14 +60,17 @@ class Detector(object):
for box in np_boxes: for box in np_boxes:
print('class_id:{:d}, confidence:{:.4f},' print('class_id:{:d}, confidence:{:.4f},'
'left_top:[{:.2f},{:.2f}],' 'left_top:[{:.2f},{:.2f}],'
' right_bottom:[{:.2f},{:.2f}]'.format(int(box[0]), box[1], box[2], box[3], box[4], box[5])) ' right_bottom:[{:.2f},{:.2f}]'.format(
int(box[0]), box[1], box[2], box[3], box[4], box[5]))
results['boxes'] = np_boxes results['boxes'] = np_boxes
if np_masks is not None: if np_masks is not None:
np_masks = np_masks[expect_boxes, :, :, :] np_masks = np_masks[expect_boxes, :, :, :]
results['masks'] = np_masks results['masks'] = np_masks
return results return results
def predict(self, image: Union[str, np.ndarray], threshold: float = 0.5): def predict(self,
image: Union[str, np.ndarray],
threshold: float = 0.5):
''' '''
Args: Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2 image (str/np.ndarray): path of image/ np.ndarray read by cv2
...@@ -80,12 +86,12 @@ class Detector(object): ...@@ -80,12 +86,12 @@ class Detector(object):
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for i in range(len(input_names)): for i in range(len(input_names)):
input_tensor = self.predictor.get_input_tensor(input_names[i]) input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]]) input_tensor.copy_from_cpu(inputs[input_names[i]])
self.predictor.zero_copy_run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_tensor(output_names[0]) boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu() np_boxes = boxes_tensor.copy_to_cpu()
# do not perform postprocess in benchmark mode # do not perform postprocess in benchmark mode
results = [] results = []
...@@ -103,16 +109,18 @@ class Detector(object): ...@@ -103,16 +109,18 @@ class Detector(object):
author="paddlepaddle", author="paddlepaddle",
author_email="", author_email="",
summary="solov2 is a detection model, this module is trained with COCO dataset.", summary="solov2 is a detection model, this module is trained with COCO dataset.",
version="1.0.0") version="1.1.0")
class DetectorSOLOv2(Detector): class DetectorSOLOv2(Detector):
""" """
Args: Args:
use_gpu (bool): whether use gpu use_gpu (bool): whether use gpu
threshold (float): threshold to reserve the result for output. threshold (float): threshold to reserve the result for output.
""" """
def __init__(self,
use_gpu: bool = False):
super(DetectorSOLOv2, self).__init__(
use_gpu=use_gpu)
def __init__(self, use_gpu: bool = False, threshold: float = 0.5):
super(DetectorSOLOv2, self).__init__(use_gpu=use_gpu, threshold=threshold)
def predict(self, def predict(self,
image: Union[str, np.ndarray], image: Union[str, np.ndarray],
...@@ -125,7 +133,7 @@ class DetectorSOLOv2(Detector): ...@@ -125,7 +133,7 @@ class DetectorSOLOv2(Detector):
threshold (float): threshold of predicted box' score threshold (float): threshold of predicted box' score
visualization (bool): Whether to save visualization result. visualization (bool): Whether to save visualization result.
save_dir (str): save path. save_dir (str): save path.
''' '''
inputs, im_info = self.transform(image) inputs, im_info = self.transform(image)
...@@ -133,20 +141,23 @@ class DetectorSOLOv2(Detector): ...@@ -133,20 +141,23 @@ class DetectorSOLOv2(Detector):
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
for i in range(len(input_names)): for i in range(len(input_names)):
input_tensor = self.predictor.get_input_tensor(input_names[i]) input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[input_names[i]]) input_tensor.copy_from_cpu(inputs[input_names[i]])
self.predictor.zero_copy_run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
np_label = self.predictor.get_output_tensor(output_names[0]).copy_to_cpu() np_label = self.predictor.get_output_handle(output_names[
np_score = self.predictor.get_output_tensor(output_names[1]).copy_to_cpu() 1]).copy_to_cpu()
np_segms = self.predictor.get_output_tensor(output_names[2]).copy_to_cpu() np_score = self.predictor.get_output_handle(output_names[
2]).copy_to_cpu()
np_segms = self.predictor.get_output_handle(output_names[
3]).copy_to_cpu()
output = dict(segm=np_segms, label=np_label, score=np_score) output = dict(segm=np_segms, label=np_label, score=np_score)
if visualization: if visualization:
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
image = D.visualize_box_mask(im=image, results=output) image = D.visualize_box_mask(im=image, results=output, threshold=threshold)
name = str(time.time()) + '.png' name = str(time.time()) + '.png'
save_path = os.path.join(save_dir, name) save_path = os.path.join(save_dir, name)
image.save(save_path) image.save(save_path)
...@@ -163,4 +174,4 @@ class DetectorSOLOv2(Detector): ...@@ -163,4 +174,4 @@ class DetectorSOLOv2(Detector):
final['segm'] = base64.b64encode(results['segm']).decode('utf8') final['segm'] = base64.b64encode(results['segm']).decode('utf8')
final['label'] = base64.b64encode(results['label']).decode('utf8') final['label'] = base64.b64encode(results['label']).decode('utf8')
final['score'] = base64.b64encode(results['score']).decode('utf8') final['score'] = base64.b64encode(results['score']).decode('utf8')
return final return final
\ No newline at end of file
...@@ -78,13 +78,20 @@ class Resize(object): ...@@ -78,13 +78,20 @@ class Resize(object):
im_channel = im.shape[2] im_channel = im.shape[2]
im_scale_x, im_scale_y = self.generate_scale(im) im_scale_x, im_scale_y = self.generate_scale(im)
if self.use_cv2: if self.use_cv2:
im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp) im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
else: else:
resize_w = int(im_scale_x * float(im.shape[1])) resize_w = int(im_scale_x * float(im.shape[1]))
resize_h = int(im_scale_y * float(im.shape[0])) resize_h = int(im_scale_y * float(im.shape[0]))
if self.max_size != 0: if self.max_size != 0:
raise TypeError('If you set max_size to cap the maximum size of image,' raise TypeError(
'please set use_cv2 to True to resize the image.') 'If you set max_size to cap the maximum size of image,'
'please set use_cv2 to True to resize the image.')
im = im.astype('uint8') im = im.astype('uint8')
im = Image.fromarray(im) im = Image.fromarray(im)
im = im.resize((int(resize_w), int(resize_h)), self.interp) im = im.resize((int(resize_w), int(resize_h)), self.interp)
...@@ -92,7 +99,8 @@ class Resize(object): ...@@ -92,7 +99,8 @@ class Resize(object):
# padding im when image_shape fixed by infer_cfg.yml # padding im when image_shape fixed by infer_cfg.yml
if self.max_size != 0 and self.image_shape is not None: if self.max_size != 0 and self.image_shape is not None:
padding_im = np.zeros((self.max_size, self.max_size, im_channel), dtype=np.float32) padding_im = np.zeros(
(self.max_size, self.max_size, im_channel), dtype=np.float32)
im_h, im_w = im.shape[:2] im_h, im_w = im.shape[:2]
padding_im[:im_h, :im_w, :] = im padding_im[:im_h, :im_w, :] = im
im = padding_im im = padding_im
...@@ -232,4 +240,4 @@ def preprocess(im, preprocess_ops): ...@@ -232,4 +240,4 @@ def preprocess(im, preprocess_ops):
for operator in preprocess_ops: for operator in preprocess_ops:
im, im_info = operator(im, im_info) im, im_info = operator(im, im_info)
im = np.array((im, )).astype('float32') im = np.array((im, )).astype('float32')
return im, im_info return im, im_info
\ No newline at end of file
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/7799a8ccc5f6471b9d56fb6eff94f82a08b70ca2c7594d3f99877e366c0a2619'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="solov2")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('solov2_result')
def test_predict1(self):
results = self.module.predict(
image='tests/test.jpg',
visualization=False
)
segm = results['segm']
label = results['label']
score = results['score']
self.assertIsInstance(segm, np.ndarray)
self.assertIsInstance(label, np.ndarray)
self.assertIsInstance(score, np.ndarray)
def test_predict2(self):
results = self.module.predict(
image=cv2.imread('tests/test.jpg'),
visualization=False
)
segm = results['segm']
label = results['label']
score = results['score']
self.assertIsInstance(segm, np.ndarray)
self.assertIsInstance(label, np.ndarray)
self.assertIsInstance(score, np.ndarray)
def test_predict3(self):
results = self.module.predict(
image=cv2.imread('tests/test.jpg'),
visualization=True
)
segm = results['segm']
label = results['label']
score = results['score']
self.assertIsInstance(segm, np.ndarray)
self.assertIsInstance(label, np.ndarray)
self.assertIsInstance(score, np.ndarray)
def test_predict4(self):
module = hub.Module(name="solov2", use_gpu=True)
results = module.predict(
image=cv2.imread('tests/test.jpg'),
visualization=True
)
segm = results['segm']
label = results['label']
score = results['score']
self.assertIsInstance(segm, np.ndarray)
self.assertIsInstance(label, np.ndarray)
self.assertIsInstance(score, np.ndarray)
def test_predict5(self):
self.assertRaises(
FileNotFoundError,
self.module.predict,
image='no.jpg'
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册