Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
2ce0e07b
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2ce0e07b
编写于
10月 14, 2022
作者:
jm_12138
提交者:
GitHub
10月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update solov2 (#2015)
* update solov2 * fix typo
上级
71ee4cf6
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
228 addition
and
75 deletion
+228
-75
modules/image/instance_segmentation/solov2/README.md
modules/image/instance_segmentation/solov2/README.md
+7
-5
modules/image/instance_segmentation/solov2/data_feed.py
modules/image/instance_segmentation/solov2/data_feed.py
+72
-36
modules/image/instance_segmentation/solov2/example.png
modules/image/instance_segmentation/solov2/example.png
+0
-0
modules/image/instance_segmentation/solov2/module.py
modules/image/instance_segmentation/solov2/module.py
+40
-29
modules/image/instance_segmentation/solov2/processor.py
modules/image/instance_segmentation/solov2/processor.py
+13
-5
modules/image/instance_segmentation/solov2/test.py
modules/image/instance_segmentation/solov2/test.py
+96
-0
未找到文件。
modules/image/instance_segmentation/solov2/README.md
浏览文件 @
2ce0e07b
...
...
@@ -78,7 +78,7 @@
-
res (dict): 识别结果,关键字有 'segm', 'label', 'score'对应的取值为:
-
segm (np.ndarray): 实例分割结果,取值为0或1。0表示背景,1为实例;
-
label (list): 实例分割结果类别id;
-
score (list):实例分割结果类别得分;
s
-
score (list):实例分割结果类别得分;
## 四、服务部署
...
...
@@ -147,8 +147,10 @@
初始发布
*
```shell
$ hub install hand_pose_localization==1.0.0
```
*
1.1.0
适配 PaddlePaddle 2.2.0+
*
```shell
$ hub install hand_pose_localization==1.1.0
```
\ No newline at end of file
modules/image/instance_segmentation/solov2/data_feed.py
浏览文件 @
2ce0e07b
...
...
@@ -3,8 +3,8 @@ import base64
import
cv2
import
numpy
as
np
from
paddle.inference
import
Config
,
create_predictor
,
PrecisionType
from
PIL
import
Image
,
ImageDraw
import
paddle.fluid
as
fluid
def
create_inputs
(
im
,
im_info
):
...
...
@@ -19,11 +19,14 @@ def create_inputs(im, im_info):
inputs
[
'image'
]
=
im
origin_shape
=
list
(
im_info
[
'origin_shape'
])
resize_shape
=
list
(
im_info
[
'resize_shape'
])
pad_shape
=
list
(
im_info
[
'pad_shape'
])
if
im_info
[
'pad_shape'
]
is
not
None
else
list
(
im_info
[
'resize_shape'
])
pad_shape
=
list
(
im_info
[
'pad_shape'
])
if
im_info
[
'pad_shape'
]
is
not
None
else
list
(
im_info
[
'resize_shape'
])
scale_x
,
scale_y
=
im_info
[
'scale'
]
scale
=
scale_x
im_info
=
np
.
array
([
resize_shape
+
[
scale
]]).
astype
(
'float32'
)
inputs
[
'im_info'
]
=
im_info
inputs
[
'scale_factor'
]
=
np
.
array
([
scale_x
,
scale_x
]).
astype
(
'float32'
).
reshape
(
-
1
,
2
)
inputs
[
'im_shape'
]
=
np
.
array
(
resize_shape
).
astype
(
'float32'
).
reshape
(
-
1
,
2
)
return
inputs
...
...
@@ -42,28 +45,38 @@ def visualize_box_mask(im, results, labels=None, mask_resolution=14, threshold=0
im (PIL.Image.Image): visualized image
"""
if
not
labels
:
labels
=
[
'background'
,
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic light'
,
'fire'
,
'hydrant'
,
'stop sign'
,
'parking meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports ball'
,
'kite'
,
'baseball bat'
,
'baseball glove'
,
'skateboard'
,
'surfboard'
,
'tennis racket'
,
'bottle'
,
'wine glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted plant'
,
'bed'
,
'dining table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy bear'
,
'hair drier'
,
'toothbrush'
]
labels
=
[
'background'
,
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic light'
,
'fire'
,
'hydrant'
,
'stop sign'
,
'parking meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports ball'
,
'kite'
,
'baseball bat'
,
'baseball glove'
,
'skateboard'
,
'surfboard'
,
'tennis racket'
,
'bottle'
,
'wine glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted plant'
,
'bed'
,
'dining table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy bear'
,
'hair drier'
,
'toothbrush'
]
if
isinstance
(
im
,
str
):
im
=
Image
.
open
(
im
).
convert
(
'RGB'
)
else
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
Image
.
fromarray
(
im
)
if
'masks'
in
results
and
'boxes'
in
results
:
im
=
draw_mask
(
im
,
results
[
'boxes'
],
results
[
'masks'
],
labels
,
resolution
=
mask_resolution
)
im
=
draw_mask
(
im
,
results
[
'boxes'
],
results
[
'masks'
],
labels
,
resolution
=
mask_resolution
)
if
'boxes'
in
results
:
im
=
draw_box
(
im
,
results
[
'boxes'
],
labels
)
if
'segm'
in
results
:
im
=
draw_segm
(
im
,
results
[
'segm'
],
results
[
'label'
],
results
[
'score'
],
labels
,
threshold
=
threshold
)
im
=
draw_segm
(
im
,
results
[
'segm'
],
results
[
'label'
],
results
[
'score'
],
labels
,
threshold
=
threshold
)
return
im
...
...
@@ -152,7 +165,8 @@ def draw_mask(im, np_boxes, np_masks, labels, resolution=14, threshold=0.5):
y0
=
min
(
max
(
ymin
,
0
),
im_h
)
y1
=
min
(
max
(
ymax
+
1
,
0
),
im_h
)
im_mask
=
np
.
zeros
((
im_h
,
im_w
),
dtype
=
np
.
uint8
)
im_mask
[
y0
:
y1
,
x0
:
x1
]
=
resized_mask
[(
y0
-
ymin
):(
y1
-
ymin
),
(
x0
-
xmin
):(
x1
-
xmin
)]
im_mask
[
y0
:
y1
,
x0
:
x1
]
=
resized_mask
[(
y0
-
ymin
):(
y1
-
ymin
),
(
x0
-
xmin
):(
x1
-
xmin
)]
if
clsid
not
in
clsid2color
:
clsid2color
[
clsid
]
=
color_list
[
clsid
]
color_mask
=
clsid2color
[
clsid
]
...
...
@@ -190,19 +204,28 @@ def draw_box(im, np_boxes, labels):
color
=
tuple
(
clsid2color
[
clsid
])
# draw bbox
draw
.
line
([(
xmin
,
ymin
),
(
xmin
,
ymax
),
(
xmax
,
ymax
),
(
xmax
,
ymin
),
(
xmin
,
ymin
)],
width
=
draw_thickness
,
fill
=
color
)
draw
.
line
(
[(
xmin
,
ymin
),
(
xmin
,
ymax
),
(
xmax
,
ymax
),
(
xmax
,
ymin
),
(
xmin
,
ymin
)],
width
=
draw_thickness
,
fill
=
color
)
# draw label
text
=
"{} {:.4f}"
.
format
(
labels
[
clsid
],
score
)
tw
,
th
=
draw
.
textsize
(
text
)
draw
.
rectangle
([(
xmin
+
1
,
ymin
-
th
),
(
xmin
+
tw
+
1
,
ymin
)],
fill
=
color
)
draw
.
rectangle
(
[(
xmin
+
1
,
ymin
-
th
),
(
xmin
+
tw
+
1
,
ymin
)],
fill
=
color
)
draw
.
text
((
xmin
+
1
,
ymin
-
th
),
text
,
fill
=
(
255
,
255
,
255
))
return
im
def
draw_segm
(
im
,
np_segms
,
np_label
,
np_score
,
labels
,
threshold
=
0.5
,
alpha
=
0.7
):
def
draw_segm
(
im
,
np_segms
,
np_label
,
np_score
,
labels
,
threshold
=
0.5
,
alpha
=
0.7
):
"""
Draw segmentation on image.
"""
...
...
@@ -231,17 +254,28 @@ def draw_segm(im, np_segms, np_label, np_score, labels, threshold=0.5, alpha=0.7
sum_y
=
np
.
sum
(
mask
,
axis
=
1
)
y
=
np
.
where
(
sum_y
>
0.5
)[
0
]
x0
,
x1
,
y0
,
y1
=
x
[
0
],
x
[
-
1
],
y
[
0
],
y
[
-
1
]
cv2
.
rectangle
(
im
,
(
x0
,
y0
),
(
x1
,
y1
),
tuple
(
color_mask
.
astype
(
'int32'
).
tolist
()),
1
)
cv2
.
rectangle
(
im
,
(
x0
,
y0
),
(
x1
,
y1
),
tuple
(
color_mask
.
astype
(
'int32'
).
tolist
()),
1
)
bbox_text
=
'%s %.2f'
%
(
labels
[
clsid
],
score
)
t_size
=
cv2
.
getTextSize
(
bbox_text
,
0
,
0.3
,
thickness
=
1
)[
0
]
cv2
.
rectangle
(
im
,
(
x0
,
y0
),
(
x0
+
t_size
[
0
],
y0
-
t_size
[
1
]
-
3
),
tuple
(
color_mask
.
astype
(
'int32'
).
tolist
()),
-
1
)
cv2
.
putText
(
im
,
bbox_text
,
(
x0
,
y0
-
2
),
cv2
.
FONT_HERSHEY_SIMPLEX
,
0.3
,
(
0
,
0
,
0
),
1
,
lineType
=
cv2
.
LINE_AA
)
cv2
.
rectangle
(
im
,
(
x0
,
y0
),
(
x0
+
t_size
[
0
],
y0
-
t_size
[
1
]
-
3
),
tuple
(
color_mask
.
astype
(
'int32'
).
tolist
()),
-
1
)
cv2
.
putText
(
im
,
bbox_text
,
(
x0
,
y0
-
2
),
cv2
.
FONT_HERSHEY_SIMPLEX
,
0.3
,
(
0
,
0
,
0
),
1
,
lineType
=
cv2
.
LINE_AA
)
return
Image
.
fromarray
(
im
.
astype
(
'uint8'
))
def
load_predictor
(
model_dir
,
run_mode
=
'fluid'
,
batch_size
=
1
,
use_gpu
=
False
,
min_subgraph_size
=
3
):
def
load_predictor
(
model_dir
,
run_mode
=
'paddle'
,
batch_size
=
1
,
use_gpu
=
False
,
min_subgraph_size
=
3
):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
...
...
@@ -251,17 +285,19 @@ def load_predictor(model_dir, run_mode='fluid', batch_size=1, use_gpu=False, min
Raises:
ValueError: predict by TensorRT need use_gpu == True.
"""
if
not
use_gpu
and
not
run_mode
==
'fluid'
:
raise
ValueError
(
"Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
.
format
(
run_mode
,
use_gpu
))
if
not
use_gpu
and
not
run_mode
==
'paddle'
:
raise
ValueError
(
"Predict by TensorRT mode: {}, expect use_gpu==True, but use_gpu == {}"
.
format
(
run_mode
,
use_gpu
))
if
run_mode
==
'trt_int8'
:
raise
ValueError
(
"TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead."
)
raise
ValueError
(
"TensorRT int8 mode is not supported now, "
"please use trt_fp32 or trt_fp16 instead."
)
precision_map
=
{
'trt_int8'
:
fluid
.
core
.
AnalysisConfig
.
Precision
.
Int8
,
'trt_fp32'
:
fluid
.
core
.
AnalysisConfig
.
Precision
.
Float32
,
'trt_fp16'
:
fluid
.
core
.
AnalysisConfig
.
Precision
.
Half
'trt_int8'
:
PrecisionType
.
Int8
,
'trt_fp32'
:
PrecisionType
.
Float32
,
'trt_fp16'
:
PrecisionType
.
Half
}
config
=
fluid
.
core
.
AnalysisConfig
(
os
.
path
.
join
(
model_dir
,
'__model__'
),
os
.
path
.
join
(
model_dir
,
'__params__'
)
)
config
=
Config
(
model_dir
+
'.pdmodel'
,
model_dir
+
'.pdiparams'
)
if
use_gpu
:
# initial GPU memory(M), device ID
config
.
enable_use_gpu
(
100
,
0
)
...
...
@@ -285,7 +321,7 @@ def load_predictor(model_dir, run_mode='fluid', batch_size=1, use_gpu=False, min
config
.
enable_memory_optim
()
# disable feed, fetch OP, needed by zero_copy_run
config
.
switch_use_feed_fetch_ops
(
False
)
predictor
=
fluid
.
core
.
create_paddl
e_predictor
(
config
)
predictor
=
creat
e_predictor
(
config
)
return
predictor
...
...
modules/image/instance_segmentation/solov2/example.png
已删除
100644 → 0
浏览文件 @
71ee4cf6
448.7 KB
modules/image/instance_segmentation/solov2/module.py
浏览文件 @
2ce0e07b
...
...
@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
time
import
base64
from
functools
import
reduce
from
typing
import
Union
import
cv2
import
numpy
as
np
from
paddlehub.module.module
import
moduleinfo
,
serving
...
...
@@ -25,7 +25,7 @@ import solov2.processor as P
import
solov2.data_feed
as
D
class
Detector
(
object
)
:
class
Detector
:
"""
Args:
min_subgraph_size (int): number of tensorRT graphs.
...
...
@@ -33,23 +33,26 @@ class Detector(object):
threshold (float): threshold to reserve the result for output.
"""
def
__init__
(
self
,
min_subgraph_size
:
int
=
60
,
use_gpu
=
False
,
threshold
:
float
=
0.5
):
def
__init__
(
self
,
min_subgraph_size
:
int
=
60
,
use_gpu
=
False
):
model_dir
=
os
.
path
.
join
(
self
.
directory
,
'solov2_r50_fpn_1x'
)
self
.
predictor
=
D
.
load_predictor
(
model_dir
,
min_subgraph_size
=
min_subgraph_size
,
use_gpu
=
use_gpu
)
self
.
compose
=
[
P
.
Resize
(
max_size
=
1333
),
P
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
P
.
Permute
(),
P
.
PadStride
(
stride
=
32
)
]
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
'solov2_r50_fpn_1x'
,
'model'
)
self
.
predictor
=
D
.
load_predictor
(
self
.
default_pretrained_model_path
,
min_subgraph_size
=
min_subgraph_size
,
use_gpu
=
use_gpu
)
self
.
compose
=
[
P
.
Resize
(
max_size
=
1333
),
P
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
]),
P
.
Permute
(),
P
.
PadStride
(
stride
=
32
)]
def
transform
(
self
,
im
:
Union
[
str
,
np
.
ndarray
]):
im
,
im_info
=
P
.
preprocess
(
im
,
self
.
compose
)
inputs
=
D
.
create_inputs
(
im
,
im_info
)
return
inputs
,
im_info
def
postprocess
(
self
,
np_boxes
:
np
.
ndarray
,
np_masks
:
np
.
ndarray
,
im_info
:
dict
,
threshold
:
float
=
0.5
):
def
postprocess
(
self
,
np_boxes
:
np
.
ndarray
,
np_masks
:
np
.
ndarray
,
threshold
:
float
=
0.5
):
# postprocess output of predictor
results
=
{}
expect_boxes
=
(
np_boxes
[:,
1
]
>
threshold
)
&
(
np_boxes
[:,
0
]
>
-
1
)
...
...
@@ -57,14 +60,17 @@ class Detector(object):
for
box
in
np_boxes
:
print
(
'class_id:{:d}, confidence:{:.4f},'
'left_top:[{:.2f},{:.2f}],'
' right_bottom:[{:.2f},{:.2f}]'
.
format
(
int
(
box
[
0
]),
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
],
box
[
5
]))
' right_bottom:[{:.2f},{:.2f}]'
.
format
(
int
(
box
[
0
]),
box
[
1
],
box
[
2
],
box
[
3
],
box
[
4
],
box
[
5
]))
results
[
'boxes'
]
=
np_boxes
if
np_masks
is
not
None
:
np_masks
=
np_masks
[
expect_boxes
,
:,
:,
:]
results
[
'masks'
]
=
np_masks
return
results
def
predict
(
self
,
image
:
Union
[
str
,
np
.
ndarray
],
threshold
:
float
=
0.5
):
def
predict
(
self
,
image
:
Union
[
str
,
np
.
ndarray
],
threshold
:
float
=
0.5
):
'''
Args:
image (str/np.ndarray): path of image/ np.ndarray read by cv2
...
...
@@ -80,12 +86,12 @@ class Detector(object):
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_
tensor
(
input_names
[
i
])
input_tensor
=
self
.
predictor
.
get_input_
handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
input_names
[
i
]])
self
.
predictor
.
zero_copy_
run
()
self
.
predictor
.
run
()
output_names
=
self
.
predictor
.
get_output_names
()
boxes_tensor
=
self
.
predictor
.
get_output_
tensor
(
output_names
[
0
])
boxes_tensor
=
self
.
predictor
.
get_output_
handle
(
output_names
[
0
])
np_boxes
=
boxes_tensor
.
copy_to_cpu
()
# do not perform postprocess in benchmark mode
results
=
[]
...
...
@@ -103,16 +109,18 @@ class Detector(object):
author
=
"paddlepaddle"
,
author_email
=
""
,
summary
=
"solov2 is a detection model, this module is trained with COCO dataset."
,
version
=
"1.
0
.0"
)
version
=
"1.
1
.0"
)
class
DetectorSOLOv2
(
Detector
):
"""
Args:
use_gpu (bool): whether use gpu
threshold (float): threshold to reserve the result for output.
"""
def
__init__
(
self
,
use_gpu
:
bool
=
False
):
super
(
DetectorSOLOv2
,
self
).
__init__
(
use_gpu
=
use_gpu
)
def
__init__
(
self
,
use_gpu
:
bool
=
False
,
threshold
:
float
=
0.5
):
super
(
DetectorSOLOv2
,
self
).
__init__
(
use_gpu
=
use_gpu
,
threshold
=
threshold
)
def
predict
(
self
,
image
:
Union
[
str
,
np
.
ndarray
],
...
...
@@ -125,7 +133,7 @@ class DetectorSOLOv2(Detector):
threshold (float): threshold of predicted box' score
visualization (bool): Whether to save visualization result.
save_dir (str): save path.
'''
inputs
,
im_info
=
self
.
transform
(
image
)
...
...
@@ -133,20 +141,23 @@ class DetectorSOLOv2(Detector):
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
self
.
predictor
.
get_input_
tensor
(
input_names
[
i
])
input_tensor
=
self
.
predictor
.
get_input_
handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
input_names
[
i
]])
self
.
predictor
.
zero_copy_
run
()
self
.
predictor
.
run
()
output_names
=
self
.
predictor
.
get_output_names
()
np_label
=
self
.
predictor
.
get_output_tensor
(
output_names
[
0
]).
copy_to_cpu
()
np_score
=
self
.
predictor
.
get_output_tensor
(
output_names
[
1
]).
copy_to_cpu
()
np_segms
=
self
.
predictor
.
get_output_tensor
(
output_names
[
2
]).
copy_to_cpu
()
np_label
=
self
.
predictor
.
get_output_handle
(
output_names
[
1
]).
copy_to_cpu
()
np_score
=
self
.
predictor
.
get_output_handle
(
output_names
[
2
]).
copy_to_cpu
()
np_segms
=
self
.
predictor
.
get_output_handle
(
output_names
[
3
]).
copy_to_cpu
()
output
=
dict
(
segm
=
np_segms
,
label
=
np_label
,
score
=
np_score
)
if
visualization
:
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
image
=
D
.
visualize_box_mask
(
im
=
image
,
results
=
output
)
image
=
D
.
visualize_box_mask
(
im
=
image
,
results
=
output
,
threshold
=
threshold
)
name
=
str
(
time
.
time
())
+
'.png'
save_path
=
os
.
path
.
join
(
save_dir
,
name
)
image
.
save
(
save_path
)
...
...
@@ -163,4 +174,4 @@ class DetectorSOLOv2(Detector):
final
[
'segm'
]
=
base64
.
b64encode
(
results
[
'segm'
]).
decode
(
'utf8'
)
final
[
'label'
]
=
base64
.
b64encode
(
results
[
'label'
]).
decode
(
'utf8'
)
final
[
'score'
]
=
base64
.
b64encode
(
results
[
'score'
]).
decode
(
'utf8'
)
return
final
return
final
\ No newline at end of file
modules/image/instance_segmentation/solov2/processor.py
浏览文件 @
2ce0e07b
...
...
@@ -78,13 +78,20 @@ class Resize(object):
im_channel
=
im
.
shape
[
2
]
im_scale_x
,
im_scale_y
=
self
.
generate_scale
(
im
)
if
self
.
use_cv2
:
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
else
:
resize_w
=
int
(
im_scale_x
*
float
(
im
.
shape
[
1
]))
resize_h
=
int
(
im_scale_y
*
float
(
im
.
shape
[
0
]))
if
self
.
max_size
!=
0
:
raise
TypeError
(
'If you set max_size to cap the maximum size of image,'
'please set use_cv2 to True to resize the image.'
)
raise
TypeError
(
'If you set max_size to cap the maximum size of image,'
'please set use_cv2 to True to resize the image.'
)
im
=
im
.
astype
(
'uint8'
)
im
=
Image
.
fromarray
(
im
)
im
=
im
.
resize
((
int
(
resize_w
),
int
(
resize_h
)),
self
.
interp
)
...
...
@@ -92,7 +99,8 @@ class Resize(object):
# padding im when image_shape fixed by infer_cfg.yml
if
self
.
max_size
!=
0
and
self
.
image_shape
is
not
None
:
padding_im
=
np
.
zeros
((
self
.
max_size
,
self
.
max_size
,
im_channel
),
dtype
=
np
.
float32
)
padding_im
=
np
.
zeros
(
(
self
.
max_size
,
self
.
max_size
,
im_channel
),
dtype
=
np
.
float32
)
im_h
,
im_w
=
im
.
shape
[:
2
]
padding_im
[:
im_h
,
:
im_w
,
:]
=
im
im
=
padding_im
...
...
@@ -232,4 +240,4 @@ def preprocess(im, preprocess_ops):
for
operator
in
preprocess_ops
:
im
,
im_info
=
operator
(
im
,
im_info
)
im
=
np
.
array
((
im
,
)).
astype
(
'float32'
)
return
im
,
im_info
return
im
,
im_info
\ No newline at end of file
modules/image/instance_segmentation/solov2/test.py
0 → 100644
浏览文件 @
2ce0e07b
import
os
import
shutil
import
unittest
import
cv2
import
requests
import
numpy
as
np
import
paddlehub
as
hub
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
class
TestHubModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
)
->
None
:
img_url
=
'https://ai-studio-static-online.cdn.bcebos.com/7799a8ccc5f6471b9d56fb6eff94f82a08b70ca2c7594d3f99877e366c0a2619'
if
not
os
.
path
.
exists
(
'tests'
):
os
.
makedirs
(
'tests'
)
response
=
requests
.
get
(
img_url
)
assert
response
.
status_code
==
200
,
'Network Error.'
with
open
(
'tests/test.jpg'
,
'wb'
)
as
f
:
f
.
write
(
response
.
content
)
cls
.
module
=
hub
.
Module
(
name
=
"solov2"
)
@
classmethod
def
tearDownClass
(
cls
)
->
None
:
shutil
.
rmtree
(
'tests'
)
shutil
.
rmtree
(
'inference'
)
shutil
.
rmtree
(
'solov2_result'
)
def
test_predict1
(
self
):
results
=
self
.
module
.
predict
(
image
=
'tests/test.jpg'
,
visualization
=
False
)
segm
=
results
[
'segm'
]
label
=
results
[
'label'
]
score
=
results
[
'score'
]
self
.
assertIsInstance
(
segm
,
np
.
ndarray
)
self
.
assertIsInstance
(
label
,
np
.
ndarray
)
self
.
assertIsInstance
(
score
,
np
.
ndarray
)
def
test_predict2
(
self
):
results
=
self
.
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
False
)
segm
=
results
[
'segm'
]
label
=
results
[
'label'
]
score
=
results
[
'score'
]
self
.
assertIsInstance
(
segm
,
np
.
ndarray
)
self
.
assertIsInstance
(
label
,
np
.
ndarray
)
self
.
assertIsInstance
(
score
,
np
.
ndarray
)
def
test_predict3
(
self
):
results
=
self
.
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
True
)
segm
=
results
[
'segm'
]
label
=
results
[
'label'
]
score
=
results
[
'score'
]
self
.
assertIsInstance
(
segm
,
np
.
ndarray
)
self
.
assertIsInstance
(
label
,
np
.
ndarray
)
self
.
assertIsInstance
(
score
,
np
.
ndarray
)
def
test_predict4
(
self
):
module
=
hub
.
Module
(
name
=
"solov2"
,
use_gpu
=
True
)
results
=
module
.
predict
(
image
=
cv2
.
imread
(
'tests/test.jpg'
),
visualization
=
True
)
segm
=
results
[
'segm'
]
label
=
results
[
'label'
]
score
=
results
[
'score'
]
self
.
assertIsInstance
(
segm
,
np
.
ndarray
)
self
.
assertIsInstance
(
label
,
np
.
ndarray
)
self
.
assertIsInstance
(
score
,
np
.
ndarray
)
def
test_predict5
(
self
):
self
.
assertRaises
(
FileNotFoundError
,
self
.
module
.
predict
,
image
=
'no.jpg'
)
def
test_save_inference_model
(
self
):
self
.
module
.
save_inference_model
(
'./inference/model'
)
self
.
assertTrue
(
os
.
path
.
exists
(
'./inference/model.pdmodel'
))
self
.
assertTrue
(
os
.
path
.
exists
(
'./inference/model.pdiparams'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录