Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
9b82f2fb
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b82f2fb
编写于
5月 07, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add detection module - venus
上级
59499f1b
变更
22
展开全部
隐藏空白更改
内联
并排
Showing
22 changed file
with
3354 addition
and
0 deletion
+3354
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/README.md
...object_detection/faster_rcnn_resnet50_fpn_venus/README.md
+71
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/__init__.py
...ject_detection/faster_rcnn_resnet50_fpn_venus/__init__.py
+0
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/bbox_assigner.py
...detection/faster_rcnn_resnet50_fpn_venus/bbox_assigner.py
+20
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/bbox_head.py
...ect_detection/faster_rcnn_resnet50_fpn_venus/bbox_head.py
+270
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/data_feed.py
...ect_detection/faster_rcnn_resnet50_fpn_venus/data_feed.py
+118
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/fpn.py
...ge/object_detection/faster_rcnn_resnet50_fpn_venus/fpn.py
+296
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/module.py
...object_detection/faster_rcnn_resnet50_fpn_venus/module.py
+250
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/name_adapter.py
..._detection/faster_rcnn_resnet50_fpn_venus/name_adapter.py
+61
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/nonlocal_helper.py
...tection/faster_rcnn_resnet50_fpn_venus/nonlocal_helper.py
+154
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/processor.py
...ect_detection/faster_rcnn_resnet50_fpn_venus/processor.py
+176
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/resnet.py
...object_detection/faster_rcnn_resnet50_fpn_venus/resnet.py
+447
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/roi_extractor.py
...detection/faster_rcnn_resnet50_fpn_venus/roi_extractor.py
+76
-0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/rpn_head.py
...ject_detection/faster_rcnn_resnet50_fpn_venus/rpn_head.py
+533
-0
hub_module/modules/image/object_detection/yolov3_darknet53_venus/README.md
...s/image/object_detection/yolov3_darknet53_venus/README.md
+51
-0
hub_module/modules/image/object_detection/yolov3_darknet53_venus/__init__.py
...image/object_detection/yolov3_darknet53_venus/__init__.py
+0
-0
hub_module/modules/image/object_detection/yolov3_darknet53_venus/darknet.py
.../image/object_detection/yolov3_darknet53_venus/darknet.py
+168
-0
hub_module/modules/image/object_detection/yolov3_darknet53_venus/data_feed.py
...mage/object_detection/yolov3_darknet53_venus/data_feed.py
+71
-0
hub_module/modules/image/object_detection/yolov3_darknet53_venus/module.py
...s/image/object_detection/yolov3_darknet53_venus/module.py
+125
-0
hub_module/modules/image/object_detection/yolov3_darknet53_venus/processor.py
...mage/object_detection/yolov3_darknet53_venus/processor.py
+180
-0
hub_module/modules/image/object_detection/yolov3_darknet53_venus/yolo_head.py
...mage/object_detection/yolov3_darknet53_venus/yolo_head.py
+273
-0
hub_module/scripts/configs/faster_rcnn_resnet50_fpn_venus.yml
...module/scripts/configs/faster_rcnn_resnet50_fpn_venus.yml
+7
-0
hub_module/scripts/configs/yolov3_darknet53_venus.yml
hub_module/scripts/configs/yolov3_darknet53_venus.yml
+7
-0
未找到文件。
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/README.md
0 → 100644
浏览文件 @
9b82f2fb
## 命令行预测
```
shell
$
hub run faster_rcnn_resnet50_fpn_venus
--input_path
"/PATH/TO/IMAGE"
```
## API
```
python
def
context
(
num_classes
=
81
,
trainable
=
True
,
pretrained
=
True
,
phase
=
'train'
)
```
提取特征,用于迁移学习。
**参数**
*
num
\_
classes (int): 类别数;
*
trainable(bool): 参数是否可训练;
*
pretrained (bool): 是否加载预训练模型;
*
phase (str): 可选值为 'train'/'predict','trian' 用于训练,'predict' 用于预测。
**返回**
*
inputs (dict): 模型的输入,相应的取值为:
当 phase 为 'train'时,包含:
*
image (Variable): 图像变量
*
im
\_
size (Variable): 图像的尺寸
*
im
\_
info (Variable): 图像缩放信息
*
gt
\_
class (Variable): 检测框类别
*
gt
\_
box (Variable): 检测框坐标
*
is
\_
crowd (Variable): 单个框内是否包含多个物体
当 phase 为 'predict'时,包含:
*
image (Variable): 图像变量
*
im
\_
size (Variable): 图像的尺寸
*
im
\_
info (Variable): 图像缩放信息
*
outputs (dict): 模型的输出,相应的取值为:
当 phase 为 'train'时,包含:
*
head_features (Variable): 所提取的特征
*
rpn
\_
cls
\_
loss (Variable): 检测框分类损失
*
rpn
\_
reg
\_
loss (Variable): 检测框回归损失
*
generate
\_
proposal
\_
labels (Variable): 图像信息
当 phase 为 'predict'时,包含:
*
head_features (Variable): 所提取的特征
*
rois (Variable): 提取的roi
*
bbox
\_
out (Variable): 预测结果
*
context
\_
prog (Program): 用于迁移学习的 Program。
```
python
def
save_inference_model
(
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
)
```
将模型保存到指定路径。
**参数**
*
dirname: 存在模型的目录名称
*
model
\_
filename: 模型文件名称,默认为
\_\_
model
\_\_
*
params
\_
filename: 参数文件名称,默认为
\_\_
params
\_\_
(仅当
`combined`
为True时生效)
*
combined: 是否将参数保存到统一的一个文件中
### 依赖
paddlepaddle >= 1.6.2
paddlehub >= 1.6.0
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/__init__.py
0 → 100644
浏览文件 @
9b82f2fb
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/bbox_assigner.py
0 → 100644
浏览文件 @
9b82f2fb
class
BBoxAssigner
(
object
):
# __op__ = fluid.layers.generate_proposal_labels
def
__init__
(
self
,
batch_size_per_im
=
512
,
fg_fraction
=
.
25
,
fg_thresh
=
.
5
,
bg_thresh_hi
=
.
5
,
bg_thresh_lo
=
0.
,
bbox_reg_weights
=
[
0.1
,
0.1
,
0.2
,
0.2
],
class_nums
=
81
,
shuffle_before_sample
=
True
):
super
(
BBoxAssigner
,
self
).
__init__
()
self
.
batch_size_per_im
=
batch_size_per_im
self
.
fg_fraction
=
fg_fraction
self
.
fg_thresh
=
fg_thresh
self
.
bg_thresh_hi
=
bg_thresh_hi
self
.
bg_thresh_lo
=
bg_thresh_lo
self
.
bbox_reg_weights
=
bbox_reg_weights
self
.
class_nums
=
class_nums
self
.
use_random
=
shuffle_before_sample
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/bbox_head.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
Normal
,
Xavier
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.initializer
import
MSRA
class
MultiClassNMS
(
object
):
# __op__ = fluid.layers.multiclass_nms
def
__init__
(
self
,
score_threshold
=
.
05
,
nms_top_k
=-
1
,
keep_top_k
=
100
,
nms_threshold
=
.
5
,
normalized
=
False
,
nms_eta
=
1.0
,
background_label
=
0
):
super
(
MultiClassNMS
,
self
).
__init__
()
self
.
score_threshold
=
score_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
self
.
nms_threshold
=
nms_threshold
self
.
normalized
=
normalized
self
.
nms_eta
=
nms_eta
self
.
background_label
=
background_label
class
SmoothL1Loss
(
object
):
'''
Smooth L1 loss
Args:
sigma (float): hyper param in smooth l1 loss
'''
def
__init__
(
self
,
sigma
=
1.0
):
super
(
SmoothL1Loss
,
self
).
__init__
()
self
.
sigma
=
sigma
def
__call__
(
self
,
x
,
y
,
inside_weight
=
None
,
outside_weight
=
None
):
return
fluid
.
layers
.
smooth_l1
(
x
,
y
,
inside_weight
=
inside_weight
,
outside_weight
=
outside_weight
,
sigma
=
self
.
sigma
)
class
BoxCoder
(
object
):
def
__init__
(
self
,
prior_box_var
=
[
0.1
,
0.1
,
0.2
,
0.2
],
code_type
=
'decode_center_size'
,
box_normalized
=
False
,
axis
=
1
):
super
(
BoxCoder
,
self
).
__init__
()
self
.
prior_box_var
=
prior_box_var
self
.
code_type
=
code_type
self
.
box_normalized
=
box_normalized
self
.
axis
=
axis
class
TwoFCHead
(
object
):
"""
RCNN head with two Fully Connected layers
Args:
mlp_dim (int): num of filters for the fc layers
"""
def
__init__
(
self
,
mlp_dim
=
1024
):
super
(
TwoFCHead
,
self
).
__init__
()
self
.
mlp_dim
=
mlp_dim
def
__call__
(
self
,
roi_feat
):
fan
=
roi_feat
.
shape
[
1
]
*
roi_feat
.
shape
[
2
]
*
roi_feat
.
shape
[
3
]
fc6
=
fluid
.
layers
.
fc
(
input
=
roi_feat
,
size
=
self
.
mlp_dim
,
act
=
'relu'
,
name
=
'fc6'
,
param_attr
=
ParamAttr
(
name
=
'fc6_w'
,
initializer
=
Xavier
(
fan_out
=
fan
)),
bias_attr
=
ParamAttr
(
name
=
'fc6_b'
,
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)))
head_feat
=
fluid
.
layers
.
fc
(
input
=
fc6
,
size
=
self
.
mlp_dim
,
act
=
'relu'
,
name
=
'fc7'
,
param_attr
=
ParamAttr
(
name
=
'fc7_w'
,
initializer
=
Xavier
()),
bias_attr
=
ParamAttr
(
name
=
'fc7_b'
,
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)))
return
head_feat
class
BBoxHead
(
object
):
"""
RCNN bbox head
Args:
head (object): the head module instance, e.g., `ResNetC5`, `TwoFCHead`
box_coder (object): `BoxCoder` instance
nms (object): `MultiClassNMS` instance
num_classes: number of output classes
"""
__inject__
=
[
'head'
,
'box_coder'
,
'nms'
,
'bbox_loss'
]
__shared__
=
[
'num_classes'
]
def
__init__
(
self
,
head
,
box_coder
=
BoxCoder
(),
nms
=
MultiClassNMS
(),
bbox_loss
=
SmoothL1Loss
(),
num_classes
=
81
):
super
(
BBoxHead
,
self
).
__init__
()
self
.
head
=
head
self
.
num_classes
=
num_classes
self
.
box_coder
=
box_coder
self
.
nms
=
nms
self
.
bbox_loss
=
bbox_loss
self
.
head_feat
=
None
def
get_head_feat
(
self
,
input
=
None
):
"""
Get the bbox head feature map.
"""
if
input
is
not
None
:
feat
=
self
.
head
(
input
)
if
isinstance
(
feat
,
OrderedDict
):
feat
=
list
(
feat
.
values
())[
0
]
self
.
head_feat
=
feat
return
self
.
head_feat
def
_get_output
(
self
,
roi_feat
):
"""
Get bbox head output.
Args:
roi_feat (Variable): RoI feature from RoIExtractor.
Returns:
cls_score(Variable): Output of rpn head with shape of
[N, num_anchors, H, W].
bbox_pred(Variable): Output of rpn head with shape of
[N, num_anchors * 4, H, W].
"""
head_feat
=
self
.
get_head_feat
(
roi_feat
)
# when ResNetC5 output a single feature map
if
not
isinstance
(
self
.
head
,
TwoFCHead
):
head_feat
=
fluid
.
layers
.
pool2d
(
head_feat
,
pool_type
=
'avg'
,
global_pooling
=
True
)
cls_score
=
fluid
.
layers
.
fc
(
input
=
head_feat
,
size
=
self
.
num_classes
,
act
=
None
,
name
=
'cls_score'
,
param_attr
=
ParamAttr
(
name
=
'cls_score_w'
,
initializer
=
Normal
(
loc
=
0.0
,
scale
=
0.01
)),
bias_attr
=
ParamAttr
(
name
=
'cls_score_b'
,
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)))
bbox_pred
=
fluid
.
layers
.
fc
(
input
=
head_feat
,
size
=
4
*
self
.
num_classes
,
act
=
None
,
name
=
'bbox_pred'
,
param_attr
=
ParamAttr
(
name
=
'bbox_pred_w'
,
initializer
=
Normal
(
loc
=
0.0
,
scale
=
0.001
)),
bias_attr
=
ParamAttr
(
name
=
'bbox_pred_b'
,
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)))
return
cls_score
,
bbox_pred
def
get_loss
(
self
,
roi_feat
,
labels_int32
,
bbox_targets
,
bbox_inside_weights
,
bbox_outside_weights
):
"""
Get bbox_head loss.
Args:
roi_feat (Variable): RoI feature from RoIExtractor.
labels_int32(Variable): Class label of a RoI with shape [P, 1].
P is the number of RoI.
bbox_targets(Variable): Box label of a RoI with shape
[P, 4 * class_nums].
bbox_inside_weights(Variable): Indicates whether a box should
contribute to loss. Same shape as bbox_targets.
bbox_outside_weights(Variable): Indicates whether a box should
contribute to loss. Same shape as bbox_targets.
Return:
Type: Dict
loss_cls(Variable): bbox_head loss.
loss_bbox(Variable): bbox_head loss.
"""
cls_score
,
bbox_pred
=
self
.
_get_output
(
roi_feat
)
labels_int64
=
fluid
.
layers
.
cast
(
x
=
labels_int32
,
dtype
=
'int64'
)
labels_int64
.
stop_gradient
=
True
loss_cls
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
cls_score
,
label
=
labels_int64
,
numeric_stable_mode
=
True
)
loss_cls
=
fluid
.
layers
.
reduce_mean
(
loss_cls
)
loss_bbox
=
self
.
bbox_loss
(
x
=
bbox_pred
,
y
=
bbox_targets
,
inside_weight
=
bbox_inside_weights
,
outside_weight
=
bbox_outside_weights
)
loss_bbox
=
fluid
.
layers
.
reduce_mean
(
loss_bbox
)
return
{
'loss_cls'
:
loss_cls
,
'loss_bbox'
:
loss_bbox
}
def
get_prediction
(
self
,
roi_feat
,
rois
,
im_info
,
im_shape
,
return_box_score
=
False
):
"""
Get prediction bounding box in test stage.
Args:
roi_feat (Variable): RoI feature from RoIExtractor.
rois (Variable): Output of generate_proposals in rpn head.
im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the
number of input images, each element consists of im_height,
im_width, im_scale.
im_shape (Variable): Actual shape of original image with shape
[B, 3]. B is the number of images, each element consists of
original_height, original_width, 1
Returns:
pred_result(Variable): Prediction result with shape [N, 6]. Each
row has 6 values: [label, confidence, xmin, ymin, xmax, ymax].
N is the total number of prediction.
"""
cls_score
,
bbox_pred
=
self
.
_get_output
(
roi_feat
)
im_scale
=
fluid
.
layers
.
slice
(
im_info
,
[
1
],
starts
=
[
2
],
ends
=
[
3
])
im_scale
=
fluid
.
layers
.
sequence_expand
(
im_scale
,
rois
)
boxes
=
rois
/
im_scale
cls_prob
=
fluid
.
layers
.
softmax
(
cls_score
,
use_cudnn
=
False
)
bbox_pred
=
fluid
.
layers
.
reshape
(
bbox_pred
,
(
-
1
,
self
.
num_classes
,
4
))
# self.box_coder
decoded_box
=
fluid
.
layers
.
box_coder
(
prior_box
=
boxes
,
target_box
=
bbox_pred
,
prior_box_var
=
self
.
box_coder
.
prior_box_var
,
code_type
=
self
.
box_coder
.
code_type
,
box_normalized
=
self
.
box_coder
.
box_normalized
,
axis
=
self
.
box_coder
.
axis
)
cliped_box
=
fluid
.
layers
.
box_clip
(
input
=
decoded_box
,
im_info
=
im_shape
)
if
return_box_score
:
return
{
'bbox'
:
cliped_box
,
'score'
:
cls_prob
}
# self.nms
pred_result
=
fluid
.
layers
.
multiclass_nms
(
bboxes
=
cliped_box
,
scores
=
cls_prob
,
score_threshold
=
self
.
nms
.
score_threshold
,
nms_top_k
=
self
.
nms
.
nms_top_k
,
keep_top_k
=
self
.
nms
.
keep_top_k
,
nms_threshold
=
self
.
nms
.
nms_threshold
,
normalized
=
self
.
nms
.
normalized
,
nms_eta
=
self
.
nms
.
nms_eta
,
background_label
=
self
.
nms
.
background_label
)
return
pred_result
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/data_feed.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
os
from
collections
import
OrderedDict
import
cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageEnhance
from
paddle
import
fluid
__all__
=
[
'test_reader'
]
def
test_reader
(
paths
=
None
,
images
=
None
):
"""
data generator
Args:
paths (list[str]): paths to images.
images (list(numpy.ndarray)): data of images, shape of each is [H, W, C]
Yield:
res (dict): key contains 'image', 'im_info', 'im_shape', the corresponding values is:
image (numpy.ndarray): the image to be fed into network
im_info (numpy.ndarray): the info about the preprocessed.
im_shape (numpy.ndarray): the shape of image.
"""
img_list
=
list
()
if
paths
:
for
img_path
in
paths
:
assert
os
.
path
.
isfile
(
img_path
),
"The {} isn't a valid file path."
.
format
(
img_path
)
img
=
cv2
.
imread
(
img_path
).
astype
(
'float32'
)
img_list
.
append
(
img
)
if
images
is
not
None
:
for
img
in
images
:
img_list
.
append
(
img
)
for
im
in
img_list
:
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
im
.
astype
(
np
.
float32
,
copy
=
False
)
mean
=
[
0.485
,
0.456
,
0.406
]
std
=
[
0.229
,
0.224
,
0.225
]
mean
=
np
.
array
(
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
=
im
/
255.0
im
-=
mean
im
/=
std
target_size
=
800
max_size
=
1333
shape
=
im
.
shape
# im_shape holds the original shape of image.
im_shape
=
np
.
array
([
shape
[
0
],
shape
[
1
],
1.0
]).
astype
(
'float32'
)
im_size_min
=
np
.
min
(
shape
[
0
:
2
])
im_size_max
=
np
.
max
(
shape
[
0
:
2
])
im_scale
=
float
(
target_size
)
/
float
(
im_size_min
)
if
np
.
round
(
im_scale
*
im_size_max
)
>
max_size
:
im_scale
=
float
(
max_size
)
/
float
(
im_size_max
)
resize_w
=
np
.
round
(
im_scale
*
float
(
shape
[
1
]))
resize_h
=
np
.
round
(
im_scale
*
float
(
shape
[
0
]))
# im_info holds the resize info of image.
im_info
=
np
.
array
([
resize_h
,
resize_w
,
im_scale
]).
astype
(
'float32'
)
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale
,
fy
=
im_scale
,
interpolation
=
cv2
.
INTER_LINEAR
)
# HWC --> CHW
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
np
.
swapaxes
(
im
,
1
,
0
)
yield
{
'image'
:
im
,
'im_info'
:
im_info
,
'im_shape'
:
im_shape
}
def
padding_minibatch
(
batch_data
,
coarsest_stride
=
0
,
use_padded_im_info
=
True
):
max_shape_org
=
np
.
array
(
[
data
[
'image'
].
shape
for
data
in
batch_data
]).
max
(
axis
=
0
)
if
coarsest_stride
>
0
:
max_shape
=
np
.
zeros
((
3
)).
astype
(
'int32'
)
max_shape
[
1
]
=
int
(
np
.
ceil
(
max_shape_org
[
1
]
/
coarsest_stride
)
*
coarsest_stride
)
max_shape
[
2
]
=
int
(
np
.
ceil
(
max_shape_org
[
2
]
/
coarsest_stride
)
*
coarsest_stride
)
else
:
max_shape
=
max_shape_org
.
astype
(
'int32'
)
padding_image
=
list
()
padding_info
=
list
()
padding_shape
=
list
()
for
data
in
batch_data
:
im_c
,
im_h
,
im_w
=
data
[
'image'
].
shape
# image
padding_im
=
np
.
zeros
((
im_c
,
max_shape
[
1
],
max_shape
[
2
]),
dtype
=
np
.
float32
)
padding_im
[:,
0
:
im_h
,
0
:
im_w
]
=
data
[
'image'
]
padding_image
.
append
(
padding_im
)
# im_info
data
[
'im_info'
][
0
]
=
max_shape
[
1
]
if
use_padded_im_info
else
max_shape_org
[
1
]
data
[
'im_info'
][
1
]
=
max_shape
[
2
]
if
use_padded_im_info
else
max_shape_org
[
2
]
padding_info
.
append
(
data
[
'im_info'
])
padding_shape
.
append
(
data
[
'im_shape'
])
padding_image
=
np
.
array
(
padding_image
).
astype
(
'float32'
)
padding_info
=
np
.
array
(
padding_info
).
astype
(
'float32'
)
padding_shape
=
np
.
array
(
padding_shape
).
astype
(
'float32'
)
return
padding_image
,
padding_info
,
padding_shape
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/fpn.py
0 → 100644
浏览文件 @
9b82f2fb
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
copy
from
collections
import
OrderedDict
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
Xavier
from
paddle.fluid.regularizer
import
L2Decay
__all__
=
[
'ConvNorm'
,
'FPN'
]
def
ConvNorm
(
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
norm_decay
=
0.
,
norm_type
=
'affine_channel'
,
norm_groups
=
32
,
dilation
=
1
,
lr_scale
=
1
,
freeze_norm
=
False
,
act
=
None
,
norm_name
=
None
,
initializer
=
None
,
name
=
None
):
fan
=
num_filters
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
((
filter_size
-
1
)
//
2
)
*
dilation
,
dilation
=
dilation
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
"_weights"
,
initializer
=
initializer
,
learning_rate
=
lr_scale
),
bias_attr
=
False
,
name
=
name
+
'.conv2d.output.1'
)
norm_lr
=
0.
if
freeze_norm
else
1.
pattr
=
ParamAttr
(
name
=
norm_name
+
'_scale'
,
learning_rate
=
norm_lr
*
lr_scale
,
regularizer
=
L2Decay
(
norm_decay
))
battr
=
ParamAttr
(
name
=
norm_name
+
'_offset'
,
learning_rate
=
norm_lr
*
lr_scale
,
regularizer
=
L2Decay
(
norm_decay
))
if
norm_type
in
[
'bn'
,
'sync_bn'
]:
global_stats
=
True
if
freeze_norm
else
False
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
name
=
norm_name
+
'.output.1'
,
param_attr
=
pattr
,
bias_attr
=
battr
,
moving_mean_name
=
norm_name
+
'_mean'
,
moving_variance_name
=
norm_name
+
'_variance'
,
use_global_stats
=
global_stats
)
scale
=
fluid
.
framework
.
_get_var
(
pattr
.
name
)
bias
=
fluid
.
framework
.
_get_var
(
battr
.
name
)
elif
norm_type
==
'gn'
:
out
=
fluid
.
layers
.
group_norm
(
input
=
conv
,
act
=
act
,
name
=
norm_name
+
'.output.1'
,
groups
=
norm_groups
,
param_attr
=
pattr
,
bias_attr
=
battr
)
scale
=
fluid
.
framework
.
_get_var
(
pattr
.
name
)
bias
=
fluid
.
framework
.
_get_var
(
battr
.
name
)
elif
norm_type
==
'affine_channel'
:
scale
=
fluid
.
layers
.
create_parameter
(
shape
=
[
conv
.
shape
[
1
]],
dtype
=
conv
.
dtype
,
attr
=
pattr
,
default_initializer
=
fluid
.
initializer
.
Constant
(
1.
))
bias
=
fluid
.
layers
.
create_parameter
(
shape
=
[
conv
.
shape
[
1
]],
dtype
=
conv
.
dtype
,
attr
=
battr
,
default_initializer
=
fluid
.
initializer
.
Constant
(
0.
))
out
=
fluid
.
layers
.
affine_channel
(
x
=
conv
,
scale
=
scale
,
bias
=
bias
,
act
=
act
)
if
freeze_norm
:
scale
.
stop_gradient
=
True
bias
.
stop_gradient
=
True
return
out
class
FPN
(
object
):
"""
Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
Args:
num_chan (int): number of feature channels
min_level (int): lowest level of the backbone feature map to use
max_level (int): highest level of the backbone feature map to use
spatial_scale (list): feature map scaling factor
has_extra_convs (bool): whether has extral convolutions in higher levels
norm_type (str|None): normalization type, 'bn'/'sync_bn'/'affine_channel'
"""
__shared__
=
[
'norm_type'
,
'freeze_norm'
]
def
__init__
(
self
,
num_chan
=
256
,
min_level
=
2
,
max_level
=
6
,
spatial_scale
=
[
1.
/
32.
,
1.
/
16.
,
1.
/
8.
,
1.
/
4.
],
has_extra_convs
=
False
,
norm_type
=
None
,
freeze_norm
=
False
):
self
.
freeze_norm
=
freeze_norm
self
.
num_chan
=
num_chan
self
.
min_level
=
min_level
self
.
max_level
=
max_level
self
.
spatial_scale
=
spatial_scale
self
.
has_extra_convs
=
has_extra_convs
self
.
norm_type
=
norm_type
def
_add_topdown_lateral
(
self
,
body_name
,
body_input
,
upper_output
):
lateral_name
=
'fpn_inner_'
+
body_name
+
'_lateral'
topdown_name
=
'fpn_topdown_'
+
body_name
fan
=
body_input
.
shape
[
1
]
if
self
.
norm_type
:
initializer
=
Xavier
(
fan_out
=
fan
)
lateral
=
ConvNorm
(
body_input
,
self
.
num_chan
,
1
,
initializer
=
initializer
,
norm_type
=
self
.
norm_type
,
freeze_norm
=
self
.
freeze_norm
,
name
=
lateral_name
,
norm_name
=
lateral_name
)
else
:
lateral
=
fluid
.
layers
.
conv2d
(
body_input
,
self
.
num_chan
,
1
,
param_attr
=
ParamAttr
(
name
=
lateral_name
+
"_w"
,
initializer
=
Xavier
(
fan_out
=
fan
)),
bias_attr
=
ParamAttr
(
name
=
lateral_name
+
"_b"
,
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)),
name
=
lateral_name
)
topdown
=
fluid
.
layers
.
resize_nearest
(
upper_output
,
scale
=
2.
,
name
=
topdown_name
)
return
lateral
+
topdown
def
get_output
(
self
,
body_dict
):
"""
Add FPN onto backbone.
Args:
body_dict(OrderedDict): Dictionary of variables and each element is the
output of backbone.
Return:
fpn_dict(OrderedDict): A dictionary represents the output of FPN with
their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
"""
spatial_scale
=
copy
.
deepcopy
(
self
.
spatial_scale
)
body_name_list
=
list
(
body_dict
.
keys
())[::
-
1
]
num_backbone_stages
=
len
(
body_name_list
)
self
.
fpn_inner_output
=
[[]
for
_
in
range
(
num_backbone_stages
)]
fpn_inner_name
=
'fpn_inner_'
+
body_name_list
[
0
]
body_input
=
body_dict
[
body_name_list
[
0
]]
fan
=
body_input
.
shape
[
1
]
if
self
.
norm_type
:
initializer
=
Xavier
(
fan_out
=
fan
)
self
.
fpn_inner_output
[
0
]
=
ConvNorm
(
body_input
,
self
.
num_chan
,
1
,
initializer
=
initializer
,
norm_type
=
self
.
norm_type
,
freeze_norm
=
self
.
freeze_norm
,
name
=
fpn_inner_name
,
norm_name
=
fpn_inner_name
)
else
:
self
.
fpn_inner_output
[
0
]
=
fluid
.
layers
.
conv2d
(
body_input
,
self
.
num_chan
,
1
,
param_attr
=
ParamAttr
(
name
=
fpn_inner_name
+
"_w"
,
initializer
=
Xavier
(
fan_out
=
fan
)),
bias_attr
=
ParamAttr
(
name
=
fpn_inner_name
+
"_b"
,
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)),
name
=
fpn_inner_name
)
for
i
in
range
(
1
,
num_backbone_stages
):
body_name
=
body_name_list
[
i
]
body_input
=
body_dict
[
body_name
]
top_output
=
self
.
fpn_inner_output
[
i
-
1
]
fpn_inner_single
=
self
.
_add_topdown_lateral
(
body_name
,
body_input
,
top_output
)
self
.
fpn_inner_output
[
i
]
=
fpn_inner_single
fpn_dict
=
{}
fpn_name_list
=
[]
for
i
in
range
(
num_backbone_stages
):
fpn_name
=
'fpn_'
+
body_name_list
[
i
]
fan
=
self
.
fpn_inner_output
[
i
].
shape
[
1
]
*
3
*
3
if
self
.
norm_type
:
initializer
=
Xavier
(
fan_out
=
fan
)
fpn_output
=
ConvNorm
(
self
.
fpn_inner_output
[
i
],
self
.
num_chan
,
3
,
initializer
=
initializer
,
norm_type
=
self
.
norm_type
,
freeze_norm
=
self
.
freeze_norm
,
name
=
fpn_name
,
norm_name
=
fpn_name
)
else
:
fpn_output
=
fluid
.
layers
.
conv2d
(
self
.
fpn_inner_output
[
i
],
self
.
num_chan
,
filter_size
=
3
,
padding
=
1
,
param_attr
=
ParamAttr
(
name
=
fpn_name
+
"_w"
,
initializer
=
Xavier
(
fan_out
=
fan
)),
bias_attr
=
ParamAttr
(
name
=
fpn_name
+
"_b"
,
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)),
name
=
fpn_name
)
fpn_dict
[
fpn_name
]
=
fpn_output
fpn_name_list
.
append
(
fpn_name
)
if
not
self
.
has_extra_convs
and
self
.
max_level
-
self
.
min_level
==
len
(
spatial_scale
):
body_top_name
=
fpn_name_list
[
0
]
body_top_extension
=
fluid
.
layers
.
pool2d
(
fpn_dict
[
body_top_name
],
1
,
'max'
,
pool_stride
=
2
,
name
=
body_top_name
+
'_subsampled_2x'
)
fpn_dict
[
body_top_name
+
'_subsampled_2x'
]
=
body_top_extension
fpn_name_list
.
insert
(
0
,
body_top_name
+
'_subsampled_2x'
)
spatial_scale
.
insert
(
0
,
spatial_scale
[
0
]
*
0.5
)
# Coarser FPN levels introduced for RetinaNet
highest_backbone_level
=
self
.
min_level
+
len
(
spatial_scale
)
-
1
if
self
.
has_extra_convs
and
self
.
max_level
>
highest_backbone_level
:
fpn_blob
=
body_dict
[
body_name_list
[
0
]]
for
i
in
range
(
highest_backbone_level
+
1
,
self
.
max_level
+
1
):
fpn_blob_in
=
fpn_blob
fpn_name
=
'fpn_'
+
str
(
i
)
if
i
>
highest_backbone_level
+
1
:
fpn_blob_in
=
fluid
.
layers
.
relu
(
fpn_blob
)
fan
=
fpn_blob_in
.
shape
[
1
]
*
3
*
3
fpn_blob
=
fluid
.
layers
.
conv2d
(
input
=
fpn_blob_in
,
num_filters
=
self
.
num_chan
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
param_attr
=
ParamAttr
(
name
=
fpn_name
+
"_w"
,
initializer
=
Xavier
(
fan_out
=
fan
)),
bias_attr
=
ParamAttr
(
name
=
fpn_name
+
"_b"
,
learning_rate
=
2.
,
regularizer
=
L2Decay
(
0.
)),
name
=
fpn_name
)
fpn_dict
[
fpn_name
]
=
fpn_blob
fpn_name_list
.
insert
(
0
,
fpn_name
)
spatial_scale
.
insert
(
0
,
spatial_scale
[
0
]
*
0.5
)
res_dict
=
OrderedDict
([(
k
,
fpn_dict
[
k
])
for
k
in
fpn_name_list
])
return
res_dict
,
spatial_scale
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/module.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
ast
import
argparse
from
collections
import
OrderedDict
from
functools
import
partial
from
math
import
ceil
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddlehub.io.parser
import
txt_parser
from
paddlehub.common.paddle_helper
import
add_vars_prefix
from
faster_rcnn_resnet50_fpn_venus.processor
import
load_label_info
,
postprocess
,
base64_to_cv2
from
faster_rcnn_resnet50_fpn_venus.data_feed
import
test_reader
,
padding_minibatch
from
faster_rcnn_resnet50_fpn_venus.fpn
import
FPN
from
faster_rcnn_resnet50_fpn_venus.resnet
import
ResNet
from
faster_rcnn_resnet50_fpn_venus.rpn_head
import
AnchorGenerator
,
RPNTargetAssign
,
GenerateProposals
,
FPNRPNHead
from
faster_rcnn_resnet50_fpn_venus.bbox_head
import
MultiClassNMS
,
BBoxHead
,
TwoFCHead
from
faster_rcnn_resnet50_fpn_venus.bbox_assigner
import
BBoxAssigner
from
faster_rcnn_resnet50_fpn_venus.roi_extractor
import
FPNRoIAlign
@
moduleinfo
(
name
=
"faster_rcnn_resnet50_fpn_venus"
,
version
=
"1.0.0"
,
type
=
"cv/object_detection"
,
summary
=
"Baidu's Faster-RCNN model for object detection, whose backbone is ResNet50, processed with Feature Pyramid Networks"
,
author
=
"paddlepaddle"
,
author_email
=
"paddle-dev@baidu.com"
)
class
FasterRCNNResNet50RPN
(
hub
.
Module
):
def
_initialize
(
self
):
# default pretrained model, Faster-RCNN with backbone ResNet50, shape of input tensor is [3, 800, 1333]
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"faster_rcnn_resnet50_fpn_model"
)
def
context
(
self
,
num_classes
=
708
,
trainable
=
True
,
pretrained
=
True
,
phase
=
'train'
):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
phase (str): optional choices are 'train' and 'predict'.
Returns:
inputs (dict): the input variables.
outputs (dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
context_prog
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
-
1
,
3
,
-
1
,
-
1
],
dtype
=
'float32'
)
# backbone
backbone
=
ResNet
(
norm_type
=
'affine_channel'
,
depth
=
50
,
feature_maps
=
[
2
,
3
,
4
,
5
],
freeze_at
=
2
)
body_feats
=
backbone
(
image
)
# fpn
fpn
=
FPN
(
max_level
=
6
,
min_level
=
2
,
num_chan
=
256
,
spatial_scale
=
[
0.03125
,
0.0625
,
0.125
,
0.25
])
var_prefix
=
'@HUB_{}@'
.
format
(
self
.
name
)
im_info
=
fluid
.
layers
.
data
(
name
=
'im_info'
,
shape
=
[
3
],
dtype
=
'float32'
,
lod_level
=
0
)
im_shape
=
fluid
.
layers
.
data
(
name
=
'im_shape'
,
shape
=
[
3
],
dtype
=
'float32'
,
lod_level
=
0
)
body_feat_names
=
list
(
body_feats
.
keys
())
body_feats
,
spatial_scale
=
fpn
.
get_output
(
body_feats
)
# rpn_head: RPNHead
rpn_head
=
self
.
rpn_head
()
rois
=
rpn_head
.
get_proposals
(
body_feats
,
im_info
,
mode
=
phase
)
# train
if
phase
==
'train'
:
gt_bbox
=
fluid
.
layers
.
data
(
name
=
'gt_bbox'
,
shape
=
[
4
],
dtype
=
'float32'
,
lod_level
=
1
)
is_crowd
=
fluid
.
layers
.
data
(
name
=
'is_crowd'
,
shape
=
[
1
],
dtype
=
'int32'
,
lod_level
=
1
)
gt_class
=
fluid
.
layers
.
data
(
name
=
'gt_class'
,
shape
=
[
1
],
dtype
=
'int32'
,
lod_level
=
1
)
rpn_loss
=
rpn_head
.
get_loss
(
im_info
,
gt_bbox
,
is_crowd
)
# bbox_assigner: BBoxAssigner
bbox_assigner
=
self
.
bbox_assigner
(
num_classes
)
outs
=
fluid
.
layers
.
generate_proposal_labels
(
rpn_rois
=
rois
,
gt_classes
=
gt_class
,
is_crowd
=
is_crowd
,
gt_boxes
=
gt_bbox
,
im_info
=
im_info
,
batch_size_per_im
=
bbox_assigner
.
batch_size_per_im
,
fg_fraction
=
bbox_assigner
.
fg_fraction
,
fg_thresh
=
bbox_assigner
.
fg_thresh
,
bg_thresh_hi
=
bbox_assigner
.
bg_thresh_hi
,
bg_thresh_lo
=
bbox_assigner
.
bg_thresh_lo
,
bbox_reg_weights
=
bbox_assigner
.
bbox_reg_weights
,
class_nums
=
bbox_assigner
.
class_nums
,
use_random
=
bbox_assigner
.
use_random
)
rois
=
outs
[
0
]
roi_extractor
=
self
.
roi_extractor
()
roi_feat
=
roi_extractor
(
head_inputs
=
body_feats
,
rois
=
rois
,
spatial_scale
=
spatial_scale
)
# head_feat
bbox_head
=
self
.
bbox_head
(
num_classes
)
head_feat
=
bbox_head
.
head
(
roi_feat
)
if
isinstance
(
head_feat
,
OrderedDict
):
head_feat
=
list
(
head_feat
.
values
())[
0
]
if
phase
==
'train'
:
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_info'
:
var_prefix
+
im_info
.
name
,
'im_shape'
:
var_prefix
+
im_shape
.
name
,
'gt_class'
:
var_prefix
+
gt_class
.
name
,
'gt_bbox'
:
var_prefix
+
gt_bbox
.
name
,
'is_crowd'
:
var_prefix
+
is_crowd
.
name
}
outputs
=
{
'head_features'
:
var_prefix
+
head_feat
.
name
,
'rpn_cls_loss'
:
var_prefix
+
rpn_loss
[
'rpn_cls_loss'
].
name
,
'rpn_reg_loss'
:
var_prefix
+
rpn_loss
[
'rpn_reg_loss'
].
name
,
'generate_proposal_labels'
:
[
var_prefix
+
var
.
name
for
var
in
outs
]
}
elif
phase
==
'predict'
:
pred
=
bbox_head
.
get_prediction
(
roi_feat
,
rois
,
im_info
,
im_shape
)
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_info'
:
var_prefix
+
im_info
.
name
,
'im_shape'
:
var_prefix
+
im_shape
.
name
}
outputs
=
{
'head_features'
:
var_prefix
+
head_feat
.
name
,
'rois'
:
var_prefix
+
rois
.
name
,
'bbox_out'
:
var_prefix
+
pred
.
name
}
add_vars_prefix
(
context_prog
,
var_prefix
)
add_vars_prefix
(
startup_program
,
var_prefix
)
global_vars
=
context_prog
.
global_block
().
vars
inputs
=
{
key
:
global_vars
[
value
]
for
key
,
value
in
inputs
.
items
()
}
outputs
=
{
key
:
global_vars
[
value
]
if
not
isinstance
(
value
,
list
)
else
[
global_vars
[
var
]
for
var
in
value
]
for
key
,
value
in
outputs
.
items
()
}
for
param
in
context_prog
.
global_block
().
iter_parameters
():
param
.
trainable
=
trainable
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
if
pretrained
:
def
_if_exist
(
var
):
if
num_classes
!=
81
:
if
'bbox_pred'
in
var
.
name
or
'cls_score'
in
var
.
name
:
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
return
inputs
,
outputs
,
context_prog
def
rpn_head
(
self
):
return
FPNRPNHead
(
anchor_generator
=
AnchorGenerator
(
anchor_sizes
=
[
32
,
64
,
128
,
256
,
512
],
aspect_ratios
=
[
0.5
,
1.0
,
2.0
],
stride
=
[
16.0
,
16.0
],
variance
=
[
1.0
,
1.0
,
1.0
,
1.0
]),
rpn_target_assign
=
RPNTargetAssign
(
rpn_batch_size_per_im
=
256
,
rpn_fg_fraction
=
0.5
,
rpn_negative_overlap
=
0.3
,
rpn_positive_overlap
=
0.7
,
rpn_straddle_thresh
=
0.0
),
train_proposal
=
GenerateProposals
(
min_size
=
0.0
,
nms_thresh
=
0.7
,
post_nms_top_n
=
2000
,
pre_nms_top_n
=
2000
),
test_proposal
=
GenerateProposals
(
min_size
=
0.0
,
nms_thresh
=
0.7
,
post_nms_top_n
=
1000
,
pre_nms_top_n
=
1000
),
anchor_start_size
=
32
,
num_chan
=
256
,
min_level
=
2
,
max_level
=
6
)
def
roi_extractor
(
self
):
return
FPNRoIAlign
(
canconical_level
=
4
,
canonical_size
=
224
,
max_level
=
5
,
min_level
=
2
,
box_resolution
=
7
,
sampling_ratio
=
2
)
def
bbox_head
(
self
,
num_classes
):
return
BBoxHead
(
head
=
TwoFCHead
(
mlp_dim
=
1024
),
nms
=
MultiClassNMS
(
keep_top_k
=
100
,
nms_threshold
=
0.5
,
score_threshold
=
0.05
),
num_classes
=
num_classes
)
def
bbox_assigner
(
self
,
num_classes
):
return
BBoxAssigner
(
batch_size_per_im
=
512
,
bbox_reg_weights
=
[
0.1
,
0.1
,
0.2
,
0.2
],
bg_thresh_hi
=
0.5
,
bg_thresh_lo
=
0.0
,
fg_fraction
=
0.25
,
fg_thresh
=
0.5
,
class_nums
=
num_classes
)
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/name_adapter.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
class
NameAdapter
(
object
):
"""Fix the backbones variable names for pretrained weight"""
def
__init__
(
self
,
model
):
super
(
NameAdapter
,
self
).
__init__
()
self
.
model
=
model
@
property
def
model_type
(
self
):
return
getattr
(
self
.
model
,
'_model_type'
,
''
)
@
property
def
variant
(
self
):
return
getattr
(
self
.
model
,
'variant'
,
''
)
def
fix_conv_norm_name
(
self
,
name
):
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
# the naming rule is same as pretrained weight
if
self
.
model_type
==
'SEResNeXt'
:
bn_name
=
name
+
"_bn"
return
bn_name
def
fix_shortcut_name
(
self
,
name
):
if
self
.
model_type
==
'SEResNeXt'
:
name
=
'conv'
+
name
+
'_prj'
return
name
def
fix_bottleneck_name
(
self
,
name
):
if
self
.
model_type
==
'SEResNeXt'
:
conv_name1
=
'conv'
+
name
+
'_x1'
conv_name2
=
'conv'
+
name
+
'_x2'
conv_name3
=
'conv'
+
name
+
'_x3'
shortcut_name
=
name
else
:
conv_name1
=
name
+
"_branch2a"
conv_name2
=
name
+
"_branch2b"
conv_name3
=
name
+
"_branch2c"
shortcut_name
=
name
+
"_branch1"
return
conv_name1
,
conv_name2
,
conv_name3
,
shortcut_name
def
fix_layer_warp_name
(
self
,
stage_num
,
count
,
i
):
name
=
'res'
+
str
(
stage_num
)
if
count
>
10
and
stage_num
==
4
:
if
i
==
0
:
conv_name
=
name
+
"a"
else
:
conv_name
=
name
+
"b"
+
str
(
i
)
else
:
conv_name
=
name
+
chr
(
ord
(
"a"
)
+
i
)
if
self
.
model_type
==
'SEResNeXt'
:
conv_name
=
str
(
stage_num
+
2
)
+
'_'
+
str
(
i
+
1
)
return
conv_name
def
fix_c1_stage_name
(
self
):
return
"res_conv1"
if
self
.
model_type
==
'ResNeXt'
else
"conv1"
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/nonlocal_helper.py
0 → 100644
浏览文件 @
9b82f2fb
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
paddle.fluid
as
fluid
from
paddle.fluid
import
ParamAttr
nonlocal_params
=
{
"use_zero_init_conv"
:
False
,
"conv_init_std"
:
0.01
,
"no_bias"
:
True
,
"use_maxpool"
:
False
,
"use_softmax"
:
True
,
"use_bn"
:
False
,
"use_scale"
:
True
,
# vital for the model prformance!!!
"use_affine"
:
False
,
"bn_momentum"
:
0.9
,
"bn_epsilon"
:
1.0000001e-5
,
"bn_init_gamma"
:
0.9
,
"weight_decay_bn"
:
1.e-4
,
}
def
space_nonlocal
(
input
,
dim_in
,
dim_out
,
prefix
,
dim_inner
,
max_pool_stride
=
2
):
cur
=
input
theta
=
fluid
.
layers
.
conv2d
(
input
=
cur
,
num_filters
=
dim_inner
,
\
filter_size
=
[
1
,
1
],
stride
=
[
1
,
1
],
\
padding
=
[
0
,
0
],
\
param_attr
=
ParamAttr
(
name
=
prefix
+
'_theta'
+
"_w"
,
\
initializer
=
fluid
.
initializer
.
Normal
(
loc
=
0.0
,
scale
=
nonlocal_params
[
"conv_init_std"
])),
\
bias_attr
=
ParamAttr
(
name
=
prefix
+
'_theta'
+
"_b"
,
\
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.
))
\
if
not
nonlocal_params
[
"no_bias"
]
else
False
,
\
name
=
prefix
+
'_theta'
)
theta_shape
=
theta
.
shape
theta_shape_op
=
fluid
.
layers
.
shape
(
theta
)
theta_shape_op
.
stop_gradient
=
True
if
nonlocal_params
[
"use_maxpool"
]:
max_pool
=
fluid
.
layers
.
pool2d
(
input
=
cur
,
\
pool_size
=
[
max_pool_stride
,
max_pool_stride
],
\
pool_type
=
'max'
,
\
pool_stride
=
[
max_pool_stride
,
max_pool_stride
],
\
pool_padding
=
[
0
,
0
],
\
name
=
prefix
+
'_pool'
)
else
:
max_pool
=
cur
phi
=
fluid
.
layers
.
conv2d
(
input
=
max_pool
,
num_filters
=
dim_inner
,
\
filter_size
=
[
1
,
1
],
stride
=
[
1
,
1
],
\
padding
=
[
0
,
0
],
\
param_attr
=
ParamAttr
(
name
=
prefix
+
'_phi'
+
"_w"
,
\
initializer
=
fluid
.
initializer
.
Normal
(
loc
=
0.0
,
scale
=
nonlocal_params
[
"conv_init_std"
])),
\
bias_attr
=
ParamAttr
(
name
=
prefix
+
'_phi'
+
"_b"
,
\
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.
))
\
if
(
nonlocal_params
[
"no_bias"
]
==
0
)
else
False
,
\
name
=
prefix
+
'_phi'
)
phi_shape
=
phi
.
shape
g
=
fluid
.
layers
.
conv2d
(
input
=
max_pool
,
num_filters
=
dim_inner
,
\
filter_size
=
[
1
,
1
],
stride
=
[
1
,
1
],
\
padding
=
[
0
,
0
],
\
param_attr
=
ParamAttr
(
name
=
prefix
+
'_g'
+
"_w"
,
\
initializer
=
fluid
.
initializer
.
Normal
(
loc
=
0.0
,
scale
=
nonlocal_params
[
"conv_init_std"
])),
\
bias_attr
=
ParamAttr
(
name
=
prefix
+
'_g'
+
"_b"
,
\
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.
))
if
(
nonlocal_params
[
"no_bias"
]
==
0
)
else
False
,
\
name
=
prefix
+
'_g'
)
g_shape
=
g
.
shape
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta
=
fluid
.
layers
.
reshape
(
theta
,
shape
=
(
0
,
0
,
-
1
))
theta
=
fluid
.
layers
.
transpose
(
theta
,
[
0
,
2
,
1
])
phi
=
fluid
.
layers
.
reshape
(
phi
,
[
0
,
0
,
-
1
])
theta_phi
=
fluid
.
layers
.
matmul
(
theta
,
phi
,
name
=
prefix
+
'_affinity'
)
g
=
fluid
.
layers
.
reshape
(
g
,
[
0
,
0
,
-
1
])
if
nonlocal_params
[
"use_softmax"
]:
if
nonlocal_params
[
"use_scale"
]:
theta_phi_sc
=
fluid
.
layers
.
scale
(
theta_phi
,
scale
=
dim_inner
**-
.
5
)
else
:
theta_phi_sc
=
theta_phi
p
=
fluid
.
layers
.
softmax
(
theta_phi_sc
,
name
=
prefix
+
'_affinity'
+
'_prob'
)
else
:
# not clear about what is doing in xlw's code
p
=
None
# not implemented
raise
"Not implemented when not use softmax"
# note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p
=
fluid
.
layers
.
transpose
(
p
,
[
0
,
2
,
1
])
t
=
fluid
.
layers
.
matmul
(
g
,
p
,
name
=
prefix
+
'_y'
)
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_shape
=
t
.
shape
t_re
=
fluid
.
layers
.
reshape
(
t
,
shape
=
list
(
theta_shape
),
actual_shape
=
theta_shape_op
)
blob_out
=
t_re
blob_out
=
fluid
.
layers
.
conv2d
(
input
=
blob_out
,
num_filters
=
dim_out
,
\
filter_size
=
[
1
,
1
],
stride
=
[
1
,
1
],
padding
=
[
0
,
0
],
\
param_attr
=
ParamAttr
(
name
=
prefix
+
'_out'
+
"_w"
,
\
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.
)
\
if
nonlocal_params
[
"use_zero_init_conv"
]
\
else
fluid
.
initializer
.
Normal
(
loc
=
0.0
,
scale
=
nonlocal_params
[
"conv_init_std"
])),
\
bias_attr
=
ParamAttr
(
name
=
prefix
+
'_out'
+
"_b"
,
\
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.
))
\
if
(
nonlocal_params
[
"no_bias"
]
==
0
)
else
False
,
\
name
=
prefix
+
'_out'
)
blob_out_shape
=
blob_out
.
shape
if
nonlocal_params
[
"use_bn"
]:
bn_name
=
prefix
+
"_bn"
blob_out
=
fluid
.
layers
.
batch_norm
(
blob_out
,
\
# is_test = test_mode, \
momentum
=
nonlocal_params
[
"bn_momentum"
],
\
epsilon
=
nonlocal_params
[
"bn_epsilon"
],
\
name
=
bn_name
,
\
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_s"
,
\
initializer
=
fluid
.
initializer
.
Constant
(
value
=
nonlocal_params
[
"bn_init_gamma"
]),
\
regularizer
=
fluid
.
regularizer
.
L2Decay
(
nonlocal_params
[
"weight_decay_bn"
])),
\
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_b"
,
\
regularizer
=
fluid
.
regularizer
.
L2Decay
(
nonlocal_params
[
"weight_decay_bn"
])),
\
moving_mean_name
=
bn_name
+
"_rm"
,
\
moving_variance_name
=
bn_name
+
"_riv"
)
# add bn
if
nonlocal_params
[
"use_affine"
]:
affine_scale
=
fluid
.
layers
.
create_parameter
(
\
shape
=
[
blob_out_shape
[
1
]],
dtype
=
blob_out
.
dtype
,
\
attr
=
ParamAttr
(
name
=
prefix
+
'_affine'
+
'_s'
),
\
default_initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.
))
affine_bias
=
fluid
.
layers
.
create_parameter
(
\
shape
=
[
blob_out_shape
[
1
]],
dtype
=
blob_out
.
dtype
,
\
attr
=
ParamAttr
(
name
=
prefix
+
'_affine'
+
'_b'
),
\
default_initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.
))
blob_out
=
fluid
.
layers
.
affine_channel
(
blob_out
,
scale
=
affine_scale
,
\
bias
=
affine_bias
,
name
=
prefix
+
'_affine'
)
# add affine
return
blob_out
def
add_space_nonlocal
(
input
,
dim_in
,
dim_out
,
prefix
,
dim_inner
):
'''
add_space_nonlocal:
Non-local Neural Networks: see https://arxiv.org/abs/1711.07971
'''
conv
=
space_nonlocal
(
input
,
dim_in
,
dim_out
,
prefix
,
dim_inner
)
output
=
fluid
.
layers
.
elementwise_add
(
input
,
conv
,
name
=
prefix
+
'_sum'
)
return
output
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/processor.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
import
base64
import
os
import
cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageDraw
__all__
=
[
'base64_to_cv2'
,
'load_label_info'
,
'postprocess'
,
]
def
base64_to_cv2
(
b64str
):
data
=
base64
.
b64decode
(
b64str
.
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
def
get_save_image_name
(
img
,
output_dir
,
image_path
):
"""Get save image name from source image path.
"""
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
image_name
=
os
.
path
.
split
(
image_path
)[
-
1
]
name
,
ext
=
os
.
path
.
splitext
(
image_name
)
if
ext
==
''
:
if
img
.
format
==
'PNG'
:
ext
=
'.png'
elif
img
.
format
==
'JPEG'
:
ext
=
'.jpg'
elif
img
.
format
==
'BMP'
:
ext
=
'.bmp'
else
:
if
img
.
mode
==
"RGB"
or
img
.
mode
==
"L"
:
ext
=
".jpg"
elif
img
.
mode
==
"RGBA"
or
img
.
mode
==
"P"
:
ext
=
'.png'
return
os
.
path
.
join
(
output_dir
,
"{}"
.
format
(
name
))
+
ext
def
draw_bounding_box_on_image
(
image_path
,
data_list
,
save_dir
):
image
=
Image
.
open
(
image_path
)
draw
=
ImageDraw
.
Draw
(
image
)
for
data
in
data_list
:
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
# draw bbox
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
# draw label
if
image
.
mode
==
'RGB'
:
text
=
data
[
'label'
]
+
": %.2f%%"
%
(
100
*
data
[
'confidence'
])
textsize_width
,
textsize_height
=
draw
.
textsize
(
text
=
text
)
draw
.
rectangle
(
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
draw
.
text
(
xy
=
(
left
,
top
-
15
),
text
=
text
,
fill
=
(
0
,
0
,
0
))
save_name
=
get_save_image_name
(
image
,
save_dir
,
image_path
)
if
os
.
path
.
exists
(
save_name
):
os
.
remove
(
save_name
)
image
.
save
(
save_name
)
return
save_name
def
clip_bbox
(
bbox
,
img_width
,
img_height
):
xmin
=
max
(
min
(
bbox
[
0
],
img_width
),
0.
)
ymin
=
max
(
min
(
bbox
[
1
],
img_height
),
0.
)
xmax
=
max
(
min
(
bbox
[
2
],
img_width
),
0.
)
ymax
=
max
(
min
(
bbox
[
3
],
img_height
),
0.
)
return
xmin
,
ymin
,
xmax
,
ymax
def
load_label_info
(
file_path
):
with
open
(
file_path
,
'r'
)
as
fr
:
text
=
fr
.
readlines
()
label_names
=
[]
for
info
in
text
:
label_names
.
append
(
info
.
strip
())
return
label_names
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
"""
postprocess the lod_tensor produced by fluid.Executor.run
Args:
paths (list[str]): the path of images.
images (list(numpy.ndarray)): list of images, shape of each is [H, W, C].
data_out (lod_tensor): data produced by executor.run.
score_thresh (float): the low limit of bounding box.
label_names (list[str]): label names.
output_dir (str): output directory.
handle_id (int): The number of images that have been handled.
visualization (bool): whether to save as images.
Returns:
res (list[dict]): The result of vehicles detecion. keys include 'data', 'save_path', the corresponding value is:
data (dict): the result of object detection, keys include 'left', 'top', 'right', 'bottom', 'label', 'confidence', the corresponding value is:
left (float): The X coordinate of the upper left corner of the bounding box;
top (float): The Y coordinate of the upper left corner of the bounding box;
right (float): The X coordinate of the lower right corner of the bounding box;
bottom (float): The Y coordinate of the lower right corner of the bounding box;
label (str): The label of detection result;
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor
=
data_out
[
0
]
lod
=
lod_tensor
.
lod
[
0
]
results
=
lod_tensor
.
as_ndarray
()
if
handle_id
<
len
(
paths
):
unhandled_paths
=
paths
[
handle_id
:]
unhandled_paths_num
=
len
(
unhandled_paths
)
else
:
unhandled_paths_num
=
0
output
=
[]
for
index
in
range
(
len
(
lod
)
-
1
):
output_i
=
{
'data'
:
[]}
if
index
<
unhandled_paths_num
:
org_img_path
=
unhandled_paths
[
index
]
org_img
=
Image
.
open
(
org_img_path
)
output_i
[
'path'
]
=
org_img_path
else
:
org_img
=
images
[
index
-
unhandled_paths_num
]
org_img
=
org_img
.
astype
(
np
.
uint8
)
org_img
=
Image
.
fromarray
(
org_img
[:,
:,
::
-
1
])
if
visualization
:
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
(
(
handle_id
+
index
)))
org_img
.
save
(
org_img_path
)
org_img_height
=
org_img
.
height
org_img_width
=
org_img
.
width
result_i
=
results
[
lod
[
index
]:
lod
[
index
+
1
]]
for
row
in
result_i
:
if
len
(
row
)
!=
6
:
continue
if
row
[
1
]
<
score_thresh
:
continue
category_id
=
int
(
row
[
0
])
confidence
=
row
[
1
]
bbox
=
row
[
2
:]
dt
=
{}
dt
[
'label'
]
=
label_names
[
category_id
]
dt
[
'confidence'
]
=
confidence
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
output_i
[
'data'
].
append
(
dt
)
output
.
append
(
output_i
)
if
visualization
:
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
return
output
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/resnet.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
from
collections
import
OrderedDict
from
numbers
import
Integral
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.initializer
import
Constant
from
.nonlocal_helper
import
add_space_nonlocal
from
.name_adapter
import
NameAdapter
__all__
=
[
'ResNet'
,
'ResNetC5'
]
class
ResNet
(
object
):
"""
Residual Network, see https://arxiv.org/abs/1512.03385
Args:
depth (int): ResNet depth, should be 34, 50.
freeze_at (int): freeze the backbone at which stage
norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel'
freeze_norm (bool): freeze normalization layers
norm_decay (float): weight decay for normalization layer weights
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
nonlocal_stages (list): index of stages who select nonlocal networks
"""
__shared__
=
[
'norm_type'
,
'freeze_norm'
,
'weight_prefix_name'
]
def
__init__
(
self
,
depth
=
50
,
freeze_at
=
0
,
norm_type
=
'sync_bn'
,
freeze_norm
=
False
,
norm_decay
=
0.
,
variant
=
'b'
,
feature_maps
=
[
3
,
4
,
5
],
dcn_v2_stages
=
[],
weight_prefix_name
=
''
,
nonlocal_stages
=
[],
get_prediction
=
False
,
class_dim
=
1000
):
super
(
ResNet
,
self
).
__init__
()
if
isinstance
(
feature_maps
,
Integral
):
feature_maps
=
[
feature_maps
]
assert
depth
in
[
34
,
50
],
\
"depth {} not in [34, 50]"
assert
variant
in
[
'a'
,
'b'
,
'c'
,
'd'
],
"invalid ResNet variant"
assert
0
<=
freeze_at
<=
4
,
"freeze_at should be 0, 1, 2, 3 or 4"
assert
len
(
feature_maps
)
>
0
,
"need one or more feature maps"
assert
norm_type
in
[
'bn'
,
'sync_bn'
,
'affine_channel'
]
assert
not
(
len
(
nonlocal_stages
)
>
0
and
depth
<
50
),
\
"non-local is not supported for resnet18 or resnet34"
self
.
depth
=
depth
self
.
freeze_at
=
freeze_at
self
.
norm_type
=
norm_type
self
.
norm_decay
=
norm_decay
self
.
freeze_norm
=
freeze_norm
self
.
variant
=
variant
self
.
_model_type
=
'ResNet'
self
.
feature_maps
=
feature_maps
self
.
dcn_v2_stages
=
dcn_v2_stages
self
.
depth_cfg
=
{
34
:
([
3
,
4
,
6
,
3
],
self
.
basicblock
),
50
:
([
3
,
4
,
6
,
3
],
self
.
bottleneck
),
}
self
.
stage_filters
=
[
64
,
128
,
256
,
512
]
self
.
_c1_out_chan_num
=
64
self
.
na
=
NameAdapter
(
self
)
self
.
prefix_name
=
weight_prefix_name
self
.
nonlocal_stages
=
nonlocal_stages
self
.
nonlocal_mod_cfg
=
{
50
:
2
,
101
:
5
,
152
:
8
,
200
:
12
,
}
self
.
get_prediction
=
get_prediction
self
.
class_dim
=
class_dim
def
_conv_offset
(
self
,
input
,
filter_size
,
stride
,
padding
,
act
=
None
,
name
=
None
):
out_channel
=
filter_size
*
filter_size
*
3
out
=
fluid
.
layers
.
conv2d
(
input
,
num_filters
=
out_channel
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
param_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
),
name
=
name
+
".w_0"
),
bias_attr
=
ParamAttr
(
initializer
=
Constant
(
0.0
),
name
=
name
+
".b_0"
),
act
=
act
,
name
=
name
)
return
out
def
_conv_norm
(
self
,
input
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
,
dcn_v2
=
False
):
_name
=
self
.
prefix_name
+
name
if
self
.
prefix_name
!=
''
else
name
if
not
dcn_v2
:
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
_name
+
"_weights"
),
bias_attr
=
False
,
name
=
_name
+
'.conv2d.output.1'
)
else
:
# select deformable conv"
offset_mask
=
self
.
_conv_offset
(
input
=
input
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
act
=
None
,
name
=
_name
+
"_conv_offset"
)
offset_channel
=
filter_size
**
2
*
2
mask_channel
=
filter_size
**
2
offset
,
mask
=
fluid
.
layers
.
split
(
input
=
offset_mask
,
num_or_sections
=
[
offset_channel
,
mask_channel
],
dim
=
1
)
mask
=
fluid
.
layers
.
sigmoid
(
mask
)
conv
=
fluid
.
layers
.
deformable_conv
(
input
=
input
,
offset
=
offset
,
mask
=
mask
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
deformable_groups
=
1
,
im2col_step
=
1
,
param_attr
=
ParamAttr
(
name
=
_name
+
"_weights"
),
bias_attr
=
False
,
name
=
_name
+
".conv2d.output.1"
)
bn_name
=
self
.
na
.
fix_conv_norm_name
(
name
)
bn_name
=
self
.
prefix_name
+
bn_name
if
self
.
prefix_name
!=
''
else
bn_name
norm_lr
=
0.
if
self
.
freeze_norm
else
1.
norm_decay
=
self
.
norm_decay
pattr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
,
learning_rate
=
norm_lr
,
regularizer
=
L2Decay
(
norm_decay
))
battr
=
ParamAttr
(
name
=
bn_name
+
'_offset'
,
learning_rate
=
norm_lr
,
regularizer
=
L2Decay
(
norm_decay
))
if
self
.
norm_type
in
[
'bn'
,
'sync_bn'
]:
global_stats
=
True
if
self
.
freeze_norm
else
False
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
name
=
bn_name
+
'.output.1'
,
param_attr
=
pattr
,
bias_attr
=
battr
,
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
use_global_stats
=
global_stats
)
scale
=
fluid
.
framework
.
_get_var
(
pattr
.
name
)
bias
=
fluid
.
framework
.
_get_var
(
battr
.
name
)
elif
self
.
norm_type
==
'affine_channel'
:
scale
=
fluid
.
layers
.
create_parameter
(
shape
=
[
conv
.
shape
[
1
]],
dtype
=
conv
.
dtype
,
attr
=
pattr
,
default_initializer
=
fluid
.
initializer
.
Constant
(
1.
))
bias
=
fluid
.
layers
.
create_parameter
(
shape
=
[
conv
.
shape
[
1
]],
dtype
=
conv
.
dtype
,
attr
=
battr
,
default_initializer
=
fluid
.
initializer
.
Constant
(
0.
))
out
=
fluid
.
layers
.
affine_channel
(
x
=
conv
,
scale
=
scale
,
bias
=
bias
,
act
=
act
)
if
self
.
freeze_norm
:
scale
.
stop_gradient
=
True
bias
.
stop_gradient
=
True
return
out
def
_shortcut
(
self
,
input
,
ch_out
,
stride
,
is_first
,
name
):
max_pooling_in_short_cut
=
self
.
variant
==
'd'
ch_in
=
input
.
shape
[
1
]
# the naming rule is same as pretrained weight
name
=
self
.
na
.
fix_shortcut_name
(
name
)
std_senet
=
getattr
(
self
,
'std_senet'
,
False
)
if
ch_in
!=
ch_out
or
stride
!=
1
or
(
self
.
depth
<
50
and
is_first
):
if
std_senet
:
if
is_first
:
return
self
.
_conv_norm
(
input
,
ch_out
,
1
,
stride
,
name
=
name
)
else
:
return
self
.
_conv_norm
(
input
,
ch_out
,
3
,
stride
,
name
=
name
)
if
max_pooling_in_short_cut
and
not
is_first
:
input
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
2
,
pool_stride
=
2
,
pool_padding
=
0
,
ceil_mode
=
True
,
pool_type
=
'avg'
)
return
self
.
_conv_norm
(
input
,
ch_out
,
1
,
1
,
name
=
name
)
return
self
.
_conv_norm
(
input
,
ch_out
,
1
,
stride
,
name
=
name
)
else
:
return
input
def
bottleneck
(
self
,
input
,
num_filters
,
stride
,
is_first
,
name
,
dcn_v2
=
False
):
if
self
.
variant
==
'a'
:
stride1
,
stride2
=
stride
,
1
else
:
stride1
,
stride2
=
1
,
stride
# ResNeXt
groups
=
getattr
(
self
,
'groups'
,
1
)
group_width
=
getattr
(
self
,
'group_width'
,
-
1
)
if
groups
==
1
:
expand
=
4
elif
(
groups
*
group_width
)
==
256
:
expand
=
1
else
:
# FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
num_filters
=
num_filters
//
2
expand
=
2
conv_name1
,
conv_name2
,
conv_name3
,
\
shortcut_name
=
self
.
na
.
fix_bottleneck_name
(
name
)
std_senet
=
getattr
(
self
,
'std_senet'
,
False
)
if
std_senet
:
conv_def
=
[[
int
(
num_filters
/
2
),
1
,
stride1
,
'relu'
,
1
,
conv_name1
],
[
num_filters
,
3
,
stride2
,
'relu'
,
groups
,
conv_name2
],
[
num_filters
*
expand
,
1
,
1
,
None
,
1
,
conv_name3
]]
else
:
conv_def
=
[[
num_filters
,
1
,
stride1
,
'relu'
,
1
,
conv_name1
],
[
num_filters
,
3
,
stride2
,
'relu'
,
groups
,
conv_name2
],
[
num_filters
*
expand
,
1
,
1
,
None
,
1
,
conv_name3
]]
residual
=
input
for
i
,
(
c
,
k
,
s
,
act
,
g
,
_name
)
in
enumerate
(
conv_def
):
residual
=
self
.
_conv_norm
(
input
=
residual
,
num_filters
=
c
,
filter_size
=
k
,
stride
=
s
,
act
=
act
,
groups
=
g
,
name
=
_name
,
dcn_v2
=
(
i
==
1
and
dcn_v2
))
short
=
self
.
_shortcut
(
input
,
num_filters
*
expand
,
stride
,
is_first
=
is_first
,
name
=
shortcut_name
)
# Squeeze-and-Excitation
if
callable
(
getattr
(
self
,
'_squeeze_excitation'
,
None
)):
residual
=
self
.
_squeeze_excitation
(
input
=
residual
,
num_channels
=
num_filters
,
name
=
'fc'
+
name
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
residual
,
act
=
'relu'
,
name
=
name
+
".add.output.5"
)
def
basicblock
(
self
,
input
,
num_filters
,
stride
,
is_first
,
name
,
dcn_v2
=
False
):
assert
dcn_v2
is
False
,
"Not implemented yet."
conv0
=
self
.
_conv_norm
(
input
=
input
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
'relu'
,
stride
=
stride
,
name
=
name
+
"_branch2a"
)
conv1
=
self
.
_conv_norm
(
input
=
conv0
,
num_filters
=
num_filters
,
filter_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
short
=
self
.
_shortcut
(
input
,
num_filters
,
stride
,
is_first
,
name
=
name
+
"_branch1"
)
return
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv1
,
act
=
'relu'
)
def
layer_warp
(
self
,
input
,
stage_num
):
"""
Args:
input (Variable): input variable.
stage_num (int): the stage number, should be 2, 3, 4, 5
Returns:
The last variable in endpoint-th stage.
"""
assert
stage_num
in
[
2
,
3
,
4
,
5
]
stages
,
block_func
=
self
.
depth_cfg
[
self
.
depth
]
count
=
stages
[
stage_num
-
2
]
ch_out
=
self
.
stage_filters
[
stage_num
-
2
]
is_first
=
False
if
stage_num
!=
2
else
True
dcn_v2
=
True
if
stage_num
in
self
.
dcn_v2_stages
else
False
nonlocal_mod
=
1000
if
stage_num
in
self
.
nonlocal_stages
:
nonlocal_mod
=
self
.
nonlocal_mod_cfg
[
self
.
depth
]
if
stage_num
==
4
else
2
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv
=
input
for
i
in
range
(
count
):
conv_name
=
self
.
na
.
fix_layer_warp_name
(
stage_num
,
count
,
i
)
if
self
.
depth
<
50
:
is_first
=
True
if
i
==
0
and
stage_num
==
2
else
False
conv
=
block_func
(
input
=
conv
,
num_filters
=
ch_out
,
stride
=
2
if
i
==
0
and
stage_num
!=
2
else
1
,
is_first
=
is_first
,
name
=
conv_name
,
dcn_v2
=
dcn_v2
)
# add non local model
dim_in
=
conv
.
shape
[
1
]
nonlocal_name
=
"nonlocal_conv{}"
.
format
(
stage_num
)
if
i
%
nonlocal_mod
==
nonlocal_mod
-
1
:
conv
=
add_space_nonlocal
(
conv
,
dim_in
,
dim_in
,
nonlocal_name
+
'_{}'
.
format
(
i
),
int
(
dim_in
/
2
))
return
conv
def
c1_stage
(
self
,
input
):
out_chan
=
self
.
_c1_out_chan_num
conv1_name
=
self
.
na
.
fix_c1_stage_name
()
if
self
.
variant
in
[
'c'
,
'd'
]:
conv_def
=
[
[
out_chan
//
2
,
3
,
2
,
"conv1_1"
],
[
out_chan
//
2
,
3
,
1
,
"conv1_2"
],
[
out_chan
,
3
,
1
,
"conv1_3"
],
]
else
:
conv_def
=
[[
out_chan
,
7
,
2
,
conv1_name
]]
for
(
c
,
k
,
s
,
_name
)
in
conv_def
:
input
=
self
.
_conv_norm
(
input
=
input
,
num_filters
=
c
,
filter_size
=
k
,
stride
=
s
,
act
=
'relu'
,
name
=
_name
)
output
=
fluid
.
layers
.
pool2d
(
input
=
input
,
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
return
output
def
__call__
(
self
,
input
):
assert
isinstance
(
input
,
Variable
)
assert
not
(
set
(
self
.
feature_maps
)
-
set
([
2
,
3
,
4
,
5
])),
\
"feature maps {} not in [2, 3, 4, 5]"
.
format
(
self
.
feature_maps
)
res_endpoints
=
[]
res
=
input
feature_maps
=
self
.
feature_maps
severed_head
=
getattr
(
self
,
'severed_head'
,
False
)
if
not
severed_head
:
res
=
self
.
c1_stage
(
res
)
feature_maps
=
range
(
2
,
max
(
self
.
feature_maps
)
+
1
)
for
i
in
feature_maps
:
res
=
self
.
layer_warp
(
res
,
i
)
if
i
in
self
.
feature_maps
:
res_endpoints
.
append
(
res
)
if
self
.
freeze_at
>=
i
:
res
.
stop_gradient
=
True
if
self
.
get_prediction
:
pool
=
fluid
.
layers
.
pool2d
(
input
=
res
,
pool_type
=
'avg'
,
global_pooling
=
True
)
stdv
=
1.0
/
math
.
sqrt
(
pool
.
shape
[
1
]
*
1.0
)
out
=
fluid
.
layers
.
fc
(
input
=
pool
,
size
=
self
.
class_dim
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
)))
out
=
fluid
.
layers
.
softmax
(
out
)
return
out
return
OrderedDict
([(
'res{}_sum'
.
format
(
self
.
feature_maps
[
idx
]),
feat
)
for
idx
,
feat
in
enumerate
(
res_endpoints
)])
class
ResNetC5
(
ResNet
):
def
__init__
(
self
,
depth
=
50
,
freeze_at
=
2
,
norm_type
=
'affine_channel'
,
freeze_norm
=
True
,
norm_decay
=
0.
,
variant
=
'b'
,
feature_maps
=
[
5
],
weight_prefix_name
=
''
):
super
(
ResNetC5
,
self
).
__init__
(
depth
,
freeze_at
,
norm_type
,
freeze_norm
,
norm_decay
,
variant
,
feature_maps
)
self
.
severed_head
=
True
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/roi_extractor.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
import
paddle.fluid
as
fluid
__all__
=
[
'FPNRoIAlign'
]
class
FPNRoIAlign
(
object
):
"""
RoI align pooling for FPN feature maps
Args:
sampling_ratio (int): number of sampling points
min_level (int): lowest level of FPN layer
max_level (int): highest level of FPN layer
canconical_level (int): the canconical FPN feature map level
canonical_size (int): the canconical FPN feature map size
box_resolution (int): box resolution
mask_resolution (int): mask roi resolution
"""
def
__init__
(
self
,
sampling_ratio
=
0
,
min_level
=
2
,
max_level
=
5
,
canconical_level
=
4
,
canonical_size
=
224
,
box_resolution
=
7
,
mask_resolution
=
14
):
super
(
FPNRoIAlign
,
self
).
__init__
()
self
.
sampling_ratio
=
sampling_ratio
self
.
min_level
=
min_level
self
.
max_level
=
max_level
self
.
canconical_level
=
canconical_level
self
.
canonical_size
=
canonical_size
self
.
box_resolution
=
box_resolution
self
.
mask_resolution
=
mask_resolution
def
__call__
(
self
,
head_inputs
,
rois
,
spatial_scale
,
is_mask
=
False
):
"""
Adopt RoI align onto several level of feature maps to get RoI features.
Distribute RoIs to different levels by area and get a list of RoI
features by distributed RoIs and their corresponding feature maps.
Returns:
roi_feat(Variable): RoI features with shape of [M, C, R, R],
where M is the number of RoIs and R is RoI resolution
"""
k_min
=
self
.
min_level
k_max
=
self
.
max_level
num_roi_lvls
=
k_max
-
k_min
+
1
name_list
=
list
(
head_inputs
.
keys
())
input_name_list
=
name_list
[
-
num_roi_lvls
:]
spatial_scale
=
spatial_scale
[
-
num_roi_lvls
:]
rois_dist
,
restore_index
=
fluid
.
layers
.
distribute_fpn_proposals
(
rois
,
k_min
,
k_max
,
self
.
canconical_level
,
self
.
canonical_size
)
# rois_dist is in ascend order
roi_out_list
=
[]
resolution
=
is_mask
and
self
.
mask_resolution
or
self
.
box_resolution
for
lvl
in
range
(
num_roi_lvls
):
name_index
=
num_roi_lvls
-
lvl
-
1
rois_input
=
rois_dist
[
lvl
]
head_input
=
head_inputs
[
input_name_list
[
name_index
]]
sc
=
spatial_scale
[
name_index
]
roi_out
=
fluid
.
layers
.
roi_align
(
input
=
head_input
,
rois
=
rois_input
,
pooled_height
=
resolution
,
pooled_width
=
resolution
,
spatial_scale
=
sc
,
sampling_ratio
=
self
.
sampling_ratio
)
roi_out_list
.
append
(
roi_out
)
roi_feat_shuffle
=
fluid
.
layers
.
concat
(
roi_out_list
)
roi_feat_
=
fluid
.
layers
.
gather
(
roi_feat_shuffle
,
restore_index
)
roi_feat
=
fluid
.
layers
.
lod_reset
(
roi_feat_
,
rois
)
return
roi_feat
hub_module/modules/image/object_detection/faster_rcnn_resnet50_fpn_venus/rpn_head.py
0 → 100644
浏览文件 @
9b82f2fb
此差异已折叠。
点击以展开。
hub_module/modules/image/object_detection/yolov3_darknet53_venus/README.md
0 → 100644
浏览文件 @
9b82f2fb
## 命令行预测
```
shell
$
hub run yolov3_darknet53_venus
--input_path
"/PATH/TO/IMAGE"
```
## API
```
python
def
context
(
trainable
=
True
,
pretrained
=
True
,
get_prediction
=
False
)
```
提取特征,用于迁移学习。
**参数**
*
trainable(bool): 参数是否可训练;
*
pretrained (bool): 是否加载预训练模型;
*
get
\_
prediction (bool): 是否执行预测。
**返回**
*
inputs (dict): 模型的输入,keys 包括 'image', 'im
\_
size',相应的取值为:
*
image (Variable): 图像变量
*
im
\_
size (Variable): 图片的尺寸
*
outputs (dict): 模型的输出。如果 get
\_
prediction 为 False,输出 'head
\_
features'、'body
\_
features',否则输出 'bbox
\_
out'。
*
context
\_
prog (Program): 用于迁移学习的 Program.
```
python
def
save_inference_model
(
dirname
,
model_filename
=
None
,
params_filename
=
None
,
combined
=
True
)
```
将模型保存到指定路径。
**参数**
*
dirname: 存在模型的目录名称
*
model
\_
filename: 模型文件名称,默认为
\_\_
model
\_\_
*
params
\_
filename: 参数文件名称,默认为
\_\_
params
\_\_
(仅当
`combined`
为True时生效)
*
combined: 是否将参数保存到统一的一个文件中
### 依赖
paddlepaddle >= 1.6.2
paddlehub >= 1.6.0
hub_module/modules/image/object_detection/yolov3_darknet53_venus/__init__.py
0 → 100644
浏览文件 @
9b82f2fb
hub_module/modules/image/object_detection/yolov3_darknet53_venus/darknet.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
six
import
math
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
__all__
=
[
'DarkNet'
]
class
DarkNet
(
object
):
"""DarkNet, see https://pjreddie.com/darknet/yolo/
Args:
depth (int): network depth, currently only darknet 53 is supported
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights
get_prediction (bool): whether to get prediction
class_dim (int): number of class while classification
"""
def
__init__
(
self
,
depth
=
53
,
norm_type
=
'sync_bn'
,
norm_decay
=
0.
,
weight_prefix_name
=
''
,
get_prediction
=
False
,
class_dim
=
1000
):
assert
depth
in
[
53
],
"unsupported depth value"
self
.
depth
=
depth
self
.
norm_type
=
norm_type
self
.
norm_decay
=
norm_decay
self
.
depth_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
],
self
.
basicblock
)}
self
.
prefix_name
=
weight_prefix_name
self
.
class_dim
=
class_dim
self
.
get_prediction
=
get_prediction
def
_conv_norm
(
self
,
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'leaky'
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
".conv.weights"
),
bias_attr
=
False
)
bn_name
=
name
+
".bn"
bn_param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
float
(
self
.
norm_decay
)),
name
=
bn_name
+
'.scale'
)
bn_bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
float
(
self
.
norm_decay
)),
name
=
bn_name
+
'.offset'
)
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
None
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'.mean'
,
moving_variance_name
=
bn_name
+
'.var'
)
# leaky relu here has `alpha` as 0.1, can not be set by
# `act` param in fluid.layers.batch_norm above.
if
act
==
'leaky'
:
out
=
fluid
.
layers
.
leaky_relu
(
x
=
out
,
alpha
=
0.1
)
return
out
def
_downsample
(
self
,
input
,
ch_out
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
name
=
None
):
return
self
.
_conv_norm
(
input
,
ch_out
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
name
=
name
)
def
basicblock
(
self
,
input
,
ch_out
,
name
=
None
):
conv1
=
self
.
_conv_norm
(
input
,
ch_out
=
ch_out
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
name
=
name
+
".0"
)
conv2
=
self
.
_conv_norm
(
conv1
,
ch_out
=
ch_out
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
name
=
name
+
".1"
)
out
=
fluid
.
layers
.
elementwise_add
(
x
=
input
,
y
=
conv2
,
act
=
None
)
return
out
def
layer_warp
(
self
,
block_func
,
input
,
ch_out
,
count
,
name
=
None
):
out
=
block_func
(
input
,
ch_out
=
ch_out
,
name
=
'{}.0'
.
format
(
name
))
for
j
in
six
.
moves
.
xrange
(
1
,
count
):
out
=
block_func
(
out
,
ch_out
=
ch_out
,
name
=
'{}.{}'
.
format
(
name
,
j
))
return
out
def
__call__
(
self
,
input
):
"""
Get the backbone of DarkNet, that is output for the 5 stages.
"""
stages
,
block_func
=
self
.
depth_cfg
[
self
.
depth
]
stages
=
stages
[
0
:
5
]
conv
=
self
.
_conv_norm
(
input
=
input
,
ch_out
=
32
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
name
=
self
.
prefix_name
+
"yolo_input"
)
downsample_
=
self
.
_downsample
(
input
=
conv
,
ch_out
=
conv
.
shape
[
1
]
*
2
,
name
=
self
.
prefix_name
+
"yolo_input.downsample"
)
blocks
=
[]
for
i
,
stage
in
enumerate
(
stages
):
block
=
self
.
layer_warp
(
block_func
=
block_func
,
input
=
downsample_
,
ch_out
=
32
*
2
**
i
,
count
=
stage
,
name
=
self
.
prefix_name
+
"stage.{}"
.
format
(
i
))
blocks
.
append
(
block
)
if
i
<
len
(
stages
)
-
1
:
# do not downsaple in the last stage
downsample_
=
self
.
_downsample
(
input
=
block
,
ch_out
=
block
.
shape
[
1
]
*
2
,
name
=
self
.
prefix_name
+
"stage.{}.downsample"
.
format
(
i
))
if
self
.
get_prediction
:
pool
=
fluid
.
layers
.
pool2d
(
input
=
block
,
pool_type
=
'avg'
,
global_pooling
=
True
)
stdv
=
1.0
/
math
.
sqrt
(
pool
.
shape
[
1
]
*
1.0
)
out
=
fluid
.
layers
.
fc
(
input
=
pool
,
size
=
self
.
class_dim
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
),
name
=
'fc_weights'
),
bias_attr
=
ParamAttr
(
name
=
'fc_offset'
))
out
=
fluid
.
layers
.
softmax
(
out
)
return
out
else
:
return
blocks
hub_module/modules/image/object_detection/yolov3_darknet53_venus/data_feed.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
from
__future__
import
absolute_import
from
__future__
import
print_function
from
__future__
import
division
import
os
import
cv2
import
numpy
as
np
__all__
=
[
'reader'
]
def
reader
(
paths
=
[],
images
=
None
):
"""
data generator
Args:
paths (list[str]): paths to images.
images (list(numpy.ndarray)): data of images, shape of each is [H, W, C]
Yield:
res (list): preprocessed image and the size of original image.
"""
img_list
=
[]
if
paths
:
assert
type
(
paths
)
is
list
,
"type(paths) is not list."
for
img_path
in
paths
:
assert
os
.
path
.
isfile
(
img_path
),
"The {} isn't a valid file path."
.
format
(
img_path
)
img
=
cv2
.
imread
(
img_path
).
astype
(
'float32'
)
img_list
.
append
(
img
)
if
images
is
not
None
:
for
img
in
images
:
img_list
.
append
(
img
)
for
im
in
img_list
:
# im_size
im_shape
=
im
.
shape
im_size
=
np
.
array
([
im_shape
[
0
],
im_shape
[
1
]],
dtype
=
np
.
int32
)
# decode image
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
# resize image
target_size
=
608
im_size_min
=
np
.
min
(
im_shape
[
0
:
2
])
im_size_max
=
np
.
max
(
im_shape
[
0
:
2
])
if
float
(
im_size_min
)
==
0
:
raise
ZeroDivisionError
(
'min size of image is 0'
)
im_scale_x
=
float
(
target_size
)
/
float
(
im_shape
[
1
])
im_scale_y
=
float
(
target_size
)
/
float
(
im_shape
[
0
])
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
2
)
# normalize image
mean
=
[
0.485
,
0.456
,
0.406
]
std
=
[
0.229
,
0.224
,
0.225
]
im
=
im
.
astype
(
np
.
float32
,
copy
=
False
)
mean
=
np
.
array
(
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
=
im
/
255.0
im
-=
mean
im
/=
std
# permute
im
=
np
.
swapaxes
(
im
,
1
,
2
)
im
=
np
.
swapaxes
(
im
,
1
,
0
)
yield
[
im
,
im_size
]
hub_module/modules/image/object_detection/yolov3_darknet53_venus/module.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
from
__future__
import
absolute_import
import
ast
import
argparse
import
os
from
functools
import
partial
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddlehub
as
hub
from
paddle.fluid.core
import
PaddleTensor
,
AnalysisConfig
,
create_paddle_predictor
from
paddlehub.module.module
import
moduleinfo
,
runnable
,
serving
from
paddlehub.common.paddle_helper
import
add_vars_prefix
from
yolov3_darknet53_venus.darknet
import
DarkNet
from
yolov3_darknet53_venus.processor
import
load_label_info
,
postprocess
,
base64_to_cv2
from
yolov3_darknet53_venus.data_feed
import
reader
from
yolov3_darknet53_venus.yolo_head
import
MultiClassNMS
,
YOLOv3Head
@
moduleinfo
(
name
=
"yolov3_darknet53_venus"
,
version
=
"1.1.0"
,
type
=
"CV/object_detection"
,
summary
=
"Baidu's YOLOv3 model for object detection, with backbone DarkNet53, trained with Baidu self-built dataset."
,
author
=
"paddlepaddle"
,
author_email
=
"paddle-dev@baidu.com"
)
class
YOLOv3DarkNet53Venus
(
hub
.
Module
):
def
_initialize
(
self
):
self
.
default_pretrained_model_path
=
os
.
path
.
join
(
self
.
directory
,
"yolov3_darknet53_model"
)
def
context
(
self
,
trainable
=
True
,
pretrained
=
True
,
get_prediction
=
False
):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
context_prog
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
# image
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
608
,
608
],
dtype
=
'float32'
)
# backbone
backbone
=
DarkNet
(
norm_type
=
'bn'
,
norm_decay
=
0.
,
depth
=
53
)
# body_feats
body_feats
=
backbone
(
image
)
# im_size
im_size
=
fluid
.
layers
.
data
(
name
=
'im_size'
,
shape
=
[
2
],
dtype
=
'int32'
)
# yolo_head
yolo_head
=
YOLOv3Head
(
num_classes
=
708
)
# head_features
head_features
,
body_features
=
yolo_head
.
_get_outputs
(
body_feats
,
is_train
=
trainable
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
# var_prefix
var_prefix
=
'@HUB_{}@'
.
format
(
self
.
name
)
# name of inputs
inputs
=
{
'image'
:
var_prefix
+
image
.
name
,
'im_size'
:
var_prefix
+
im_size
.
name
}
# name of outputs
if
get_prediction
:
bbox_out
=
yolo_head
.
get_prediction
(
head_features
,
im_size
)
outputs
=
{
'bbox_out'
:
[
var_prefix
+
bbox_out
.
name
]}
else
:
outputs
=
{
'head_features'
:
[
var_prefix
+
var
.
name
for
var
in
head_features
],
'body_features'
:
[
var_prefix
+
var
.
name
for
var
in
body_features
]
}
# add_vars_prefix
add_vars_prefix
(
context_prog
,
var_prefix
)
add_vars_prefix
(
fluid
.
default_startup_program
(),
var_prefix
)
# inputs
inputs
=
{
key
:
context_prog
.
global_block
().
vars
[
value
]
for
key
,
value
in
inputs
.
items
()
}
# outputs
outputs
=
{
key
:
[
context_prog
.
global_block
().
vars
[
varname
]
for
varname
in
value
]
for
key
,
value
in
outputs
.
items
()
}
# trainable
for
param
in
context_prog
.
global_block
().
iter_parameters
():
param
.
trainable
=
trainable
# pretrained
if
pretrained
:
def
_if_exist
(
var
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
default_pretrained_model_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
self
.
default_pretrained_model_path
,
predicate
=
_if_exist
)
else
:
exe
.
run
(
startup_program
)
return
inputs
,
outputs
,
context_prog
hub_module/modules/image/object_detection/yolov3_darknet53_venus/processor.py
0 → 100644
浏览文件 @
9b82f2fb
# coding=utf-8
import
base64
import
os
import
cv2
import
numpy
as
np
from
PIL
import
Image
,
ImageDraw
__all__
=
[
'base64_to_cv2'
,
'load_label_info'
,
'postprocess'
]
def
base64_to_cv2
(
b64str
):
data
=
base64
.
b64decode
(
b64str
.
encode
(
'utf8'
))
data
=
np
.
fromstring
(
data
,
np
.
uint8
)
data
=
cv2
.
imdecode
(
data
,
cv2
.
IMREAD_COLOR
)
return
data
def
check_dir
(
dir_path
):
if
not
os
.
path
.
exists
(
dir_path
):
os
.
makedirs
(
dir_path
)
elif
os
.
path
.
isfile
(
dir_path
):
os
.
remove
(
dir_path
)
os
.
makedirs
(
dir_path
)
def
get_save_image_name
(
img
,
output_dir
,
image_path
):
"""Get save image name from source image path.
"""
image_name
=
os
.
path
.
split
(
image_path
)[
-
1
]
name
,
ext
=
os
.
path
.
splitext
(
image_name
)
if
ext
==
''
:
if
img
.
format
==
'PNG'
:
ext
=
'.png'
elif
img
.
format
==
'JPEG'
:
ext
=
'.jpg'
elif
img
.
format
==
'BMP'
:
ext
=
'.bmp'
else
:
if
img
.
mode
==
"RGB"
or
img
.
mode
==
"L"
:
ext
=
".jpg"
elif
img
.
mode
==
"RGBA"
or
img
.
mode
==
"P"
:
ext
=
'.png'
return
os
.
path
.
join
(
output_dir
,
"{}"
.
format
(
name
))
+
ext
def
draw_bounding_box_on_image
(
image_path
,
data_list
,
save_dir
):
image
=
Image
.
open
(
image_path
)
draw
=
ImageDraw
.
Draw
(
image
)
for
data
in
data_list
:
left
,
right
,
top
,
bottom
=
data
[
'left'
],
data
[
'right'
],
data
[
'top'
],
data
[
'bottom'
]
# draw bbox
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
left
,
top
)],
width
=
2
,
fill
=
'red'
)
# draw label
if
image
.
mode
==
'RGB'
:
text
=
data
[
'label'
]
+
": %.2f%%"
%
(
100
*
data
[
'confidence'
])
textsize_width
,
textsize_height
=
draw
.
textsize
(
text
=
text
)
draw
.
rectangle
(
xy
=
(
left
,
top
-
(
textsize_height
+
5
),
left
+
textsize_width
+
10
,
top
),
fill
=
(
255
,
255
,
255
))
draw
.
text
(
xy
=
(
left
,
top
-
15
),
text
=
text
,
fill
=
(
0
,
0
,
0
))
save_name
=
get_save_image_name
(
image
,
save_dir
,
image_path
)
if
os
.
path
.
exists
(
save_name
):
os
.
remove
(
save_name
)
image
.
save
(
save_name
)
return
save_name
def
clip_bbox
(
bbox
,
img_width
,
img_height
):
xmin
=
max
(
min
(
bbox
[
0
],
img_width
),
0.
)
ymin
=
max
(
min
(
bbox
[
1
],
img_height
),
0.
)
xmax
=
max
(
min
(
bbox
[
2
],
img_width
),
0.
)
ymax
=
max
(
min
(
bbox
[
3
],
img_height
),
0.
)
return
xmin
,
ymin
,
xmax
,
ymax
def
load_label_info
(
file_path
):
with
open
(
file_path
,
'r'
)
as
fr
:
text
=
fr
.
readlines
()
label_names
=
[]
for
info
in
text
:
label_names
.
append
(
info
.
strip
())
return
label_names
def
postprocess
(
paths
,
images
,
data_out
,
score_thresh
,
label_names
,
output_dir
,
handle_id
,
visualization
=
True
):
"""
postprocess the lod_tensor produced by fluid.Executor.run
Args:
paths (list[str]): The paths of images.
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
data_out (lod_tensor): data output of predictor.
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): the low limit of bounding box.
label_names (list[str]): label names.
handle_id (int): The number of images that have been handled.
Returns:
res (list[dict]): The result of vehicles detecion. keys include 'data', 'save_path', the corresponding value is:
data (dict): the result of object detection, keys include 'left', 'top', 'right', 'bottom', 'label', 'confidence', the corresponding value is:
left (float): The X coordinate of the upper left corner of the bounding box;
top (float): The Y coordinate of the upper left corner of the bounding box;
right (float): The X coordinate of the lower right corner of the bounding box;
bottom (float): The Y coordinate of the lower right corner of the bounding box;
label (str): The label of detection result;
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor
=
data_out
[
0
]
lod
=
lod_tensor
.
lod
[
0
]
results
=
lod_tensor
.
as_ndarray
()
check_dir
(
output_dir
)
assert
type
(
paths
)
is
list
,
"type(paths) is not list."
if
handle_id
<
len
(
paths
):
unhandled_paths
=
paths
[
handle_id
:]
unhandled_paths_num
=
len
(
unhandled_paths
)
else
:
unhandled_paths_num
=
0
output
=
list
()
for
index
in
range
(
len
(
lod
)
-
1
):
output_i
=
{
'data'
:
[]}
if
index
<
unhandled_paths_num
:
org_img_path
=
unhandled_paths
[
index
]
org_img
=
Image
.
open
(
org_img_path
)
else
:
org_img
=
images
[
index
-
unhandled_paths_num
]
org_img
=
org_img
.
astype
(
np
.
uint8
)
org_img
=
Image
.
fromarray
(
org_img
[:,
:,
::
-
1
])
if
visualization
:
org_img_path
=
get_save_image_name
(
org_img
,
output_dir
,
'image_numpy_{}'
.
format
(
(
handle_id
+
index
)))
org_img
.
save
(
org_img_path
)
org_img_height
=
org_img
.
height
org_img_width
=
org_img
.
width
result_i
=
results
[
lod
[
index
]:
lod
[
index
+
1
]]
for
row
in
result_i
:
if
len
(
row
)
!=
6
:
continue
if
row
[
1
]
<
score_thresh
:
continue
category_id
=
int
(
row
[
0
])
confidence
=
row
[
1
]
bbox
=
row
[
2
:]
dt
=
{}
dt
[
'label'
]
=
label_names
[
category_id
]
dt
[
'confidence'
]
=
confidence
dt
[
'left'
],
dt
[
'top'
],
dt
[
'right'
],
dt
[
'bottom'
]
=
clip_bbox
(
bbox
,
org_img_width
,
org_img_height
)
output_i
[
'data'
].
append
(
dt
)
output
.
append
(
output_i
)
if
visualization
:
output_i
[
'save_path'
]
=
draw_bounding_box_on_image
(
org_img_path
,
output_i
[
'data'
],
output_dir
)
return
output
hub_module/modules/image/object_detection/yolov3_darknet53_venus/yolo_head.py
0 → 100644
浏览文件 @
9b82f2fb
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
from
paddle
import
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
__all__
=
[
'MultiClassNMS'
,
'YOLOv3Head'
]
class
MultiClassNMS
(
object
):
# __op__ = fluid.layers.multiclass_nms
def
__init__
(
self
,
background_label
,
keep_top_k
,
nms_threshold
,
nms_top_k
,
normalized
,
score_threshold
):
super
(
MultiClassNMS
,
self
).
__init__
()
self
.
background_label
=
background_label
self
.
keep_top_k
=
keep_top_k
self
.
nms_threshold
=
nms_threshold
self
.
nms_top_k
=
nms_top_k
self
.
normalized
=
normalized
self
.
score_threshold
=
score_threshold
class
YOLOv3Head
(
object
):
"""Head block for YOLOv3 network
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
"""
def
__init__
(
self
,
norm_decay
=
0.
,
num_classes
=
80
,
ignore_thresh
=
0.7
,
label_smooth
=
True
,
anchors
=
[[
10
,
13
],
[
16
,
30
],
[
33
,
23
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
116
,
90
],
[
156
,
198
],
[
373
,
326
]],
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]],
nms
=
MultiClassNMS
(
background_label
=-
1
,
keep_top_k
=
100
,
nms_threshold
=
0.45
,
nms_top_k
=
1000
,
normalized
=
True
,
score_threshold
=
0.01
),
weight_prefix_name
=
''
):
self
.
norm_decay
=
norm_decay
self
.
num_classes
=
num_classes
self
.
ignore_thresh
=
ignore_thresh
self
.
label_smooth
=
label_smooth
self
.
anchor_masks
=
anchor_masks
self
.
_parse_anchors
(
anchors
)
self
.
nms
=
nms
self
.
prefix_name
=
weight_prefix_name
def
_conv_bn
(
self
,
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'leaky'
,
is_test
=
True
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
".conv.weights"
),
bias_attr
=
False
)
bn_name
=
name
+
".bn"
bn_param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
self
.
norm_decay
),
name
=
bn_name
+
'.scale'
)
bn_bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
self
.
norm_decay
),
name
=
bn_name
+
'.offset'
)
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
None
,
is_test
=
is_test
,
param_attr
=
bn_param_attr
,
bias_attr
=
bn_bias_attr
,
moving_mean_name
=
bn_name
+
'.mean'
,
moving_variance_name
=
bn_name
+
'.var'
)
if
act
==
'leaky'
:
out
=
fluid
.
layers
.
leaky_relu
(
x
=
out
,
alpha
=
0.1
)
return
out
def
_detection_block
(
self
,
input
,
channel
,
is_test
=
True
,
name
=
None
):
assert
channel
%
2
==
0
,
\
"channel {} cannot be divided by 2 in detection block {}"
\
.
format
(
channel
,
name
)
conv
=
input
for
j
in
range
(
2
):
conv
=
self
.
_conv_bn
(
conv
,
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
,
name
=
'{}.{}.0'
.
format
(
name
,
j
))
conv
=
self
.
_conv_bn
(
conv
,
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
,
name
=
'{}.{}.1'
.
format
(
name
,
j
))
route
=
self
.
_conv_bn
(
conv
,
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
is_test
,
name
=
'{}.2'
.
format
(
name
))
tip
=
self
.
_conv_bn
(
route
,
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
is_test
=
is_test
,
name
=
'{}.tip'
.
format
(
name
))
return
route
,
tip
def
_upsample
(
self
,
input
,
scale
=
2
,
name
=
None
):
out
=
fluid
.
layers
.
resize_nearest
(
input
=
input
,
scale
=
float
(
scale
),
name
=
name
)
return
out
def
_parse_anchors
(
self
,
anchors
):
"""
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
"""
self
.
anchors
=
[]
self
.
mask_anchors
=
[]
assert
len
(
anchors
)
>
0
,
"ANCHORS not set."
assert
len
(
self
.
anchor_masks
)
>
0
,
"ANCHOR_MASKS not set."
for
anchor
in
anchors
:
assert
len
(
anchor
)
==
2
,
"anchor {} len should be 2"
.
format
(
anchor
)
self
.
anchors
.
extend
(
anchor
)
anchor_num
=
len
(
anchors
)
for
masks
in
self
.
anchor_masks
:
self
.
mask_anchors
.
append
([])
for
mask
in
masks
:
assert
mask
<
anchor_num
,
"anchor mask index overflow"
self
.
mask_anchors
[
-
1
].
extend
(
anchors
[
mask
])
def
_get_outputs
(
self
,
input
,
is_train
=
True
):
"""
Get YOLOv3 head output
Args:
input (list): List of Variables, output of backbone stages
is_train (bool): whether in train or test mode
Returns:
outputs (list): Variables of each output layer
"""
outputs
=
[]
# get last out_layer_num blocks in reverse order
out_layer_num
=
len
(
self
.
anchor_masks
)
if
isinstance
(
input
,
OrderedDict
):
blocks
=
list
(
input
.
values
())[
-
1
:
-
out_layer_num
-
1
:
-
1
]
else
:
blocks
=
input
[
-
1
:
-
out_layer_num
-
1
:
-
1
]
route
=
None
for
i
,
block
in
enumerate
(
blocks
):
if
i
>
0
:
# perform concat in first 2 detection_block
block
=
fluid
.
layers
.
concat
(
input
=
[
route
,
block
],
axis
=
1
)
route
,
tip
=
self
.
_detection_block
(
block
,
channel
=
512
//
(
2
**
i
),
is_test
=
(
not
is_train
),
name
=
self
.
prefix_name
+
"yolo_block.{}"
.
format
(
i
))
# out channel number = mask_num * (5 + class_num)
num_filters
=
len
(
self
.
anchor_masks
[
i
])
*
(
self
.
num_classes
+
5
)
block_out
=
fluid
.
layers
.
conv2d
(
input
=
tip
,
num_filters
=
num_filters
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
self
.
prefix_name
+
"yolo_output.{}.conv.weights"
.
format
(
i
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.
),
name
=
self
.
prefix_name
+
"yolo_output.{}.conv.bias"
.
format
(
i
)))
outputs
.
append
(
block_out
)
if
i
<
len
(
blocks
)
-
1
:
# do not perform upsample in the last detection_block
route
=
self
.
_conv_bn
(
input
=
route
,
ch_out
=
256
//
(
2
**
i
),
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
is_test
=
(
not
is_train
),
name
=
self
.
prefix_name
+
"yolo_transition.{}"
.
format
(
i
))
# upsample
route
=
self
.
_upsample
(
route
)
return
outputs
,
blocks
def
get_prediction
(
self
,
outputs
,
im_size
):
"""
Get prediction result of YOLOv3 network
Args:
outputs (list): list of Variables, return from _get_outputs
im_size (Variable): Variable of size([h, w]) of each image
Returns:
pred (Variable): The prediction result after non-max suppress.
"""
boxes
=
[]
scores
=
[]
downsample
=
32
for
i
,
output
in
enumerate
(
outputs
):
box
,
score
=
fluid
.
layers
.
yolo_box
(
x
=
output
,
img_size
=
im_size
,
anchors
=
self
.
mask_anchors
[
i
],
class_num
=
self
.
num_classes
,
conf_thresh
=
self
.
nms
.
score_threshold
,
downsample_ratio
=
downsample
,
name
=
self
.
prefix_name
+
"yolo_box"
+
str
(
i
))
boxes
.
append
(
box
)
scores
.
append
(
fluid
.
layers
.
transpose
(
score
,
perm
=
[
0
,
2
,
1
]))
downsample
//=
2
yolo_boxes
=
fluid
.
layers
.
concat
(
boxes
,
axis
=
1
)
yolo_scores
=
fluid
.
layers
.
concat
(
scores
,
axis
=
2
)
pred
=
fluid
.
layers
.
multiclass_nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
,
score_threshold
=
self
.
nms
.
score_threshold
,
nms_top_k
=
self
.
nms
.
nms_top_k
,
keep_top_k
=
self
.
nms
.
keep_top_k
,
nms_threshold
=
self
.
nms
.
nms_threshold
,
background_label
=
self
.
nms
.
background_label
,
normalized
=
self
.
nms
.
normalized
,
name
=
"multiclass_nms"
)
return
pred
hub_module/scripts/configs/faster_rcnn_resnet50_fpn_venus.yml
0 → 100644
浏览文件 @
9b82f2fb
name
:
faster_rcnn_resnet50_fpn_venus
dir
:
"
modules/image/object_detection/faster_rcnn_resnet50_fpn_venus"
# resources:
# -
# url: https://paddlehub.bj.bcebos.com/model/cv/faster_rcnn_resnet50_fpn_model.tar.gz
# dest: faster_rcnn_resnet50_fpn_model
# uncompress: True
hub_module/scripts/configs/yolov3_darknet53_venus.yml
0 → 100644
浏览文件 @
9b82f2fb
name
:
yolov3_darknet53_venus
dir
:
"
modules/image/object_detection/yolov3_darknet53_venus"
# resources:
# -
# url: https://paddlehub.bj.bcebos.com/model/cv/yolov3_darknet53_model.tar.gz
# dest: yolov3_darknet53_model
# uncompress: True
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录