未验证 提交 c68b6d5d 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

add gradio app in solov2 (#2162)

* add gradio app (test)

* update

* update solov2
上级 e3ee127f
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
- ### 模型介绍 - ### 模型介绍
- solov2是基于"SOLOv2: Dynamic, Faster and Stronger"实现的快速实例分割的模型。该模型基于SOLOV1, 并且针对mask的检测效果和运行效率进行改进,在实例分割任务中表现优秀。相对语义分割,实例分割需要标注出图上同一物体的不同个体。 - solov2是基于"SOLOv2: Dynamic, Faster and Stronger"实现的快速实例分割的模型。该模型基于SOLOV1, 并且针对mask的检测效果和运行效率进行改进,在实例分割任务中表现优秀。相对语义分割,实例分割需要标注出图上同一物体的不同个体。
## 二、安装 ## 二、安装
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
- paddlepaddle >= 2.0.0 - paddlepaddle >= 2.0.0
- paddlehub >= 2.0.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) - paddlehub >= 2.0.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装 - ### 2、安装
- ```shell - ```shell
...@@ -45,7 +45,7 @@ ...@@ -45,7 +45,7 @@
- ```shell - ```shell
$ hub run openpose_hands_estimation --input_path "/PATH/TO/IMAGE" $ hub run openpose_hands_estimation --input_path "/PATH/TO/IMAGE"
``` ```
- 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) - 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例 - ### 2、预测代码示例
...@@ -53,12 +53,12 @@ ...@@ -53,12 +53,12 @@
- ```python - ```python
import cv2 import cv2
import paddlehub as hub import paddlehub as hub
img = cv2.imread('/PATH/TO/IMAGE') img = cv2.imread('/PATH/TO/IMAGE')
model = hub.Module(name='solov2', use_gpu=False) model = hub.Module(name='solov2', use_gpu=False)
output = model.predict(image=img, visualization=False) output = model.predict(image=img, visualization=False)
``` ```
- ### 3、API - ### 3、API
- ```python - ```python
...@@ -67,7 +67,7 @@ ...@@ -67,7 +67,7 @@
visualization: bool = False, visualization: bool = False,
save_dir: str = 'solov2_result'): save_dir: str = 'solov2_result'):
``` ```
- 预测API,实例分割。 - 预测API,实例分割。
- **参数** - **参数**
- image (Union[str, np.ndarray]): 图片路径或者图片数据,ndarray.shape 为 [H, W, C],BGR格式; - image (Union[str, np.ndarray]): 图片路径或者图片数据,ndarray.shape 为 [H, W, C],BGR格式;
...@@ -105,42 +105,45 @@ ...@@ -105,42 +105,45 @@
import json import json
import cv2 import cv2
import base64 import base64
import numpy as np import numpy as np
def cv2_to_base64(image): def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1] data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8') return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str): def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR) data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data return data
# 发送HTTP请求 # 发送HTTP请求
org_im = cv2.imread('/PATH/TO/IMAGE') org_im = cv2.imread('/PATH/TO/IMAGE')
h, w, c = org_im.shape h, w, c = org_im.shape
data = {'images':[cv2_to_base64(org_im)]} data = {'images':[cv2_to_base64(org_im)]}
headers = {"Content-type": "application/json"} headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/solov2" url = "http://127.0.0.1:8866/predict/solov2"
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(url=url, headers=headers, data=json.dumps(data))
seg = base64.b64decode(r.json()["results"]['segm'].encode('utf8')) seg = base64.b64decode(r.json()["results"]['segm'].encode('utf8'))
seg = np.fromstring(seg, dtype=np.int32).reshape((-1, h, w)) seg = np.fromstring(seg, dtype=np.int32).reshape((-1, h, w))
label = base64.b64decode(r.json()["results"]['label'].encode('utf8')) label = base64.b64decode(r.json()["results"]['label'].encode('utf8'))
label = np.fromstring(label, dtype=np.int64) label = np.fromstring(label, dtype=np.int64)
score = base64.b64decode(r.json()["results"]['score'].encode('utf8')) score = base64.b64decode(r.json()["results"]['score'].encode('utf8'))
score = np.fromstring(score, dtype=np.float32) score = np.fromstring(score, dtype=np.float32)
print('seg', seg) print('seg', seg)
print('label', label) print('label', label)
print('score', score) print('score', score)
``` ```
- ### Gradio App 支持
从 PaddleHub 2.3.1 开始支持使用链接 http://127.0.0.1:8866/gradio/solov2 在浏览器中访问 solov2 的 Gradio App。
## 五、更新历史 ## 五、更新历史
* 1.0.0 * 1.0.0
...@@ -151,6 +154,10 @@ ...@@ -151,6 +154,10 @@
适配 PaddlePaddle 2.2.0+ 适配 PaddlePaddle 2.2.0+
* ```shell * 1.2.0
$ hub install hand_pose_localization==1.1.0
``` 添加 Gradio APP 支持
\ No newline at end of file
- ```shell
$ hub install solov2==1.2.0
```
import os
import base64 import base64
import cv2 import cv2
import numpy as np import numpy as np
from paddle.inference import Config, create_predictor, PrecisionType from paddle.inference import Config
from PIL import Image, ImageDraw from paddle.inference import create_predictor
from paddle.inference import PrecisionType
from PIL import Image
from PIL import ImageDraw
def create_inputs(im, im_info): def create_inputs(im, im_info):
...@@ -19,8 +21,7 @@ def create_inputs(im, im_info): ...@@ -19,8 +21,7 @@ def create_inputs(im, im_info):
inputs['image'] = im inputs['image'] = im
origin_shape = list(im_info['origin_shape']) origin_shape = list(im_info['origin_shape'])
resize_shape = list(im_info['resize_shape']) resize_shape = list(im_info['resize_shape'])
pad_shape = list(im_info['pad_shape']) if im_info[ pad_shape = list(im_info['pad_shape']) if im_info['pad_shape'] is not None else list(im_info['resize_shape'])
'pad_shape'] is not None else list(im_info['resize_shape'])
scale_x, scale_y = im_info['scale'] scale_x, scale_y = im_info['scale']
scale = scale_x scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32') im_info = np.array([resize_shape + [scale]]).astype('float32')
...@@ -45,38 +46,28 @@ def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0 ...@@ -45,38 +46,28 @@ def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0
im (PIL.Image.Image): visualized image im (PIL.Image.Image): visualized image
""" """
if not labels: if not labels:
labels = ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', labels = [
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', 'stop sign', 'parking meter', 'background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'traffic light', 'fire', 'hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'hair drier', 'toothbrush'] 'teddy bear', 'hair drier', 'toothbrush'
]
if isinstance(im, str): if isinstance(im, str):
im = Image.open(im).convert('RGB') im = Image.open(im).convert('RGB')
else: else:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = Image.fromarray(im) im = Image.fromarray(im)
if 'masks' in results and 'boxes' in results: if 'masks' in results and 'boxes' in results:
im = draw_mask( im = draw_mask(im, results['boxes'], results['masks'], labels, resolution=mask_resolution)
im,
results['boxes'],
results['masks'],
labels,
resolution=mask_resolution)
if 'boxes' in results: if 'boxes' in results:
im = draw_box(im, results['boxes'], labels) im = draw_box(im, results['boxes'], labels)
if 'segm' in results: if 'segm' in results:
im = draw_segm( im = draw_segm(im, results['segm'], results['label'], results['score'], labels, threshold=threshold)
im,
results['segm'],
results['label'],
results['score'],
labels,
threshold=threshold)
return im return im
...@@ -165,8 +156,7 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5): ...@@ -165,8 +156,7 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
y0 = min(max(ymin, 0), im_h) y0 = min(max(ymin, 0), im_h)
y1 = min(max(ymax + 1, 0), im_h) y1 = min(max(ymax + 1, 0), im_h)
im_mask = np.zeros((im_h, im_w), dtype=np.uint8) im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), ( im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (x0 - xmin):(x1 - xmin)]
x0 - xmin):(x1 - xmin)]
if clsid not in clsid2color: if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid] clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid] color_mask = clsid2color[clsid]
...@@ -204,28 +194,19 @@ def draw_box(im, np_boxes, labels): ...@@ -204,28 +194,19 @@ def draw_box(im, np_boxes, labels):
color = tuple(clsid2color[clsid]) color = tuple(clsid2color[clsid])
# draw bbox # draw bbox
draw.line( draw.line([(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), width=draw_thickness,
(xmin, ymin)], fill=color)
width=draw_thickness,
fill=color)
# draw label # draw label
text = "{} {:.4f}".format(labels[clsid], score) text = "{} {:.4f}".format(labels[clsid], score)
tw, th = draw.textsize(text) tw, th = draw.textsize(text)
draw.rectangle( draw.rectangle([(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im return im
def draw_segm(im, def draw_segm(im, np_segms, np_label, np_score, labels, threshold=0.5, alpha=0.7):
np_segms,
np_label,
np_score,
labels,
threshold=0.5,
alpha=0.7):
""" """
Draw segmentation on image. Draw segmentation on image.
""" """
...@@ -254,28 +235,17 @@ def draw_segm(im, ...@@ -254,28 +235,17 @@ def draw_segm(im,
sum_y = np.sum(mask, axis=1) sum_y = np.sum(mask, axis=1)
y = np.where(sum_y > 0.5)[0] y = np.where(sum_y > 0.5)[0]
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1] x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
cv2.rectangle(im, (x0, y0), (x1, y1), cv2.rectangle(im, (x0, y0), (x1, y1), tuple(color_mask.astype('int32').tolist()), 1)
tuple(color_mask.astype('int32').tolist()), 1)
bbox_text = '%s %.2f' % (labels[clsid], score) bbox_text = '%s %.2f' % (labels[clsid], score)
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0] t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3), cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3), tuple(color_mask.astype('int32').tolist()),
tuple(color_mask.astype('int32').tolist()), -1) -1)
cv2.putText( cv2.putText(im, bbox_text, (x0, y0 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), 1, lineType=cv2.LINE_AA)
im,
bbox_text, (x0, y0 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.3, (0, 0, 0),
1,
lineType=cv2.LINE_AA)
return Image.fromarray(im.astype('uint8')) return Image.fromarray(im.astype('uint8'))
def load_predictor(model_dir, def load_predictor(model_dir, run_mode='paddle', batch_size=1, use_gpu=False, min_subgraph_size=3):
run_mode='paddle',
batch_size=1,
use_gpu=False,
min_subgraph_size=3):
"""set AnalysisConfig, generate AnalysisPredictor """set AnalysisConfig, generate AnalysisPredictor
Args: Args:
model_dir (str): root path of __model__ and __params__ model_dir (str): root path of __model__ and __params__
...@@ -286,18 +256,13 @@ def load_predictor(model_dir, ...@@ -286,18 +256,13 @@ def load_predictor(model_dir,
ValueError: predict by TensorRT need use_gpu == True. ValueError: predict by TensorRT need use_gpu == True.
""" """
if not use_gpu and not run_mode == 'paddle': if not use_gpu and not run_mode == 'paddle':
raise ValueError( raise ValueError("Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}".format(
"Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}" run_mode, use_gpu))
.format(run_mode, use_gpu))
if run_mode == 'trt_int8': if run_mode == 'trt_int8':
raise ValueError("TensorRT int8 mode is not supported now, " raise ValueError("TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead.") "please use trt_fp32 or trt_fp16 instead.")
precision_map = { precision_map = {'trt_int8': PrecisionType.Int8, 'trt_fp32': PrecisionType.Float32, 'trt_fp16': PrecisionType.Half}
'trt_int8': PrecisionType.Int8, config = Config(model_dir + '.pdmodel', model_dir + '.pdiparams')
'trt_fp32': PrecisionType.Float32,
'trt_fp16': PrecisionType.Half
}
config = Config(model_dir+'.pdmodel', model_dir+'.pdiparams')
if use_gpu: if use_gpu:
# initial GPU memory(M), device ID # initial GPU memory(M), device ID
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, 0)
...@@ -307,13 +272,12 @@ def load_predictor(model_dir, ...@@ -307,13 +272,12 @@ def load_predictor(model_dir,
config.disable_gpu() config.disable_gpu()
if run_mode in precision_map.keys(): if run_mode in precision_map.keys():
config.enable_tensorrt_engine( config.enable_tensorrt_engine(workspace_size=1 << 10,
workspace_size=1 << 10, max_batch_size=batch_size,
max_batch_size=batch_size, min_subgraph_size=min_subgraph_size,
min_subgraph_size=min_subgraph_size, precision_mode=precision_map[run_mode],
precision_mode=precision_map[run_mode], use_static=False,
use_static=False, use_calib_mode=False)
use_calib_mode=False)
# disable print log when predict # disable print log when predict
config.disable_glog_info() config.disable_glog_info()
......
...@@ -11,18 +11,18 @@ ...@@ -11,18 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64
import os import os
import time import time
import base64
from functools import reduce from functools import reduce
from typing import Union from typing import Union
import numpy as np import numpy as np
from paddlehub.module.module import moduleinfo, serving
import solov2.processor as P
import solov2.data_feed as D import solov2.data_feed as D
import solov2.processor as P
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import serving
class Detector: class Detector:
...@@ -33,19 +33,18 @@ class Detector: ...@@ -33,19 +33,18 @@ class Detector:
threshold (float): threshold to reserve the result for output. threshold (float): threshold to reserve the result for output.
""" """
def __init__(self, def __init__(self, min_subgraph_size: int = 60, use_gpu=False):
min_subgraph_size: int = 60,
use_gpu=False):
self.default_pretrained_model_path = os.path.join(self.directory, 'solov2_r50_fpn_1x', 'model') self.default_pretrained_model_path = os.path.join(self.directory, 'solov2_r50_fpn_1x', 'model')
self.predictor = D.load_predictor( self.predictor = D.load_predictor(self.default_pretrained_model_path,
self.default_pretrained_model_path, min_subgraph_size=min_subgraph_size,
min_subgraph_size=min_subgraph_size, use_gpu=use_gpu)
use_gpu=use_gpu) self.compose = [
self.compose = [P.Resize(max_size=1333), P.Resize(max_size=1333),
P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), P.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
P.Permute(), P.Permute(),
P.PadStride(stride=32)] P.PadStride(stride=32)
]
def transform(self, im: Union[str, np.ndarray]): def transform(self, im: Union[str, np.ndarray]):
im, im_info = P.preprocess(im, self.compose) im, im_info = P.preprocess(im, self.compose)
...@@ -60,17 +59,14 @@ class Detector: ...@@ -60,17 +59,14 @@ class Detector:
for box in np_boxes: for box in np_boxes:
print('class_id:{:d}, confidence:{:.4f},' print('class_id:{:d}, confidence:{:.4f},'
'left_top:[{:.2f},{:.2f}],' 'left_top:[{:.2f},{:.2f}],'
' right_bottom:[{:.2f},{:.2f}]'.format( ' right_bottom:[{:.2f},{:.2f}]'.format(int(box[0]), box[1], box[2], box[3], box[4], box[5]))
int(box[0]), box[1], box[2], box[3], box[4], box[5]))
results['boxes'] = np_boxes results['boxes'] = np_boxes
if np_masks is not None: if np_masks is not None:
np_masks = np_masks[expect_boxes, :, :, :] np_masks = np_masks[expect_boxes, :, :, :]
results['masks'] = np_masks results['masks'] = np_masks
return results return results
def predict(self, def predict(self, image: Union[str, np.ndarray], threshold: float = 0.5):
image: Union[str, np.ndarray],
threshold: float = 0.5):
''' '''
Args: Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2 image (str/np.ndarray): path of image/ np.ndarray read by cv2
...@@ -103,24 +99,21 @@ class Detector: ...@@ -103,24 +99,21 @@ class Detector:
return results return results
@moduleinfo( @moduleinfo(name="solov2",
name="solov2", type="CV/instance_segmentation",
type="CV/instance_segmentation", author="paddlepaddle",
author="paddlepaddle", author_email="",
author_email="", summary="solov2 is a detection model, this module is trained with COCO dataset.",
summary="solov2 is a detection model, this module is trained with COCO dataset.", version="1.2.0")
version="1.1.0")
class DetectorSOLOv2(Detector): class DetectorSOLOv2(Detector):
""" """
Args: Args:
use_gpu (bool): whether use gpu use_gpu (bool): whether use gpu
threshold (float): threshold to reserve the result for output. threshold (float): threshold to reserve the result for output.
""" """
def __init__(self,
use_gpu: bool = False):
super(DetectorSOLOv2, self).__init__(
use_gpu=use_gpu)
def __init__(self, use_gpu: bool = False):
super(DetectorSOLOv2, self).__init__(use_gpu=use_gpu)
def predict(self, def predict(self,
image: Union[str, np.ndarray], image: Union[str, np.ndarray],
...@@ -133,7 +126,7 @@ class DetectorSOLOv2(Detector): ...@@ -133,7 +126,7 @@ class DetectorSOLOv2(Detector):
threshold (float): threshold of predicted box' score threshold (float): threshold of predicted box' score
visualization (bool): Whether to save visualization result. visualization (bool): Whether to save visualization result.
save_dir (str): save path. save_dir (str): save path.
''' '''
inputs, im_info = self.transform(image) inputs, im_info = self.transform(image)
...@@ -146,14 +139,11 @@ class DetectorSOLOv2(Detector): ...@@ -146,14 +139,11 @@ class DetectorSOLOv2(Detector):
self.predictor.run() self.predictor.run()
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
np_label = self.predictor.get_output_handle(output_names[ np_label = self.predictor.get_output_handle(output_names[1]).copy_to_cpu()
1]).copy_to_cpu() np_score = self.predictor.get_output_handle(output_names[2]).copy_to_cpu()
np_score = self.predictor.get_output_handle(output_names[ np_segms = self.predictor.get_output_handle(output_names[3]).copy_to_cpu()
2]).copy_to_cpu()
np_segms = self.predictor.get_output_handle(output_names[
3]).copy_to_cpu()
output = dict(segm=np_segms, label=np_label, score=np_score) output = dict(segm=np_segms, label=np_label, score=np_score)
if visualization: if visualization:
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
...@@ -174,4 +164,26 @@ class DetectorSOLOv2(Detector): ...@@ -174,4 +164,26 @@ class DetectorSOLOv2(Detector):
final['segm'] = base64.b64encode(results['segm']).decode('utf8') final['segm'] = base64.b64encode(results['segm']).decode('utf8')
final['label'] = base64.b64encode(results['label']).decode('utf8') final['label'] = base64.b64encode(results['label']).decode('utf8')
final['score'] = base64.b64encode(results['score']).decode('utf8') final['score'] = base64.b64encode(results['score']).decode('utf8')
return final return final
\ No newline at end of file
def create_gradio_app(self):
import os
import tempfile
import gradio as gr
from PIL import Image
def inference(img, threshold):
with tempfile.TemporaryDirectory() as tempdir_name:
self.predict(image=img, threshold=threshold, visualization=True, save_dir=tempdir_name)
result_names = os.listdir(tempdir_name)
return Image.open(os.path.join(tempdir_name, result_names[0]))
interface = gr.Interface(inference,
inputs=[gr.inputs.Image(type="filepath"),
gr.Slider(0.0, 1.0, value=0.5)],
outputs=gr.Image(label='segmentation'),
title='SOLOv2',
allow_flagging='never')
return interface
...@@ -11,10 +11,9 @@ ...@@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from PIL import Image
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image
def decode_image(im_file, im_info): def decode_image(im_file, im_info):
...@@ -78,20 +77,13 @@ class Resize(object): ...@@ -78,20 +77,13 @@ class Resize(object):
im_channel = im.shape[2] im_channel = im.shape[2]
im_scale_x, im_scale_y = self.generate_scale(im) im_scale_x, im_scale_y = self.generate_scale(im)
if self.use_cv2: if self.use_cv2:
im = cv2.resize( im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=self.interp)
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
else: else:
resize_w = int(im_scale_x * float(im.shape[1])) resize_w = int(im_scale_x * float(im.shape[1]))
resize_h = int(im_scale_y * float(im.shape[0])) resize_h = int(im_scale_y * float(im.shape[0]))
if self.max_size != 0: if self.max_size != 0:
raise TypeError( raise TypeError('If you set max_size to cap the maximum size of image,'
'If you set max_size to cap the maximum size of image,' 'please set use_cv2 to True to resize the image.')
'please set use_cv2 to True to resize the image.')
im = im.astype('uint8') im = im.astype('uint8')
im = Image.fromarray(im) im = Image.fromarray(im)
im = im.resize((int(resize_w), int(resize_h)), self.interp) im = im.resize((int(resize_w), int(resize_h)), self.interp)
...@@ -99,8 +91,7 @@ class Resize(object): ...@@ -99,8 +91,7 @@ class Resize(object):
# padding im when image_shape fixed by infer_cfg.yml # padding im when image_shape fixed by infer_cfg.yml
if self.max_size != 0 and self.image_shape is not None: if self.max_size != 0 and self.image_shape is not None:
padding_im = np.zeros( padding_im = np.zeros((self.max_size, self.max_size, im_channel), dtype=np.float32)
(self.max_size, self.max_size, im_channel), dtype=np.float32)
im_h, im_w = im.shape[:2] im_h, im_w = im.shape[:2]
padding_im[:im_h, :im_w, :] = im padding_im[:im_h, :im_w, :] = im
im = padding_im im = padding_im
...@@ -240,4 +231,4 @@ def preprocess(im, preprocess_ops): ...@@ -240,4 +231,4 @@ def preprocess(im, preprocess_ops):
for operator in preprocess_ops: for operator in preprocess_ops:
im, im_info = operator(im, im_info) im, im_info = operator(im, im_info)
im = np.array((im, )).astype('float32') im = np.array((im, )).astype('float32')
return im, im_info return im, im_info
\ No newline at end of file
...@@ -3,15 +3,16 @@ import shutil ...@@ -3,15 +3,16 @@ import shutil
import unittest import unittest
import cv2 import cv2
import requests
import numpy as np import numpy as np
import paddlehub as hub import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase): class TestHubModule(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/7799a8ccc5f6471b9d56fb6eff94f82a08b70ca2c7594d3f99877e366c0a2619' img_url = 'https://ai-studio-static-online.cdn.bcebos.com/7799a8ccc5f6471b9d56fb6eff94f82a08b70ca2c7594d3f99877e366c0a2619'
...@@ -30,10 +31,7 @@ class TestHubModule(unittest.TestCase): ...@@ -30,10 +31,7 @@ class TestHubModule(unittest.TestCase):
shutil.rmtree('solov2_result') shutil.rmtree('solov2_result')
def test_predict1(self): def test_predict1(self):
results = self.module.predict( results = self.module.predict(image='tests/test.jpg', visualization=False)
image='tests/test.jpg',
visualization=False
)
segm = results['segm'] segm = results['segm']
label = results['label'] label = results['label']
score = results['score'] score = results['score']
...@@ -42,10 +40,7 @@ class TestHubModule(unittest.TestCase): ...@@ -42,10 +40,7 @@ class TestHubModule(unittest.TestCase):
self.assertIsInstance(score, np.ndarray) self.assertIsInstance(score, np.ndarray)
def test_predict2(self): def test_predict2(self):
results = self.module.predict( results = self.module.predict(image=cv2.imread('tests/test.jpg'), visualization=False)
image=cv2.imread('tests/test.jpg'),
visualization=False
)
segm = results['segm'] segm = results['segm']
label = results['label'] label = results['label']
score = results['score'] score = results['score']
...@@ -54,10 +49,7 @@ class TestHubModule(unittest.TestCase): ...@@ -54,10 +49,7 @@ class TestHubModule(unittest.TestCase):
self.assertIsInstance(score, np.ndarray) self.assertIsInstance(score, np.ndarray)
def test_predict3(self): def test_predict3(self):
results = self.module.predict( results = self.module.predict(image=cv2.imread('tests/test.jpg'), visualization=True)
image=cv2.imread('tests/test.jpg'),
visualization=True
)
segm = results['segm'] segm = results['segm']
label = results['label'] label = results['label']
score = results['score'] score = results['score']
...@@ -67,10 +59,7 @@ class TestHubModule(unittest.TestCase): ...@@ -67,10 +59,7 @@ class TestHubModule(unittest.TestCase):
def test_predict4(self): def test_predict4(self):
module = hub.Module(name="solov2", use_gpu=True) module = hub.Module(name="solov2", use_gpu=True)
results = module.predict( results = module.predict(image=cv2.imread('tests/test.jpg'), visualization=True)
image=cv2.imread('tests/test.jpg'),
visualization=True
)
segm = results['segm'] segm = results['segm']
label = results['label'] label = results['label']
score = results['score'] score = results['score']
...@@ -79,11 +68,7 @@ class TestHubModule(unittest.TestCase): ...@@ -79,11 +68,7 @@ class TestHubModule(unittest.TestCase):
self.assertIsInstance(score, np.ndarray) self.assertIsInstance(score, np.ndarray)
def test_predict5(self): def test_predict5(self):
self.assertRaises( self.assertRaises(FileNotFoundError, self.module.predict, image='no.jpg')
FileNotFoundError,
self.module.predict,
image='no.jpg'
)
def test_save_inference_model(self): def test_save_inference_model(self):
self.module.save_inference_model('./inference/model') self.module.save_inference_model('./inference/model')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册