Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
c68b6d5d
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 2 年 前同步成功
通知
285
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c68b6d5d
编写于
12月 29, 2022
作者:
jm_12138
提交者:
GitHub
12月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gradio app in solov2 (#2162)
* add gradio app (test) * update * update solov2
上级
e3ee127f
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
133 addition
and
174 deletion
+133
-174
modules/image/instance_segmentation/solov2/README.md
modules/image/instance_segmentation/solov2/README.md
+25
-18
modules/image/instance_segmentation/solov2/data_feed.py
modules/image/instance_segmentation/solov2/data_feed.py
+41
-77
modules/image/instance_segmentation/solov2/module.py
modules/image/instance_segmentation/solov2/module.py
+53
-41
modules/image/instance_segmentation/solov2/processor.py
modules/image/instance_segmentation/solov2/processor.py
+6
-15
modules/image/instance_segmentation/solov2/test.py
modules/image/instance_segmentation/solov2/test.py
+8
-23
未找到文件。
modules/image/instance_segmentation/solov2/README.md
浏览文件 @
c68b6d5d
...
...
@@ -141,6 +141,9 @@
print('score', score)
```
-
### Gradio App 支持
从 PaddleHub 2.3.1 开始支持使用链接 http://127.0.0.1:8866/gradio/solov2 在浏览器中访问 solov2 的 Gradio App。
## 五、更新历史
*
1.0.0
...
...
@@ -151,6 +154,10 @@
适配 PaddlePaddle 2.2.0+
*
```shell
$ hub install hand_pose_localization==1.1.0
*
1.2.0
添加 Gradio APP 支持
-
```shell
$ hub install solov2==1.2.0
```
modules/image/instance_segmentation/solov2/data_feed.py
浏览文件 @
c68b6d5d
import
os
import
base64
import
cv2
import
numpy
as
np
from
paddle.inference
import
Config
,
create_predictor
,
PrecisionType
from
PIL
import
Image
,
ImageDraw
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
paddle.inference
import
PrecisionType
from
PIL
import
Image
from
PIL
import
ImageDraw
def
create_inputs
(
im
,
im_info
):
...
...
@@ -19,8 +21,7 @@ def create_inputs(im, im_info):
inputs
[
'image'
]
=
im
origin_shape
=
list
(
im_info
[
'origin_shape'
])
resize_shape
=
list
(
im_info
[
'resize_shape'
])
pad_shape
=
list
(
im_info
[
'pad_shape'
])
if
im_info
[
'pad_shape'
]
is
not
None
else
list
(
im_info
[
'resize_shape'
])
pad_shape
=
list
(
im_info
[
'pad_shape'
])
if
im_info
[
'pad_shape'
]
is
not
None
else
list
(
im_info
[
'resize_shape'
])
scale_x
,
scale_y
=
im_info
[
'scale'
]
scale
=
scale_x
im_info
=
np
.
array
([
resize_shape
+
[
scale
]]).
astype
(
'float32'
)
...
...
@@ -45,38 +46,28 @@ def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0
im (PIL.Image.Image): visualized image
"""
if
not
labels
:
labels
=
[
'background'
,
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic light'
,
'fire'
,
'hydrant'
,
'stop sign'
,
'parking meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports ball'
,
'kite'
,
'baseball bat'
,
'baseball glove'
,
'skateboard'
,
'surfboard'
,
'tennis racket'
,
'bottle'
,
'wine glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted plant'
,
'bed'
,
'dining table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy bear'
,
'hair drier'
,
'toothbrush'
]
labels
=
[
'background'
,
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic light'
,
'fire'
,
'hydrant'
,
'stop sign'
,
'parking meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports ball'
,
'kite'
,
'baseball bat'
,
'baseball glove'
,
'skateboard'
,
'surfboard'
,
'tennis racket'
,
'bottle'
,
'wine glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted plant'
,
'bed'
,
'dining table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy bear'
,
'hair drier'
,
'toothbrush'
]
if
isinstance
(
im
,
str
):
im
=
Image
.
open
(
im
).
convert
(
'RGB'
)
else
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
Image
.
fromarray
(
im
)
if
'masks'
in
results
and
'boxes'
in
results
:
im
=
draw_mask
(
im
,
results
[
'boxes'
],
results
[
'masks'
],
labels
,
resolution
=
mask_resolution
)
im
=
draw_mask
(
im
,
results
[
'boxes'
],
results
[
'masks'
],
labels
,
resolution
=
mask_resolution
)
if
'boxes'
in
results
:
im
=
draw_box
(
im
,
results
[
'boxes'
],
labels
)
if
'segm'
in
results
:
im
=
draw_segm
(
im
,
results
[
'segm'
],
results
[
'label'
],
results
[
'score'
],
labels
,
threshold
=
threshold
)
im
=
draw_segm
(
im
,
results
[
'segm'
],
results
[
'label'
],
results
[
'score'
],
labels
,
threshold
=
threshold
)
return
im
...
...
@@ -165,8 +156,7 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
y0
=
min
(
max
(
ymin
,
0
),
im_h
)
y1
=
min
(
max
(
ymax
+
1
,
0
),
im_h
)
im_mask
=
np
.
zeros
((
im_h
,
im_w
),
dtype
=
np
.
uint8
)
im_mask
[
y0
:
y1
,
x0
:
x1
]
=
resized_mask
[(
y0
-
ymin
):(
y1
-
ymin
),
(
x0
-
xmin
):(
x1
-
xmin
)]
im_mask
[
y0
:
y1
,
x0
:
x1
]
=
resized_mask
[(
y0
-
ymin
):(
y1
-
ymin
),
(
x0
-
xmin
):(
x1
-
xmin
)]
if
clsid
not
in
clsid2color
:
clsid2color
[
clsid
]
=
color_list
[
clsid
]
color_mask
=
clsid2color
[
clsid
]
...
...
@@ -204,28 +194,19 @@ def draw_box(im, np_boxes, labels):
color
=
tuple
(
clsid2color
[
clsid
])
# draw bbox
draw
.
line
(
[(
xmin
,
ymin
),
(
xmin
,
ymax
),
(
xmax
,
ymax
),
(
xmax
,
ymin
),
(
xmin
,
ymin
)],
draw
.
line
([(
xmin
,
ymin
),
(
xmin
,
ymax
),
(
xmax
,
ymax
),
(
xmax
,
ymin
),
(
xmin
,
ymin
)],
width
=
draw_thickness
,
fill
=
color
)
# draw label
text
=
"{} {:.4f}"
.
format
(
labels
[
clsid
],
score
)
tw
,
th
=
draw
.
textsize
(
text
)
draw
.
rectangle
(
[(
xmin
+
1
,
ymin
-
th
),
(
xmin
+
tw
+
1
,
ymin
)],
fill
=
color
)
draw
.
rectangle
([(
xmin
+
1
,
ymin
-
th
),
(
xmin
+
tw
+
1
,
ymin
)],
fill
=
color
)
draw
.
text
((
xmin
+
1
,
ymin
-
th
),
text
,
fill
=
(
255
,
255
,
255
))
return
im
def
draw_segm
(
im
,
np_segms
,
np_label
,
np_score
,
labels
,
threshold
=
0.5
,
alpha
=
0.7
):
def
draw_segm
(
im
,
np_segms
,
np_label
,
np_score
,
labels
,
threshold
=
0.5
,
alpha
=
0.7
):
"""
Draw segmentation on image.
"""
...
...
@@ -254,28 +235,17 @@ def draw_segm(im,
sum_y
=
np
.
sum
(
mask
,
axis
=
1
)
y
=
np
.
where
(
sum_y
>
0.5
)[
0
]
x0
,
x1
,
y0
,
y1
=
x
[
0
],
x
[
-
1
],
y
[
0
],
y
[
-
1
]
cv2
.
rectangle
(
im
,
(
x0
,
y0
),
(
x1
,
y1
),
tuple
(
color_mask
.
astype
(
'int32'
).
tolist
()),
1
)
cv2
.
rectangle
(
im
,
(
x0
,
y0
),
(
x1
,
y1
),
tuple
(
color_mask
.
astype
(
'int32'
).
tolist
()),
1
)
bbox_text
=
'%s %.2f'
%
(
labels
[
clsid
],
score
)
t_size
=
cv2
.
getTextSize
(
bbox_text
,
0
,
0.3
,
thickness
=
1
)[
0
]
cv2
.
rectangle
(
im
,
(
x0
,
y0
),
(
x0
+
t_size
[
0
],
y0
-
t_size
[
1
]
-
3
),
tuple
(
color_mask
.
astype
(
'int32'
).
tolist
()),
-
1
)
cv2
.
putText
(
im
,
bbox_text
,
(
x0
,
y0
-
2
),
cv2
.
FONT_HERSHEY_SIMPLEX
,
0.3
,
(
0
,
0
,
0
),
1
,
lineType
=
cv2
.
LINE_AA
)
cv2
.
rectangle
(
im
,
(
x0
,
y0
),
(
x0
+
t_size
[
0
],
y0
-
t_size
[
1
]
-
3
),
tuple
(
color_mask
.
astype
(
'int32'
).
tolist
()),
-
1
)
cv2
.
putText
(
im
,
bbox_text
,
(
x0
,
y0
-
2
),
cv2
.
FONT_HERSHEY_SIMPLEX
,
0.3
,
(
0
,
0
,
0
),
1
,
lineType
=
cv2
.
LINE_AA
)
return
Image
.
fromarray
(
im
.
astype
(
'uint8'
))
def
load_predictor
(
model_dir
,
run_mode
=
'paddle'
,
batch_size
=
1
,
use_gpu
=
False
,
min_subgraph_size
=
3
):
def
load_predictor
(
model_dir
,
run_mode
=
'paddle'
,
batch_size
=
1
,
use_gpu
=
False
,
min_subgraph_size
=
3
):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
...
...
@@ -286,18 +256,13 @@ def load_predictor(model_dir,
ValueError: predict by TensorRT need use_gpu == True.
"""
if
not
use_gpu
and
not
run_mode
==
'paddle'
:
raise
ValueError
(
"Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
.
format
(
run_mode
,
use_gpu
))
raise
ValueError
(
"Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
.
format
(
run_mode
,
use_gpu
))
if
run_mode
==
'trt_int8'
:
raise
ValueError
(
"TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead."
)
precision_map
=
{
'trt_int8'
:
PrecisionType
.
Int8
,
'trt_fp32'
:
PrecisionType
.
Float32
,
'trt_fp16'
:
PrecisionType
.
Half
}
config
=
Config
(
model_dir
+
'.pdmodel'
,
model_dir
+
'.pdiparams'
)
precision_map
=
{
'trt_int8'
:
PrecisionType
.
Int8
,
'trt_fp32'
:
PrecisionType
.
Float32
,
'trt_fp16'
:
PrecisionType
.
Half
}
config
=
Config
(
model_dir
+
'.pdmodel'
,
model_dir
+
'.pdiparams'
)
if
use_gpu
:
# initial GPU memory(M), device ID
config
.
enable_use_gpu
(
100
,
0
)
...
...
@@ -307,8 +272,7 @@ def load_predictor(model_dir,
config
.
disable_gpu
()
if
run_mode
in
precision_map
.
keys
():
config
.
enable_tensorrt_engine
(
workspace_size
=
1
<<
10
,
config
.
enable_tensorrt_engine
(
workspace_size
=
1
<<
10
,
max_batch_size
=
batch_size
,
min_subgraph_size
=
min_subgraph_size
,
precision_mode
=
precision_map
[
run_mode
],
...
...
modules/image/instance_segmentation/solov2/module.py
浏览文件 @
c68b6d5d
...
...
@@ -11,18 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
base64
import
os
import
time
import
base64
from
functools
import
reduce
from
typing
import
Union
import
numpy
as
np
from
paddlehub.module.module
import
moduleinfo
,
serving
import
solov2.processor
as
P
import
solov2.data_feed
as
D
import
solov2.processor
as
P
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.module.module
import
serving
class
Detector
:
...
...
@@ -33,19 +33,18 @@ class Detector:
threshold (float): threshold to reserve the result for output.
"""
def
__init__
(
self
,
min_subgraph_size
:
int
=
60
,
use_gpu
=
False
):
def
__init__
(
self
,
min_subgraph_size
:
int
=
60
,
use_gpu
=
False
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
'solov2_r50_fpn_1x'
,
'model'
)
self
.
predictor
=
D
.
load_predictor
(
self
.
default_pretrained_model_path
,
self
.
predictor
=
D
.
load_predictor
(
self
.
default_pretrained_model_path
,
min_subgraph_size
=
min_subgraph_size
,
use_gpu
=
use_gpu
)
self
.
compose
=
[
P
.
Resize
(
max_size
=
1333
),
self
.
compose
=
[
P
.
Resize
(
max_size
=
1333
),
P
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
P
.
Permute
(),
P
.
PadStride
(
stride
=
32
)]
P
.
PadStride
(
stride
=
32
)
]
def
transform
(
self
,
im
:
Union
[
str
,
np
.
ndarray
]):
im
,
im_info
=
P
.
preprocess
(
im
,
self
.
compose
)
...
...
@@ -60,17 +59,14 @@ class Detector:
for
box
in
np_boxes
:
print
(
'class_id:{:d}, confidence:{:.4f},'
'left_top:[{:.2f},{:.2f}],'
' right_bottom:[{:.2f},{:.2f}]'
.
format
(
int
(
box
[
0
]),
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
],
box
[
5
]))
' right_bottom:[{:.2f},{:.2f}]'
.
format
(
int
(
box
[
0
]),
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
],
box
[
5
]))
results
[
'boxes'
]
=
np_boxes
if
np_masks
is
not
None
:
np_masks
=
np_masks
[
expect_boxes
,
:,
:,
:]
results
[
'masks'
]
=
np_masks
return
results
def
predict
(
self
,
image
:
Union
[
str
,
np
.
ndarray
],
threshold
:
float
=
0.5
):
def
predict
(
self
,
image
:
Union
[
str
,
np
.
ndarray
],
threshold
:
float
=
0.5
):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
...
...
@@ -103,24 +99,21 @@ class Detector:
return
results
@
moduleinfo
(
name
=
"solov2"
,
@
moduleinfo
(
name
=
"solov2"
,
type
=
"CV/instance_segmentation"
,
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"solov2 is a detection model, this module is trained with COCO dataset."
,
version
=
"1.1
.0"
)
version
=
"1.2
.0"
)
class
DetectorSOLOv2
(
Detector
):
"""
Args:
use_gpu (bool): whether use gpu
threshold (float): threshold to reserve the result for output.
"""
def
__init__
(
self
,
use_gpu
:
bool
=
False
):
super
(
DetectorSOLOv2
,
self
).
__init__
(
use_gpu
=
use_gpu
)
def
__init__
(
self
,
use_gpu
:
bool
=
False
):
super
(
DetectorSOLOv2
,
self
).
__init__
(
use_gpu
=
use_gpu
)
def
predict
(
self
,
image
:
Union
[
str
,
np
.
ndarray
],
...
...
@@ -146,12 +139,9 @@ class DetectorSOLOv2(Detector):
self
.
predictor
.
run
()
output_names
=
self
.
predictor
.
get_output_names
()
np_label
=
self
.
predictor
.
get_output_handle
(
output_names
[
1
]).
copy_to_cpu
()
np_score
=
self
.
predictor
.
get_output_handle
(
output_names
[
2
]).
copy_to_cpu
()
np_segms
=
self
.
predictor
.
get_output_handle
(
output_names
[
3
]).
copy_to_cpu
()
np_label
=
self
.
predictor
.
get_output_handle
(
output_names
[
1
]).
copy_to_cpu
()
np_score
=
self
.
predictor
.
get_output_handle
(
output_names
[
2
]).
copy_to_cpu
()
np_segms
=
self
.
predictor
.
get_output_handle
(
output_names
[
3
]).
copy_to_cpu
()
output
=
dict
(
segm
=
np_segms
,
label
=
np_label
,
score
=
np_score
)
if
visualization
:
...
...
@@ -175,3 +165,25 @@ class DetectorSOLOv2(Detector):
final
[
'label'
]
=
base64
.
b64encode
(
results
[
'label'
]).
decode
(
'utf8'
)
final
[
'score'
]
=
base64
.
b64encode
(
results
[
'score'
]).
decode
(
'utf8'
)
return
final
def
create_gradio_app
(
self
):
import
os
import
tempfile
import
gradio
as
gr
from
PIL
import
Image
def
inference
(
img
,
threshold
):
with
tempfile
.
TemporaryDirectory
()
as
tempdir_name
:
self
.
predict
(
image
=
img
,
threshold
=
threshold
,
visualization
=
True
,
save_dir
=
tempdir_name
)
result_names
=
os
.
listdir
(
tempdir_name
)
return
Image
.
open
(
os
.
path
.
join
(
tempdir_name
,
result_names
[
0
]))
interface
=
gr
.
Interface
(
inference
,
inputs
=
[
gr
.
inputs
.
Image
(
type
=
"filepath"
),
gr
.
Slider
(
0.0
,
1.0
,
value
=
0.5
)],
outputs
=
gr
.
Image
(
label
=
'segmentation'
),
title
=
'SOLOv2'
,
allow_flagging
=
'never'
)
return
interface
modules/image/instance_segmentation/solov2/processor.py
浏览文件 @
c68b6d5d
...
...
@@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
PIL
import
Image
import
cv2
import
numpy
as
np
from
PIL
import
Image
def
decode_image
(
im_file
,
im_info
):
...
...
@@ -78,19 +77,12 @@ class Resize(object):
im_channel
=
im
.
shape
[
2
]
im_scale_x
,
im_scale_y
=
self
.
generate_scale
(
im
)
if
self
.
use_cv2
:
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
else
:
resize_w
=
int
(
im_scale_x
*
float
(
im
.
shape
[
1
]))
resize_h
=
int
(
im_scale_y
*
float
(
im
.
shape
[
0
]))
if
self
.
max_size
!=
0
:
raise
TypeError
(
'If you set max_size to cap the maximum size of image,'
raise
TypeError
(
'If you set max_size to cap the maximum size of image,'
'please set use_cv2 to True to resize the image.'
)
im
=
im
.
astype
(
'uint8'
)
im
=
Image
.
fromarray
(
im
)
...
...
@@ -99,8 +91,7 @@ class Resize(object):
# padding im when image_shape fixed by infer_cfg.yml
if
self
.
max_size
!=
0
and
self
.
image_shape
is
not
None
:
padding_im
=
np
.
zeros
(
(
self
.
max_size
,
self
.
max_size
,
im_channel
),
dtype
=
np
.
float32
)
padding_im
=
np
.
zeros
((
self
.
max_size
,
self
.
max_size
,
im_channel
),
dtype
=
np
.
float32
)
im_h
,
im_w
=
im
.
shape
[:
2
]
padding_im
[:
im_h
,
:
im_w
,
:]
=
im
im
=
padding_im
...
...
modules/image/instance_segmentation/solov2/test.py
浏览文件 @
c68b6d5d
...
...
@@ -3,15 +3,16 @@ import shutil
import
unittest
import
cv2
import
requests
import
numpy
as
np
import
paddlehub
as
hub
import
requests
import
paddlehub
as
hub
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
class
TestHubModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
img_url
=
'https://ai-studio-static-online.cdn.bcebos.com/7799a8ccc5f6471b9d56fb6eff94f82a08b70ca2c7594d3f99877e366c0a2619'
...
...
@@ -30,10 +31,7 @@ class TestHubModule(unittest.TestCase):
shutil
.
rmtree
(
'solov2_result'
)
def
test_predict1
(
self
):
results
=
self
.
module
.
predict
(
image
=
'tests/test.jpg'
,
visualization
=
False
)
results
=
self
.
module
.
predict
(
image
=
'tests/test.jpg'
,
visualization
=
False
)
segm
=
results
[
'segm'
]
label
=
results
[
'label'
]
score
=
results
[
'score'
]
...
...
@@ -42,10 +40,7 @@ class TestHubModule(unittest.TestCase):
self
.
assertIsInstance
(
score
,
np
.
ndarray
)
def
test_predict2
(
self
):
results
=
self
.
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
False
)
results
=
self
.
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
False
)
segm
=
results
[
'segm'
]
label
=
results
[
'label'
]
score
=
results
[
'score'
]
...
...
@@ -54,10 +49,7 @@ class TestHubModule(unittest.TestCase):
self
.
assertIsInstance
(
score
,
np
.
ndarray
)
def
test_predict3
(
self
):
results
=
self
.
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
True
)
results
=
self
.
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
True
)
segm
=
results
[
'segm'
]
label
=
results
[
'label'
]
score
=
results
[
'score'
]
...
...
@@ -67,10 +59,7 @@ class TestHubModule(unittest.TestCase):
def
test_predict4
(
self
):
module
=
hub
.
Module
(
name
=
"solov2"
,
use_gpu
=
True
)
results
=
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
True
)
results
=
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
True
)
segm
=
results
[
'segm'
]
label
=
results
[
'label'
]
score
=
results
[
'score'
]
...
...
@@ -79,11 +68,7 @@ class TestHubModule(unittest.TestCase):
self
.
assertIsInstance
(
score
,
np
.
ndarray
)
def
test_predict5
(
self
):
self
.
assertRaises
(
FileNotFoundError
,
self
.
module
.
predict
,
image
=
'no.jpg'
)
self
.
assertRaises
(
FileNotFoundError
,
self
.
module
.
predict
,
image
=
'no.jpg'
)
def
test_save_inference_model
(
self
):
self
.
module
.
save_inference_model
(
'./inference/model'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录