提交 9b82f2fb 编写于 作者: W wuzewu

Add detection module - venus

上级 59499f1b
## 命令行预测
```shell
$ hub run faster_rcnn_resnet50_fpn_venus --input_path "/PATH/TO/IMAGE"
```
## API
```python
def context(num_classes=81,
trainable=True,
pretrained=True,
phase='train')
```
提取特征,用于迁移学习。
**参数**
* num\_classes (int): 类别数;
* trainable(bool): 参数是否可训练;
* pretrained (bool): 是否加载预训练模型;
* phase (str): 可选值为 'train'/'predict','trian' 用于训练,'predict' 用于预测。
**返回**
* inputs (dict): 模型的输入,相应的取值为:
当 phase 为 'train'时,包含:
* image (Variable): 图像变量
* im\_size (Variable): 图像的尺寸
* im\_info (Variable): 图像缩放信息
* gt\_class (Variable): 检测框类别
* gt\_box (Variable): 检测框坐标
* is\_crowd (Variable): 单个框内是否包含多个物体
当 phase 为 'predict'时,包含:
* image (Variable): 图像变量
* im\_size (Variable): 图像的尺寸
* im\_info (Variable): 图像缩放信息
* outputs (dict): 模型的输出,相应的取值为:
当 phase 为 'train'时,包含:
* head_features (Variable): 所提取的特征
* rpn\_cls\_loss (Variable): 检测框分类损失
* rpn\_reg\_loss (Variable): 检测框回归损失
* generate\_proposal\_labels (Variable): 图像信息
当 phase 为 'predict'时,包含:
* head_features (Variable): 所提取的特征
* rois (Variable): 提取的roi
* bbox\_out (Variable): 预测结果
* context\_prog (Program): 用于迁移学习的 Program。
```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
```
将模型保存到指定路径。
**参数**
* dirname: 存在模型的目录名称
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
### 依赖
paddlepaddle >= 1.6.2
paddlehub >= 1.6.0
class BBoxAssigner(object):
# __op__ = fluid.layers.generate_proposal_labels
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
fg_thresh=.5,
bg_thresh_hi=.5,
bg_thresh_lo=0.,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=81,
shuffle_before_sample=True):
super(BBoxAssigner, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh
self.bg_thresh_hi = bg_thresh_hi
self.bg_thresh_lo = bg_thresh_lo
self.bbox_reg_weights = bbox_reg_weights
self.class_nums = class_nums
self.use_random = shuffle_before_sample
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Xavier
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import MSRA
class MultiClassNMS(object):
# __op__ = fluid.layers.multiclass_nms
def __init__(self,
score_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
nms_threshold=.5,
normalized=False,
nms_eta=1.0,
background_label=0):
super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.normalized = normalized
self.nms_eta = nms_eta
self.background_label = background_label
class SmoothL1Loss(object):
'''
Smooth L1 loss
Args:
sigma (float): hyper param in smooth l1 loss
'''
def __init__(self, sigma=1.0):
super(SmoothL1Loss, self).__init__()
self.sigma = sigma
def __call__(self, x, y, inside_weight=None, outside_weight=None):
return fluid.layers.smooth_l1(
x,
y,
inside_weight=inside_weight,
outside_weight=outside_weight,
sigma=self.sigma)
class BoxCoder(object):
def __init__(self,
prior_box_var=[0.1, 0.1, 0.2, 0.2],
code_type='decode_center_size',
box_normalized=False,
axis=1):
super(BoxCoder, self).__init__()
self.prior_box_var = prior_box_var
self.code_type = code_type
self.box_normalized = box_normalized
self.axis = axis
class TwoFCHead(object):
"""
RCNN head with two Fully Connected layers
Args:
mlp_dim (int): num of filters for the fc layers
"""
def __init__(self, mlp_dim=1024):
super(TwoFCHead, self).__init__()
self.mlp_dim = mlp_dim
def __call__(self, roi_feat):
fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3]
fc6 = fluid.layers.fc(
input=roi_feat,
size=self.mlp_dim,
act='relu',
name='fc6',
param_attr=ParamAttr(name='fc6_w', initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(
name='fc6_b', learning_rate=2., regularizer=L2Decay(0.)))
head_feat = fluid.layers.fc(
input=fc6,
size=self.mlp_dim,
act='relu',
name='fc7',
param_attr=ParamAttr(name='fc7_w', initializer=Xavier()),
bias_attr=ParamAttr(
name='fc7_b', learning_rate=2., regularizer=L2Decay(0.)))
return head_feat
class BBoxHead(object):
"""
RCNN bbox head
Args:
head (object): the head module instance, e.g., `ResNetC5`, `TwoFCHead`
box_coder (object): `BoxCoder` instance
nms (object): `MultiClassNMS` instance
num_classes: number of output classes
"""
__inject__ = ['head', 'box_coder', 'nms', 'bbox_loss']
__shared__ = ['num_classes']
def __init__(self,
head,
box_coder=BoxCoder(),
nms=MultiClassNMS(),
bbox_loss=SmoothL1Loss(),
num_classes=81):
super(BBoxHead, self).__init__()
self.head = head
self.num_classes = num_classes
self.box_coder = box_coder
self.nms = nms
self.bbox_loss = bbox_loss
self.head_feat = None
def get_head_feat(self, input=None):
"""
Get the bbox head feature map.
"""
if input is not None:
feat = self.head(input)
if isinstance(feat, OrderedDict):
feat = list(feat.values())[0]
self.head_feat = feat
return self.head_feat
def _get_output(self, roi_feat):
"""
Get bbox head output.
Args:
roi_feat (Variable): RoI feature from RoIExtractor.
Returns:
cls_score(Variable): Output of rpn head with shape of
[N, num_anchors, H, W].
bbox_pred(Variable): Output of rpn head with shape of
[N, num_anchors * 4, H, W].
"""
head_feat = self.get_head_feat(roi_feat)
# when ResNetC5 output a single feature map
if not isinstance(self.head, TwoFCHead):
head_feat = fluid.layers.pool2d(
head_feat, pool_type='avg', global_pooling=True)
cls_score = fluid.layers.fc(
input=head_feat,
size=self.num_classes,
act=None,
name='cls_score',
param_attr=ParamAttr(
name='cls_score_w', initializer=Normal(loc=0.0, scale=0.01)),
bias_attr=ParamAttr(
name='cls_score_b', learning_rate=2., regularizer=L2Decay(0.)))
bbox_pred = fluid.layers.fc(
input=head_feat,
size=4 * self.num_classes,
act=None,
name='bbox_pred',
param_attr=ParamAttr(
name='bbox_pred_w', initializer=Normal(loc=0.0, scale=0.001)),
bias_attr=ParamAttr(
name='bbox_pred_b', learning_rate=2., regularizer=L2Decay(0.)))
return cls_score, bbox_pred
def get_loss(self, roi_feat, labels_int32, bbox_targets,
bbox_inside_weights, bbox_outside_weights):
"""
Get bbox_head loss.
Args:
roi_feat (Variable): RoI feature from RoIExtractor.
labels_int32(Variable): Class label of a RoI with shape [P, 1].
P is the number of RoI.
bbox_targets(Variable): Box label of a RoI with shape
[P, 4 * class_nums].
bbox_inside_weights(Variable): Indicates whether a box should
contribute to loss. Same shape as bbox_targets.
bbox_outside_weights(Variable): Indicates whether a box should
contribute to loss. Same shape as bbox_targets.
Return:
Type: Dict
loss_cls(Variable): bbox_head loss.
loss_bbox(Variable): bbox_head loss.
"""
cls_score, bbox_pred = self._get_output(roi_feat)
labels_int64 = fluid.layers.cast(x=labels_int32, dtype='int64')
labels_int64.stop_gradient = True
loss_cls = fluid.layers.softmax_with_cross_entropy(
logits=cls_score, label=labels_int64, numeric_stable_mode=True)
loss_cls = fluid.layers.reduce_mean(loss_cls)
loss_bbox = self.bbox_loss(
x=bbox_pred,
y=bbox_targets,
inside_weight=bbox_inside_weights,
outside_weight=bbox_outside_weights)
loss_bbox = fluid.layers.reduce_mean(loss_bbox)
return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox}
def get_prediction(self,
roi_feat,
rois,
im_info,
im_shape,
return_box_score=False):
"""
Get prediction bounding box in test stage.
Args:
roi_feat (Variable): RoI feature from RoIExtractor.
rois (Variable): Output of generate_proposals in rpn head.
im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the
number of input images, each element consists of im_height,
im_width, im_scale.
im_shape (Variable): Actual shape of original image with shape
[B, 3]. B is the number of images, each element consists of
original_height, original_width, 1
Returns:
pred_result(Variable): Prediction result with shape [N, 6]. Each
row has 6 values: [label, confidence, xmin, ymin, xmax, ymax].
N is the total number of prediction.
"""
cls_score, bbox_pred = self._get_output(roi_feat)
im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, rois)
boxes = rois / im_scale
cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False)
bbox_pred = fluid.layers.reshape(bbox_pred, (-1, self.num_classes, 4))
# self.box_coder
decoded_box = fluid.layers.box_coder(
prior_box=boxes,
target_box=bbox_pred,
prior_box_var=self.box_coder.prior_box_var,
code_type=self.box_coder.code_type,
box_normalized=self.box_coder.box_normalized,
axis=self.box_coder.axis)
cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape)
if return_box_score:
return {'bbox': cliped_box, 'score': cls_prob}
# self.nms
pred_result = fluid.layers.multiclass_nms(
bboxes=cliped_box,
scores=cls_prob,
score_threshold=self.nms.score_threshold,
nms_top_k=self.nms.nms_top_k,
keep_top_k=self.nms.keep_top_k,
nms_threshold=self.nms.nms_threshold,
normalized=self.nms.normalized,
nms_eta=self.nms.nms_eta,
background_label=self.nms.background_label)
return pred_result
# coding=utf-8
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image, ImageEnhance
from paddle import fluid
__all__ = ['test_reader']
def test_reader(paths=None, images=None):
"""
data generator
Args:
paths (list[str]): paths to images.
images (list(numpy.ndarray)): data of images, shape of each is [H, W, C]
Yield:
res (dict): key contains 'image', 'im_info', 'im_shape', the corresponding values is:
image (numpy.ndarray): the image to be fed into network
im_info (numpy.ndarray): the info about the preprocessed.
im_shape (numpy.ndarray): the shape of image.
"""
img_list = list()
if paths:
for img_path in paths:
assert os.path.isfile(
img_path), "The {} isn't a valid file path.".format(img_path)
img = cv2.imread(img_path).astype('float32')
img_list.append(img)
if images is not None:
for img in images:
img_list.append(img)
for im in img_list:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = im.astype(np.float32, copy=False)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = np.array(mean)[np.newaxis, np.newaxis, :]
std = np.array(std)[np.newaxis, np.newaxis, :]
im = im / 255.0
im -= mean
im /= std
target_size = 800
max_size = 1333
shape = im.shape
# im_shape holds the original shape of image.
im_shape = np.array([shape[0], shape[1], 1.0]).astype('float32')
im_size_min = np.min(shape[0:2])
im_size_max = np.max(shape[0:2])
im_scale = float(target_size) / float(im_size_min)
if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max)
resize_w = np.round(im_scale * float(shape[1]))
resize_h = np.round(im_scale * float(shape[0]))
# im_info holds the resize info of image.
im_info = np.array([resize_h, resize_w, im_scale]).astype('float32')
im = cv2.resize(
im,
None,
None,
fx=im_scale,
fy=im_scale,
interpolation=cv2.INTER_LINEAR)
# HWC --> CHW
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 1, 0)
yield {'image': im, 'im_info': im_info, 'im_shape': im_shape}
def padding_minibatch(batch_data, coarsest_stride=0, use_padded_im_info=True):
max_shape_org = np.array(
[data['image'].shape for data in batch_data]).max(axis=0)
if coarsest_stride > 0:
max_shape = np.zeros((3)).astype('int32')
max_shape[1] = int(
np.ceil(max_shape_org[1] / coarsest_stride) * coarsest_stride)
max_shape[2] = int(
np.ceil(max_shape_org[2] / coarsest_stride) * coarsest_stride)
else:
max_shape = max_shape_org.astype('int32')
padding_image = list()
padding_info = list()
padding_shape = list()
for data in batch_data:
im_c, im_h, im_w = data['image'].shape
# image
padding_im = np.zeros((im_c, max_shape[1], max_shape[2]),
dtype=np.float32)
padding_im[:, 0:im_h, 0:im_w] = data['image']
padding_image.append(padding_im)
# im_info
data['im_info'][
0] = max_shape[1] if use_padded_im_info else max_shape_org[1]
data['im_info'][
1] = max_shape[2] if use_padded_im_info else max_shape_org[2]
padding_info.append(data['im_info'])
padding_shape.append(data['im_shape'])
padding_image = np.array(padding_image).astype('float32')
padding_info = np.array(padding_info).astype('float32')
padding_shape = np.array(padding_shape).astype('float32')
return padding_image, padding_info, padding_shape
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Xavier
from paddle.fluid.regularizer import L2Decay
__all__ = ['ConvNorm', 'FPN']
def ConvNorm(input,
num_filters,
filter_size,
stride=1,
groups=1,
norm_decay=0.,
norm_type='affine_channel',
norm_groups=32,
dilation=1,
lr_scale=1,
freeze_norm=False,
act=None,
norm_name=None,
initializer=None,
name=None):
fan = num_filters
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=((filter_size - 1) // 2) * dilation,
dilation=dilation,
groups=groups,
act=None,
param_attr=ParamAttr(
name=name + "_weights",
initializer=initializer,
learning_rate=lr_scale),
bias_attr=False,
name=name + '.conv2d.output.1')
norm_lr = 0. if freeze_norm else 1.
pattr = ParamAttr(
name=norm_name + '_scale',
learning_rate=norm_lr * lr_scale,
regularizer=L2Decay(norm_decay))
battr = ParamAttr(
name=norm_name + '_offset',
learning_rate=norm_lr * lr_scale,
regularizer=L2Decay(norm_decay))
if norm_type in ['bn', 'sync_bn']:
global_stats = True if freeze_norm else False
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=norm_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=norm_name + '_mean',
moving_variance_name=norm_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif norm_type == 'gn':
out = fluid.layers.group_norm(
input=conv,
act=act,
name=norm_name + '.output.1',
groups=norm_groups,
param_attr=pattr,
bias_attr=battr)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif norm_type == 'affine_channel':
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=pattr,
default_initializer=fluid.initializer.Constant(1.))
bias = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=battr,
default_initializer=fluid.initializer.Constant(0.))
out = fluid.layers.affine_channel(
x=conv, scale=scale, bias=bias, act=act)
if freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
class FPN(object):
"""
Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
Args:
num_chan (int): number of feature channels
min_level (int): lowest level of the backbone feature map to use
max_level (int): highest level of the backbone feature map to use
spatial_scale (list): feature map scaling factor
has_extra_convs (bool): whether has extral convolutions in higher levels
norm_type (str|None): normalization type, 'bn'/'sync_bn'/'affine_channel'
"""
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self,
num_chan=256,
min_level=2,
max_level=6,
spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.],
has_extra_convs=False,
norm_type=None,
freeze_norm=False):
self.freeze_norm = freeze_norm
self.num_chan = num_chan
self.min_level = min_level
self.max_level = max_level
self.spatial_scale = spatial_scale
self.has_extra_convs = has_extra_convs
self.norm_type = norm_type
def _add_topdown_lateral(self, body_name, body_input, upper_output):
lateral_name = 'fpn_inner_' + body_name + '_lateral'
topdown_name = 'fpn_topdown_' + body_name
fan = body_input.shape[1]
if self.norm_type:
initializer = Xavier(fan_out=fan)
lateral = ConvNorm(
body_input,
self.num_chan,
1,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=lateral_name,
norm_name=lateral_name)
else:
lateral = fluid.layers.conv2d(
body_input,
self.num_chan,
1,
param_attr=ParamAttr(
name=lateral_name + "_w", initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(
name=lateral_name + "_b",
learning_rate=2.,
regularizer=L2Decay(0.)),
name=lateral_name)
topdown = fluid.layers.resize_nearest(
upper_output, scale=2., name=topdown_name)
return lateral + topdown
def get_output(self, body_dict):
"""
Add FPN onto backbone.
Args:
body_dict(OrderedDict): Dictionary of variables and each element is the
output of backbone.
Return:
fpn_dict(OrderedDict): A dictionary represents the output of FPN with
their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
"""
spatial_scale = copy.deepcopy(self.spatial_scale)
body_name_list = list(body_dict.keys())[::-1]
num_backbone_stages = len(body_name_list)
self.fpn_inner_output = [[] for _ in range(num_backbone_stages)]
fpn_inner_name = 'fpn_inner_' + body_name_list[0]
body_input = body_dict[body_name_list[0]]
fan = body_input.shape[1]
if self.norm_type:
initializer = Xavier(fan_out=fan)
self.fpn_inner_output[0] = ConvNorm(
body_input,
self.num_chan,
1,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=fpn_inner_name,
norm_name=fpn_inner_name)
else:
self.fpn_inner_output[0] = fluid.layers.conv2d(
body_input,
self.num_chan,
1,
param_attr=ParamAttr(
name=fpn_inner_name + "_w",
initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(
name=fpn_inner_name + "_b",
learning_rate=2.,
regularizer=L2Decay(0.)),
name=fpn_inner_name)
for i in range(1, num_backbone_stages):
body_name = body_name_list[i]
body_input = body_dict[body_name]
top_output = self.fpn_inner_output[i - 1]
fpn_inner_single = self._add_topdown_lateral(
body_name, body_input, top_output)
self.fpn_inner_output[i] = fpn_inner_single
fpn_dict = {}
fpn_name_list = []
for i in range(num_backbone_stages):
fpn_name = 'fpn_' + body_name_list[i]
fan = self.fpn_inner_output[i].shape[1] * 3 * 3
if self.norm_type:
initializer = Xavier(fan_out=fan)
fpn_output = ConvNorm(
self.fpn_inner_output[i],
self.num_chan,
3,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=fpn_name,
norm_name=fpn_name)
else:
fpn_output = fluid.layers.conv2d(
self.fpn_inner_output[i],
self.num_chan,
filter_size=3,
padding=1,
param_attr=ParamAttr(
name=fpn_name + "_w", initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(
name=fpn_name + "_b",
learning_rate=2.,
regularizer=L2Decay(0.)),
name=fpn_name)
fpn_dict[fpn_name] = fpn_output
fpn_name_list.append(fpn_name)
if not self.has_extra_convs and self.max_level - self.min_level == len(
spatial_scale):
body_top_name = fpn_name_list[0]
body_top_extension = fluid.layers.pool2d(
fpn_dict[body_top_name],
1,
'max',
pool_stride=2,
name=body_top_name + '_subsampled_2x')
fpn_dict[body_top_name + '_subsampled_2x'] = body_top_extension
fpn_name_list.insert(0, body_top_name + '_subsampled_2x')
spatial_scale.insert(0, spatial_scale[0] * 0.5)
# Coarser FPN levels introduced for RetinaNet
highest_backbone_level = self.min_level + len(spatial_scale) - 1
if self.has_extra_convs and self.max_level > highest_backbone_level:
fpn_blob = body_dict[body_name_list[0]]
for i in range(highest_backbone_level + 1, self.max_level + 1):
fpn_blob_in = fpn_blob
fpn_name = 'fpn_' + str(i)
if i > highest_backbone_level + 1:
fpn_blob_in = fluid.layers.relu(fpn_blob)
fan = fpn_blob_in.shape[1] * 3 * 3
fpn_blob = fluid.layers.conv2d(
input=fpn_blob_in,
num_filters=self.num_chan,
filter_size=3,
stride=2,
padding=1,
param_attr=ParamAttr(
name=fpn_name + "_w", initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(
name=fpn_name + "_b",
learning_rate=2.,
regularizer=L2Decay(0.)),
name=fpn_name)
fpn_dict[fpn_name] = fpn_blob
fpn_name_list.insert(0, fpn_name)
spatial_scale.insert(0, spatial_scale[0] * 0.5)
res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list])
return res_dict, spatial_scale
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import ast
import argparse
from collections import OrderedDict
from functools import partial
from math import ceil
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable, serving
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.io.parser import txt_parser
from paddlehub.common.paddle_helper import add_vars_prefix
from faster_rcnn_resnet50_fpn_venus.processor import load_label_info, postprocess, base64_to_cv2
from faster_rcnn_resnet50_fpn_venus.data_feed import test_reader, padding_minibatch
from faster_rcnn_resnet50_fpn_venus.fpn import FPN
from faster_rcnn_resnet50_fpn_venus.resnet import ResNet
from faster_rcnn_resnet50_fpn_venus.rpn_head import AnchorGenerator, RPNTargetAssign, GenerateProposals, FPNRPNHead
from faster_rcnn_resnet50_fpn_venus.bbox_head import MultiClassNMS, BBoxHead, TwoFCHead
from faster_rcnn_resnet50_fpn_venus.bbox_assigner import BBoxAssigner
from faster_rcnn_resnet50_fpn_venus.roi_extractor import FPNRoIAlign
@moduleinfo(
name="faster_rcnn_resnet50_fpn_venus",
version="1.0.0",
type="cv/object_detection",
summary=
"Baidu's Faster-RCNN model for object detection, whose backbone is ResNet50, processed with Feature Pyramid Networks",
author="paddlepaddle",
author_email="paddle-dev@baidu.com")
class FasterRCNNResNet50RPN(hub.Module):
def _initialize(self):
# default pretrained model, Faster-RCNN with backbone ResNet50, shape of input tensor is [3, 800, 1333]
self.default_pretrained_model_path = os.path.join(
self.directory, "faster_rcnn_resnet50_fpn_model")
def context(self,
num_classes=708,
trainable=True,
pretrained=True,
phase='train'):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
phase (str): optional choices are 'train' and 'predict'.
Returns:
inputs (dict): the input variables.
outputs (dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=[-1, 3, -1, -1], dtype='float32')
# backbone
backbone = ResNet(
norm_type='affine_channel',
depth=50,
feature_maps=[2, 3, 4, 5],
freeze_at=2)
body_feats = backbone(image)
# fpn
fpn = FPN(
max_level=6,
min_level=2,
num_chan=256,
spatial_scale=[0.03125, 0.0625, 0.125, 0.25])
var_prefix = '@HUB_{}@'.format(self.name)
im_info = fluid.layers.data(
name='im_info', shape=[3], dtype='float32', lod_level=0)
im_shape = fluid.layers.data(
name='im_shape', shape=[3], dtype='float32', lod_level=0)
body_feat_names = list(body_feats.keys())
body_feats, spatial_scale = fpn.get_output(body_feats)
# rpn_head: RPNHead
rpn_head = self.rpn_head()
rois = rpn_head.get_proposals(body_feats, im_info, mode=phase)
# train
if phase == 'train':
gt_bbox = fluid.layers.data(
name='gt_bbox', shape=[4], dtype='float32', lod_level=1)
is_crowd = fluid.layers.data(
name='is_crowd', shape=[1], dtype='int32', lod_level=1)
gt_class = fluid.layers.data(
name='gt_class', shape=[1], dtype='int32', lod_level=1)
rpn_loss = rpn_head.get_loss(im_info, gt_bbox, is_crowd)
# bbox_assigner: BBoxAssigner
bbox_assigner = self.bbox_assigner(num_classes)
outs = fluid.layers.generate_proposal_labels(
rpn_rois=rois,
gt_classes=gt_class,
is_crowd=is_crowd,
gt_boxes=gt_bbox,
im_info=im_info,
batch_size_per_im=bbox_assigner.batch_size_per_im,
fg_fraction=bbox_assigner.fg_fraction,
fg_thresh=bbox_assigner.fg_thresh,
bg_thresh_hi=bbox_assigner.bg_thresh_hi,
bg_thresh_lo=bbox_assigner.bg_thresh_lo,
bbox_reg_weights=bbox_assigner.bbox_reg_weights,
class_nums=bbox_assigner.class_nums,
use_random=bbox_assigner.use_random)
rois = outs[0]
roi_extractor = self.roi_extractor()
roi_feat = roi_extractor(
head_inputs=body_feats,
rois=rois,
spatial_scale=spatial_scale)
# head_feat
bbox_head = self.bbox_head(num_classes)
head_feat = bbox_head.head(roi_feat)
if isinstance(head_feat, OrderedDict):
head_feat = list(head_feat.values())[0]
if phase == 'train':
inputs = {
'image': var_prefix + image.name,
'im_info': var_prefix + im_info.name,
'im_shape': var_prefix + im_shape.name,
'gt_class': var_prefix + gt_class.name,
'gt_bbox': var_prefix + gt_bbox.name,
'is_crowd': var_prefix + is_crowd.name
}
outputs = {
'head_features':
var_prefix + head_feat.name,
'rpn_cls_loss':
var_prefix + rpn_loss['rpn_cls_loss'].name,
'rpn_reg_loss':
var_prefix + rpn_loss['rpn_reg_loss'].name,
'generate_proposal_labels':
[var_prefix + var.name for var in outs]
}
elif phase == 'predict':
pred = bbox_head.get_prediction(roi_feat, rois, im_info,
im_shape)
inputs = {
'image': var_prefix + image.name,
'im_info': var_prefix + im_info.name,
'im_shape': var_prefix + im_shape.name
}
outputs = {
'head_features': var_prefix + head_feat.name,
'rois': var_prefix + rois.name,
'bbox_out': var_prefix + pred.name
}
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(startup_program, var_prefix)
global_vars = context_prog.global_block().vars
inputs = {
key: global_vars[value]
for key, value in inputs.items()
}
outputs = {
key: global_vars[value] if not isinstance(value, list) else
[global_vars[var] for var in value]
for key, value in outputs.items()
}
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
if pretrained:
def _if_exist(var):
if num_classes != 81:
if 'bbox_pred' in var.name or 'cls_score' in var.name:
return False
return os.path.exists(
os.path.join(self.default_pretrained_model_path,
var.name))
fluid.io.load_vars(
exe,
self.default_pretrained_model_path,
predicate=_if_exist)
return inputs, outputs, context_prog
def rpn_head(self):
return FPNRPNHead(
anchor_generator=AnchorGenerator(
anchor_sizes=[32, 64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
stride=[16.0, 16.0],
variance=[1.0, 1.0, 1.0, 1.0]),
rpn_target_assign=RPNTargetAssign(
rpn_batch_size_per_im=256,
rpn_fg_fraction=0.5,
rpn_negative_overlap=0.3,
rpn_positive_overlap=0.7,
rpn_straddle_thresh=0.0),
train_proposal=GenerateProposals(
min_size=0.0,
nms_thresh=0.7,
post_nms_top_n=2000,
pre_nms_top_n=2000),
test_proposal=GenerateProposals(
min_size=0.0,
nms_thresh=0.7,
post_nms_top_n=1000,
pre_nms_top_n=1000),
anchor_start_size=32,
num_chan=256,
min_level=2,
max_level=6)
def roi_extractor(self):
return FPNRoIAlign(
canconical_level=4,
canonical_size=224,
max_level=5,
min_level=2,
box_resolution=7,
sampling_ratio=2)
def bbox_head(self, num_classes):
return BBoxHead(
head=TwoFCHead(mlp_dim=1024),
nms=MultiClassNMS(
keep_top_k=100, nms_threshold=0.5, score_threshold=0.05),
num_classes=num_classes)
def bbox_assigner(self, num_classes):
return BBoxAssigner(
batch_size_per_im=512,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
fg_fraction=0.25,
fg_thresh=0.5,
class_nums=num_classes)
# coding=utf-8
class NameAdapter(object):
"""Fix the backbones variable names for pretrained weight"""
def __init__(self, model):
super(NameAdapter, self).__init__()
self.model = model
@property
def model_type(self):
return getattr(self.model, '_model_type', '')
@property
def variant(self):
return getattr(self.model, 'variant', '')
def fix_conv_norm_name(self, name):
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
# the naming rule is same as pretrained weight
if self.model_type == 'SEResNeXt':
bn_name = name + "_bn"
return bn_name
def fix_shortcut_name(self, name):
if self.model_type == 'SEResNeXt':
name = 'conv' + name + '_prj'
return name
def fix_bottleneck_name(self, name):
if self.model_type == 'SEResNeXt':
conv_name1 = 'conv' + name + '_x1'
conv_name2 = 'conv' + name + '_x2'
conv_name3 = 'conv' + name + '_x3'
shortcut_name = name
else:
conv_name1 = name + "_branch2a"
conv_name2 = name + "_branch2b"
conv_name3 = name + "_branch2c"
shortcut_name = name + "_branch1"
return conv_name1, conv_name2, conv_name3, shortcut_name
def fix_layer_warp_name(self, stage_num, count, i):
name = 'res' + str(stage_num)
if count > 10 and stage_num == 4:
if i == 0:
conv_name = name + "a"
else:
conv_name = name + "b" + str(i)
else:
conv_name = name + chr(ord("a") + i)
if self.model_type == 'SEResNeXt':
conv_name = str(stage_num + 2) + '_' + str(i + 1)
return conv_name
def fix_c1_stage_name(self):
return "res_conv1" if self.model_type == 'ResNeXt' else "conv1"
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
nonlocal_params = {
"use_zero_init_conv": False,
"conv_init_std": 0.01,
"no_bias": True,
"use_maxpool": False,
"use_softmax": True,
"use_bn": False,
"use_scale": True, # vital for the model prformance!!!
"use_affine": False,
"bn_momentum": 0.9,
"bn_epsilon": 1.0000001e-5,
"bn_init_gamma": 0.9,
"weight_decay_bn": 1.e-4,
}
def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner,
max_pool_stride=2):
cur = input
theta = fluid.layers.conv2d(input = cur, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr=ParamAttr(name = prefix + '_theta' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_theta' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if not nonlocal_params["no_bias"] else False, \
name = prefix + '_theta')
theta_shape = theta.shape
theta_shape_op = fluid.layers.shape(theta)
theta_shape_op.stop_gradient = True
if nonlocal_params["use_maxpool"]:
max_pool = fluid.layers.pool2d(input = cur, \
pool_size = [max_pool_stride, max_pool_stride], \
pool_type = 'max', \
pool_stride = [max_pool_stride, max_pool_stride], \
pool_padding = [0, 0], \
name = prefix + '_pool')
else:
max_pool = cur
phi = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_phi' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_phi' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_phi')
phi_shape = phi.shape
g = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_g' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0, scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_g' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_g')
g_shape = g.shape
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta = fluid.layers.reshape(theta, shape=(0, 0, -1))
theta = fluid.layers.transpose(theta, [0, 2, 1])
phi = fluid.layers.reshape(phi, [0, 0, -1])
theta_phi = fluid.layers.matmul(theta, phi, name=prefix + '_affinity')
g = fluid.layers.reshape(g, [0, 0, -1])
if nonlocal_params["use_softmax"]:
if nonlocal_params["use_scale"]:
theta_phi_sc = fluid.layers.scale(theta_phi, scale=dim_inner**-.5)
else:
theta_phi_sc = theta_phi
p = fluid.layers.softmax(
theta_phi_sc, name=prefix + '_affinity' + '_prob')
else:
# not clear about what is doing in xlw's code
p = None # not implemented
raise "Not implemented when not use softmax"
# note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p = fluid.layers.transpose(p, [0, 2, 1])
t = fluid.layers.matmul(g, p, name=prefix + '_y')
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_shape = t.shape
t_re = fluid.layers.reshape(
t, shape=list(theta_shape), actual_shape=theta_shape_op)
blob_out = t_re
blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \
filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_out' + "_w", \
initializer = fluid.initializer.Constant(value = 0.) \
if nonlocal_params["use_zero_init_conv"] \
else fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_out' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_out')
blob_out_shape = blob_out.shape
if nonlocal_params["use_bn"]:
bn_name = prefix + "_bn"
blob_out = fluid.layers.batch_norm(blob_out, \
# is_test = test_mode, \
momentum = nonlocal_params["bn_momentum"], \
epsilon = nonlocal_params["bn_epsilon"], \
name = bn_name, \
param_attr = ParamAttr(name = bn_name + "_s", \
initializer = fluid.initializer.Constant(value = nonlocal_params["bn_init_gamma"]), \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
bias_attr = ParamAttr(name = bn_name + "_b", \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
moving_mean_name = bn_name + "_rm", \
moving_variance_name = bn_name + "_riv") # add bn
if nonlocal_params["use_affine"]:
affine_scale = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_s'), \
default_initializer = fluid.initializer.Constant(value = 1.))
affine_bias = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_b'), \
default_initializer = fluid.initializer.Constant(value = 0.))
blob_out = fluid.layers.affine_channel(blob_out, scale = affine_scale, \
bias = affine_bias, name = prefix + '_affine') # add affine
return blob_out
def add_space_nonlocal(input, dim_in, dim_out, prefix, dim_inner):
'''
add_space_nonlocal:
Non-local Neural Networks: see https://arxiv.org/abs/1711.07971
'''
conv = space_nonlocal(input, dim_in, dim_out, prefix, dim_inner)
output = fluid.layers.elementwise_add(input, conv, name=prefix + '_sum')
return output
# coding=utf-8
import base64
import os
import cv2
import numpy as np
from PIL import Image, ImageDraw
__all__ = [
'base64_to_cv2',
'load_label_info',
'postprocess',
]
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def get_save_image_name(img, output_dir, image_path):
"""Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
if ext == '':
if img.format == 'PNG':
ext = '.png'
elif img.format == 'JPEG':
ext = '.jpg'
elif img.format == 'BMP':
ext = '.bmp'
else:
if img.mode == "RGB" or img.mode == "L":
ext = ".jpg"
elif img.mode == "RGBA" or img.mode == "P":
ext = '.png'
return os.path.join(output_dir, "{}".format(name)) + ext
def draw_bounding_box_on_image(image_path, data_list, save_dir):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
for data in data_list:
left, right, top, bottom = data['left'], data['right'], data[
'top'], data['bottom']
# draw bbox
draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=2,
fill='red')
# draw label
if image.mode == 'RGB':
text = data['label'] + ": %.2f%%" % (100 * data['confidence'])
textsize_width, textsize_height = draw.textsize(text=text)
draw.rectangle(
xy=(left, top - (textsize_height + 5),
left + textsize_width + 10, top),
fill=(255, 255, 255))
draw.text(xy=(left, top - 15), text=text, fill=(0, 0, 0))
save_name = get_save_image_name(image, save_dir, image_path)
if os.path.exists(save_name):
os.remove(save_name)
image.save(save_name)
return save_name
def clip_bbox(bbox, img_width, img_height):
xmin = max(min(bbox[0], img_width), 0.)
ymin = max(min(bbox[1], img_height), 0.)
xmax = max(min(bbox[2], img_width), 0.)
ymax = max(min(bbox[3], img_height), 0.)
return xmin, ymin, xmax, ymax
def load_label_info(file_path):
with open(file_path, 'r') as fr:
text = fr.readlines()
label_names = []
for info in text:
label_names.append(info.strip())
return label_names
def postprocess(paths,
images,
data_out,
score_thresh,
label_names,
output_dir,
handle_id,
visualization=True):
"""
postprocess the lod_tensor produced by fluid.Executor.run
Args:
paths (list[str]): the path of images.
images (list(numpy.ndarray)): list of images, shape of each is [H, W, C].
data_out (lod_tensor): data produced by executor.run.
score_thresh (float): the low limit of bounding box.
label_names (list[str]): label names.
output_dir (str): output directory.
handle_id (int): The number of images that have been handled.
visualization (bool): whether to save as images.
Returns:
res (list[dict]): The result of vehicles detecion. keys include 'data', 'save_path', the corresponding value is:
data (dict): the result of object detection, keys include 'left', 'top', 'right', 'bottom', 'label', 'confidence', the corresponding value is:
left (float): The X coordinate of the upper left corner of the bounding box;
top (float): The Y coordinate of the upper left corner of the bounding box;
right (float): The X coordinate of the lower right corner of the bounding box;
bottom (float): The Y coordinate of the lower right corner of the bounding box;
label (str): The label of detection result;
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor = data_out[0]
lod = lod_tensor.lod[0]
results = lod_tensor.as_ndarray()
if handle_id < len(paths):
unhandled_paths = paths[handle_id:]
unhandled_paths_num = len(unhandled_paths)
else:
unhandled_paths_num = 0
output = []
for index in range(len(lod) - 1):
output_i = {'data': []}
if index < unhandled_paths_num:
org_img_path = unhandled_paths[index]
org_img = Image.open(org_img_path)
output_i['path'] = org_img_path
else:
org_img = images[index - unhandled_paths_num]
org_img = org_img.astype(np.uint8)
org_img = Image.fromarray(org_img[:, :, ::-1])
if visualization:
org_img_path = get_save_image_name(
org_img, output_dir, 'image_numpy_{}'.format(
(handle_id + index)))
org_img.save(org_img_path)
org_img_height = org_img.height
org_img_width = org_img.width
result_i = results[lod[index]:lod[index + 1]]
for row in result_i:
if len(row) != 6:
continue
if row[1] < score_thresh:
continue
category_id = int(row[0])
confidence = row[1]
bbox = row[2:]
dt = {}
dt['label'] = label_names[category_id]
dt['confidence'] = confidence
dt['left'], dt['top'], dt['right'], dt['bottom'] = clip_bbox(
bbox, org_img_width, org_img_height)
output_i['data'].append(dt)
output.append(output_i)
if visualization:
output_i['save_path'] = draw_bounding_box_on_image(
org_img_path, output_i['data'], output_dir)
return output
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from collections import OrderedDict
from numbers import Integral
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import Constant
from .nonlocal_helper import add_space_nonlocal
from .name_adapter import NameAdapter
__all__ = ['ResNet', 'ResNetC5']
class ResNet(object):
"""
Residual Network, see https://arxiv.org/abs/1512.03385
Args:
depth (int): ResNet depth, should be 34, 50.
freeze_at (int): freeze the backbone at which stage
norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel'
freeze_norm (bool): freeze normalization layers
norm_decay (float): weight decay for normalization layer weights
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
nonlocal_stages (list): index of stages who select nonlocal networks
"""
__shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
def __init__(self,
depth=50,
freeze_at=0,
norm_type='sync_bn',
freeze_norm=False,
norm_decay=0.,
variant='b',
feature_maps=[3, 4, 5],
dcn_v2_stages=[],
weight_prefix_name='',
nonlocal_stages=[],
get_prediction=False,
class_dim=1000):
super(ResNet, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
assert depth in [34, 50], \
"depth {} not in [34, 50]"
assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
assert len(feature_maps) > 0, "need one or more feature maps"
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and depth<50), \
"non-local is not supported for resnet18 or resnet34"
self.depth = depth
self.freeze_at = freeze_at
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.variant = variant
self._model_type = 'ResNet'
self.feature_maps = feature_maps
self.dcn_v2_stages = dcn_v2_stages
self.depth_cfg = {
34: ([3, 4, 6, 3], self.basicblock),
50: ([3, 4, 6, 3], self.bottleneck),
}
self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64
self.na = NameAdapter(self)
self.prefix_name = weight_prefix_name
self.nonlocal_stages = nonlocal_stages
self.nonlocal_mod_cfg = {
50: 2,
101: 5,
152: 8,
200: 12,
}
self.get_prediction = get_prediction
self.class_dim = class_dim
def _conv_offset(self,
input,
filter_size,
stride,
padding,
act=None,
name=None):
out_channel = filter_size * filter_size * 3
out = fluid.layers.conv2d(
input,
num_filters=out_channel,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=ParamAttr(initializer=Constant(0.0), name=name + ".w_0"),
bias_attr=ParamAttr(initializer=Constant(0.0), name=name + ".b_0"),
act=act,
name=name)
return out
def _conv_norm(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None,
dcn_v2=False):
_name = self.prefix_name + name if self.prefix_name != '' else name
if not dcn_v2:
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=_name + "_weights"),
bias_attr=False,
name=_name + '.conv2d.output.1')
else:
# select deformable conv"
offset_mask = self._conv_offset(
input=input,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
act=None,
name=_name + "_conv_offset")
offset_channel = filter_size**2 * 2
mask_channel = filter_size**2
offset, mask = fluid.layers.split(
input=offset_mask,
num_or_sections=[offset_channel, mask_channel],
dim=1)
mask = fluid.layers.sigmoid(mask)
conv = fluid.layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
deformable_groups=1,
im2col_step=1,
param_attr=ParamAttr(name=_name + "_weights"),
bias_attr=False,
name=_name + ".conv2d.output.1")
bn_name = self.na.fix_conv_norm_name(name)
bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
norm_lr = 0. if self.freeze_norm else 1.
norm_decay = self.norm_decay
pattr = ParamAttr(
name=bn_name + '_scale',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
battr = ParamAttr(
name=bn_name + '_offset',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
if self.norm_type in ['bn', 'sync_bn']:
global_stats = True if self.freeze_norm else False
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif self.norm_type == 'affine_channel':
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=pattr,
default_initializer=fluid.initializer.Constant(1.))
bias = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=battr,
default_initializer=fluid.initializer.Constant(0.))
out = fluid.layers.affine_channel(
x=conv, scale=scale, bias=bias, act=act)
if self.freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
def _shortcut(self, input, ch_out, stride, is_first, name):
max_pooling_in_short_cut = self.variant == 'd'
ch_in = input.shape[1]
# the naming rule is same as pretrained weight
name = self.na.fix_shortcut_name(name)
std_senet = getattr(self, 'std_senet', False)
if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
if std_senet:
if is_first:
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return self._conv_norm(input, ch_out, 3, stride, name=name)
if max_pooling_in_short_cut and not is_first:
input = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
ceil_mode=True,
pool_type='avg')
return self._conv_norm(input, ch_out, 1, 1, name=name)
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck(self,
input,
num_filters,
stride,
is_first,
name,
dcn_v2=False):
if self.variant == 'a':
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
# ResNeXt
groups = getattr(self, 'groups', 1)
group_width = getattr(self, 'group_width', -1)
if groups == 1:
expand = 4
elif (groups * group_width) == 256:
expand = 1
else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
num_filters = num_filters // 2
expand = 2
conv_name1, conv_name2, conv_name3, \
shortcut_name = self.na.fix_bottleneck_name(name)
std_senet = getattr(self, 'std_senet', False)
if std_senet:
conv_def = [[
int(num_filters / 2), 1, stride1, 'relu', 1, conv_name1
], [num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
else:
conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1],
[num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
residual = input
for i, (c, k, s, act, g, _name) in enumerate(conv_def):
residual = self._conv_norm(
input=residual,
num_filters=c,
filter_size=k,
stride=s,
act=act,
groups=g,
name=_name,
dcn_v2=(i == 1 and dcn_v2))
short = self._shortcut(
input,
num_filters * expand,
stride,
is_first=is_first,
name=shortcut_name)
# Squeeze-and-Excitation
if callable(getattr(self, '_squeeze_excitation', None)):
residual = self._squeeze_excitation(
input=residual, num_channels=num_filters, name='fc' + name)
return fluid.layers.elementwise_add(
x=short, y=residual, act='relu', name=name + ".add.output.5")
def basicblock(self,
input,
num_filters,
stride,
is_first,
name,
dcn_v2=False):
assert dcn_v2 is False, "Not implemented yet."
conv0 = self._conv_norm(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self._conv_norm(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self._shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def layer_warp(self, input, stage_num):
"""
Args:
input (Variable): input variable.
stage_num (int): the stage number, should be 2, 3, 4, 5
Returns:
The last variable in endpoint-th stage.
"""
assert stage_num in [2, 3, 4, 5]
stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2]
ch_out = self.stage_filters[stage_num - 2]
is_first = False if stage_num != 2 else True
dcn_v2 = True if stage_num in self.dcn_v2_stages else False
nonlocal_mod = 1000
if stage_num in self.nonlocal_stages:
nonlocal_mod = self.nonlocal_mod_cfg[
self.depth] if stage_num == 4 else 2
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv = input
for i in range(count):
conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
if self.depth < 50:
is_first = True if i == 0 and stage_num == 2 else False
conv = block_func(
input=conv,
num_filters=ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
is_first=is_first,
name=conv_name,
dcn_v2=dcn_v2)
# add non local model
dim_in = conv.shape[1]
nonlocal_name = "nonlocal_conv{}".format(stage_num)
if i % nonlocal_mod == nonlocal_mod - 1:
conv = add_space_nonlocal(conv, dim_in, dim_in,
nonlocal_name + '_{}'.format(i),
int(dim_in / 2))
return conv
def c1_stage(self, input):
out_chan = self._c1_out_chan_num
conv1_name = self.na.fix_c1_stage_name()
if self.variant in ['c', 'd']:
conv_def = [
[out_chan // 2, 3, 2, "conv1_1"],
[out_chan // 2, 3, 1, "conv1_2"],
[out_chan, 3, 1, "conv1_3"],
]
else:
conv_def = [[out_chan, 7, 2, conv1_name]]
for (c, k, s, _name) in conv_def:
input = self._conv_norm(
input=input,
num_filters=c,
filter_size=k,
stride=s,
act='relu',
name=_name)
output = fluid.layers.pool2d(
input=input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
return output
def __call__(self, input):
assert isinstance(input, Variable)
assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
"feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
res_endpoints = []
res = input
feature_maps = self.feature_maps
severed_head = getattr(self, 'severed_head', False)
if not severed_head:
res = self.c1_stage(res)
feature_maps = range(2, max(self.feature_maps) + 1)
for i in feature_maps:
res = self.layer_warp(res, i)
if i in self.feature_maps:
res_endpoints.append(res)
if self.freeze_at >= i:
res.stop_gradient = True
if self.get_prediction:
pool = fluid.layers.pool2d(
input=res, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
out = fluid.layers.softmax(out)
return out
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)])
class ResNetC5(ResNet):
def __init__(self,
depth=50,
freeze_at=2,
norm_type='affine_channel',
freeze_norm=True,
norm_decay=0.,
variant='b',
feature_maps=[5],
weight_prefix_name=''):
super(ResNetC5, self).__init__(depth, freeze_at, norm_type, freeze_norm,
norm_decay, variant, feature_maps)
self.severed_head = True
# coding=utf-8
import paddle.fluid as fluid
__all__ = ['FPNRoIAlign']
class FPNRoIAlign(object):
"""
RoI align pooling for FPN feature maps
Args:
sampling_ratio (int): number of sampling points
min_level (int): lowest level of FPN layer
max_level (int): highest level of FPN layer
canconical_level (int): the canconical FPN feature map level
canonical_size (int): the canconical FPN feature map size
box_resolution (int): box resolution
mask_resolution (int): mask roi resolution
"""
def __init__(self,
sampling_ratio=0,
min_level=2,
max_level=5,
canconical_level=4,
canonical_size=224,
box_resolution=7,
mask_resolution=14):
super(FPNRoIAlign, self).__init__()
self.sampling_ratio = sampling_ratio
self.min_level = min_level
self.max_level = max_level
self.canconical_level = canconical_level
self.canonical_size = canonical_size
self.box_resolution = box_resolution
self.mask_resolution = mask_resolution
def __call__(self, head_inputs, rois, spatial_scale, is_mask=False):
"""
Adopt RoI align onto several level of feature maps to get RoI features.
Distribute RoIs to different levels by area and get a list of RoI
features by distributed RoIs and their corresponding feature maps.
Returns:
roi_feat(Variable): RoI features with shape of [M, C, R, R],
where M is the number of RoIs and R is RoI resolution
"""
k_min = self.min_level
k_max = self.max_level
num_roi_lvls = k_max - k_min + 1
name_list = list(head_inputs.keys())
input_name_list = name_list[-num_roi_lvls:]
spatial_scale = spatial_scale[-num_roi_lvls:]
rois_dist, restore_index = fluid.layers.distribute_fpn_proposals(
rois, k_min, k_max, self.canconical_level, self.canonical_size)
# rois_dist is in ascend order
roi_out_list = []
resolution = is_mask and self.mask_resolution or self.box_resolution
for lvl in range(num_roi_lvls):
name_index = num_roi_lvls - lvl - 1
rois_input = rois_dist[lvl]
head_input = head_inputs[input_name_list[name_index]]
sc = spatial_scale[name_index]
roi_out = fluid.layers.roi_align(
input=head_input,
rois=rois_input,
pooled_height=resolution,
pooled_width=resolution,
spatial_scale=sc,
sampling_ratio=self.sampling_ratio)
roi_out_list.append(roi_out)
roi_feat_shuffle = fluid.layers.concat(roi_out_list)
roi_feat_ = fluid.layers.gather(roi_feat_shuffle, restore_index)
roi_feat = fluid.layers.lod_reset(roi_feat_, rois)
return roi_feat
## 命令行预测
```shell
$ hub run yolov3_darknet53_venus --input_path "/PATH/TO/IMAGE"
```
## API
```python
def context(trainable=True,
pretrained=True,
get_prediction=False)
```
提取特征,用于迁移学习。
**参数**
* trainable(bool): 参数是否可训练;
* pretrained (bool): 是否加载预训练模型;
* get\_prediction (bool): 是否执行预测。
**返回**
* inputs (dict): 模型的输入,keys 包括 'image', 'im\_size',相应的取值为:
* image (Variable): 图像变量
* im\_size (Variable): 图片的尺寸
* outputs (dict): 模型的输出。如果 get\_prediction 为 False,输出 'head\_features'、'body\_features',否则输出 'bbox\_out'。
* context\_prog (Program): 用于迁移学习的 Program.
```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
```
将模型保存到指定路径。
**参数**
* dirname: 存在模型的目录名称
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
### 依赖
paddlepaddle >= 1.6.2
paddlehub >= 1.6.0
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import math
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
__all__ = ['DarkNet']
class DarkNet(object):
"""DarkNet, see https://pjreddie.com/darknet/yolo/
Args:
depth (int): network depth, currently only darknet 53 is supported
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights
get_prediction (bool): whether to get prediction
class_dim (int): number of class while classification
"""
def __init__(self,
depth=53,
norm_type='sync_bn',
norm_decay=0.,
weight_prefix_name='',
get_prediction=False,
class_dim=1000):
assert depth in [53], "unsupported depth value"
self.depth = depth
self.norm_type = norm_type
self.norm_decay = norm_decay
self.depth_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)}
self.prefix_name = weight_prefix_name
self.class_dim = class_dim
self.get_prediction = get_prediction
def _conv_norm(self,
input,
ch_out,
filter_size,
stride,
padding,
act='leaky',
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
bn_param_attr = ParamAttr(
regularizer=L2Decay(float(self.norm_decay)),
name=bn_name + '.scale')
bn_bias_attr = ParamAttr(
regularizer=L2Decay(float(self.norm_decay)),
name=bn_name + '.offset')
out = fluid.layers.batch_norm(
input=conv,
act=None,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
# leaky relu here has `alpha` as 0.1, can not be set by
# `act` param in fluid.layers.batch_norm above.
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
def _downsample(self,
input,
ch_out,
filter_size=3,
stride=2,
padding=1,
name=None):
return self._conv_norm(
input,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
name=name)
def basicblock(self, input, ch_out, name=None):
conv1 = self._conv_norm(
input,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0,
name=name + ".0")
conv2 = self._conv_norm(
conv1,
ch_out=ch_out * 2,
filter_size=3,
stride=1,
padding=1,
name=name + ".1")
out = fluid.layers.elementwise_add(x=input, y=conv2, act=None)
return out
def layer_warp(self, block_func, input, ch_out, count, name=None):
out = block_func(input, ch_out=ch_out, name='{}.0'.format(name))
for j in six.moves.xrange(1, count):
out = block_func(out, ch_out=ch_out, name='{}.{}'.format(name, j))
return out
def __call__(self, input):
"""
Get the backbone of DarkNet, that is output for the 5 stages.
"""
stages, block_func = self.depth_cfg[self.depth]
stages = stages[0:5]
conv = self._conv_norm(
input=input,
ch_out=32,
filter_size=3,
stride=1,
padding=1,
name=self.prefix_name + "yolo_input")
downsample_ = self._downsample(
input=conv,
ch_out=conv.shape[1] * 2,
name=self.prefix_name + "yolo_input.downsample")
blocks = []
for i, stage in enumerate(stages):
block = self.layer_warp(
block_func=block_func,
input=downsample_,
ch_out=32 * 2**i,
count=stage,
name=self.prefix_name + "stage.{}".format(i))
blocks.append(block)
if i < len(stages) - 1: # do not downsaple in the last stage
downsample_ = self._downsample(
input=block,
ch_out=block.shape[1] * 2,
name=self.prefix_name + "stage.{}.downsample".format(i))
if self.get_prediction:
pool = fluid.layers.pool2d(
input=block, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=self.class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
out = fluid.layers.softmax(out)
return out
else:
return blocks
# coding=utf-8
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
import cv2
import numpy as np
__all__ = ['reader']
def reader(paths=[], images=None):
"""
data generator
Args:
paths (list[str]): paths to images.
images (list(numpy.ndarray)): data of images, shape of each is [H, W, C]
Yield:
res (list): preprocessed image and the size of original image.
"""
img_list = []
if paths:
assert type(paths) is list, "type(paths) is not list."
for img_path in paths:
assert os.path.isfile(
img_path), "The {} isn't a valid file path.".format(img_path)
img = cv2.imread(img_path).astype('float32')
img_list.append(img)
if images is not None:
for img in images:
img_list.append(img)
for im in img_list:
# im_size
im_shape = im.shape
im_size = np.array([im_shape[0], im_shape[1]], dtype=np.int32)
# decode image
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# resize image
target_size = 608
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
if float(im_size_min) == 0:
raise ZeroDivisionError('min size of image is 0')
im_scale_x = float(target_size) / float(im_shape[1])
im_scale_y = float(target_size) / float(im_shape[0])
im = cv2.resize(
im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
# normalize image
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
im = im.astype(np.float32, copy=False)
mean = np.array(mean)[np.newaxis, np.newaxis, :]
std = np.array(std)[np.newaxis, np.newaxis, :]
im = im / 255.0
im -= mean
im /= std
# permute
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 1, 0)
yield [im, im_size]
# coding=utf-8
from __future__ import absolute_import
import ast
import argparse
import os
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from yolov3_darknet53_venus.darknet import DarkNet
from yolov3_darknet53_venus.processor import load_label_info, postprocess, base64_to_cv2
from yolov3_darknet53_venus.data_feed import reader
from yolov3_darknet53_venus.yolo_head import MultiClassNMS, YOLOv3Head
@moduleinfo(
name="yolov3_darknet53_venus",
version="1.1.0",
type="CV/object_detection",
summary=
"Baidu's YOLOv3 model for object detection, with backbone DarkNet53, trained with Baidu self-built dataset.",
author="paddlepaddle",
author_email="paddle-dev@baidu.com")
class YOLOv3DarkNet53Venus(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(
self.directory, "yolov3_darknet53_model")
def context(self, trainable=True, pretrained=True, get_prediction=False):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
# image
image = fluid.layers.data(
name='image', shape=[3, 608, 608], dtype='float32')
# backbone
backbone = DarkNet(norm_type='bn', norm_decay=0., depth=53)
# body_feats
body_feats = backbone(image)
# im_size
im_size = fluid.layers.data(
name='im_size', shape=[2], dtype='int32')
# yolo_head
yolo_head = YOLOv3Head(num_classes=708)
# head_features
head_features, body_features = yolo_head._get_outputs(
body_feats, is_train=trainable)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# var_prefix
var_prefix = '@HUB_{}@'.format(self.name)
# name of inputs
inputs = {
'image': var_prefix + image.name,
'im_size': var_prefix + im_size.name
}
# name of outputs
if get_prediction:
bbox_out = yolo_head.get_prediction(head_features, im_size)
outputs = {'bbox_out': [var_prefix + bbox_out.name]}
else:
outputs = {
'head_features':
[var_prefix + var.name for var in head_features],
'body_features':
[var_prefix + var.name for var in body_features]
}
# add_vars_prefix
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(fluid.default_startup_program(), var_prefix)
# inputs
inputs = {
key: context_prog.global_block().vars[value]
for key, value in inputs.items()
}
# outputs
outputs = {
key: [
context_prog.global_block().vars[varname]
for varname in value
]
for key, value in outputs.items()
}
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
# pretrained
if pretrained:
def _if_exist(var):
return os.path.exists(
os.path.join(self.default_pretrained_model_path,
var.name))
fluid.io.load_vars(
exe,
self.default_pretrained_model_path,
predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
# coding=utf-8
import base64
import os
import cv2
import numpy as np
from PIL import Image, ImageDraw
__all__ = ['base64_to_cv2', 'load_label_info', 'postprocess']
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def check_dir(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
elif os.path.isfile(dir_path):
os.remove(dir_path)
os.makedirs(dir_path)
def get_save_image_name(img, output_dir, image_path):
"""Get save image name from source image path.
"""
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
if ext == '':
if img.format == 'PNG':
ext = '.png'
elif img.format == 'JPEG':
ext = '.jpg'
elif img.format == 'BMP':
ext = '.bmp'
else:
if img.mode == "RGB" or img.mode == "L":
ext = ".jpg"
elif img.mode == "RGBA" or img.mode == "P":
ext = '.png'
return os.path.join(output_dir, "{}".format(name)) + ext
def draw_bounding_box_on_image(image_path, data_list, save_dir):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
for data in data_list:
left, right, top, bottom = data['left'], data['right'], data[
'top'], data['bottom']
# draw bbox
draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=2,
fill='red')
# draw label
if image.mode == 'RGB':
text = data['label'] + ": %.2f%%" % (100 * data['confidence'])
textsize_width, textsize_height = draw.textsize(text=text)
draw.rectangle(
xy=(left, top - (textsize_height + 5),
left + textsize_width + 10, top),
fill=(255, 255, 255))
draw.text(xy=(left, top - 15), text=text, fill=(0, 0, 0))
save_name = get_save_image_name(image, save_dir, image_path)
if os.path.exists(save_name):
os.remove(save_name)
image.save(save_name)
return save_name
def clip_bbox(bbox, img_width, img_height):
xmin = max(min(bbox[0], img_width), 0.)
ymin = max(min(bbox[1], img_height), 0.)
xmax = max(min(bbox[2], img_width), 0.)
ymax = max(min(bbox[3], img_height), 0.)
return xmin, ymin, xmax, ymax
def load_label_info(file_path):
with open(file_path, 'r') as fr:
text = fr.readlines()
label_names = []
for info in text:
label_names.append(info.strip())
return label_names
def postprocess(paths,
images,
data_out,
score_thresh,
label_names,
output_dir,
handle_id,
visualization=True):
"""
postprocess the lod_tensor produced by fluid.Executor.run
Args:
paths (list[str]): The paths of images.
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
data_out (lod_tensor): data output of predictor.
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): the low limit of bounding box.
label_names (list[str]): label names.
handle_id (int): The number of images that have been handled.
Returns:
res (list[dict]): The result of vehicles detecion. keys include 'data', 'save_path', the corresponding value is:
data (dict): the result of object detection, keys include 'left', 'top', 'right', 'bottom', 'label', 'confidence', the corresponding value is:
left (float): The X coordinate of the upper left corner of the bounding box;
top (float): The Y coordinate of the upper left corner of the bounding box;
right (float): The X coordinate of the lower right corner of the bounding box;
bottom (float): The Y coordinate of the lower right corner of the bounding box;
label (str): The label of detection result;
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor = data_out[0]
lod = lod_tensor.lod[0]
results = lod_tensor.as_ndarray()
check_dir(output_dir)
assert type(paths) is list, "type(paths) is not list."
if handle_id < len(paths):
unhandled_paths = paths[handle_id:]
unhandled_paths_num = len(unhandled_paths)
else:
unhandled_paths_num = 0
output = list()
for index in range(len(lod) - 1):
output_i = {'data': []}
if index < unhandled_paths_num:
org_img_path = unhandled_paths[index]
org_img = Image.open(org_img_path)
else:
org_img = images[index - unhandled_paths_num]
org_img = org_img.astype(np.uint8)
org_img = Image.fromarray(org_img[:, :, ::-1])
if visualization:
org_img_path = get_save_image_name(
org_img, output_dir, 'image_numpy_{}'.format(
(handle_id + index)))
org_img.save(org_img_path)
org_img_height = org_img.height
org_img_width = org_img.width
result_i = results[lod[index]:lod[index + 1]]
for row in result_i:
if len(row) != 6:
continue
if row[1] < score_thresh:
continue
category_id = int(row[0])
confidence = row[1]
bbox = row[2:]
dt = {}
dt['label'] = label_names[category_id]
dt['confidence'] = confidence
dt['left'], dt['top'], dt['right'], dt['bottom'] = clip_bbox(
bbox, org_img_width, org_img_height)
output_i['data'].append(dt)
output.append(output_i)
if visualization:
output_i['save_path'] = draw_bounding_box_on_image(
org_img_path, output_i['data'], output_dir)
return output
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
__all__ = ['MultiClassNMS', 'YOLOv3Head']
class MultiClassNMS(object):
# __op__ = fluid.layers.multiclass_nms
def __init__(self, background_label, keep_top_k, nms_threshold, nms_top_k,
normalized, score_threshold):
super(MultiClassNMS, self).__init__()
self.background_label = background_label
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.normalized = normalized
self.score_threshold = score_threshold
class YOLOv3Head(object):
"""Head block for YOLOv3 network
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
"""
def __init__(self,
norm_decay=0.,
num_classes=80,
ignore_thresh=0.7,
label_smooth=True,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
nms=MultiClassNMS(
background_label=-1,
keep_top_k=100,
nms_threshold=0.45,
nms_top_k=1000,
normalized=True,
score_threshold=0.01),
weight_prefix_name=''):
self.norm_decay = norm_decay
self.num_classes = num_classes
self.ignore_thresh = ignore_thresh
self.label_smooth = label_smooth
self.anchor_masks = anchor_masks
self._parse_anchors(anchors)
self.nms = nms
self.prefix_name = weight_prefix_name
def _conv_bn(self,
input,
ch_out,
filter_size,
stride,
padding,
act='leaky',
is_test=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
bn_param_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
bn_bias_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
out = fluid.layers.batch_norm(
input=conv,
act=None,
is_test=is_test,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
def _detection_block(self, input, channel, is_test=True, name=None):
assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
.format(channel, name)
conv = input
for j in range(2):
conv = self._conv_bn(
conv,
channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.{}.0'.format(name, j))
conv = self._conv_bn(
conv,
channel * 2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test,
name='{}.{}.1'.format(name, j))
route = self._conv_bn(
conv,
channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.2'.format(name))
tip = self._conv_bn(
route,
channel * 2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test,
name='{}.tip'.format(name))
return route, tip
def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
return out
def _parse_anchors(self, anchors):
"""
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
"""
self.anchors = []
self.mask_anchors = []
assert len(anchors) > 0, "ANCHORS not set."
assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
for anchor in anchors:
assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
self.anchors.extend(anchor)
anchor_num = len(anchors)
for masks in self.anchor_masks:
self.mask_anchors.append([])
for mask in masks:
assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask])
def _get_outputs(self, input, is_train=True):
"""
Get YOLOv3 head output
Args:
input (list): List of Variables, output of backbone stages
is_train (bool): whether in train or test mode
Returns:
outputs (list): Variables of each output layer
"""
outputs = []
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
if isinstance(input, OrderedDict):
blocks = list(input.values())[-1:-out_layer_num - 1:-1]
else:
blocks = input[-1:-out_layer_num - 1:-1]
route = None
for i, block in enumerate(blocks):
if i > 0: # perform concat in first 2 detection_block
block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self._detection_block(
block,
channel=512 // (2**i),
is_test=(not is_train),
name=self.prefix_name + "yolo_block.{}".format(i))
# out channel number = mask_num * (5 + class_num)
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
block_out = fluid.layers.conv2d(
input=tip,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name=self.prefix_name +
"yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name=self.prefix_name +
"yolo_output.{}.conv.bias".format(i)))
outputs.append(block_out)
if i < len(blocks) - 1:
# do not perform upsample in the last detection_block
route = self._conv_bn(
input=route,
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not is_train),
name=self.prefix_name + "yolo_transition.{}".format(i))
# upsample
route = self._upsample(route)
return outputs, blocks
def get_prediction(self, outputs, im_size):
"""
Get prediction result of YOLOv3 network
Args:
outputs (list): list of Variables, return from _get_outputs
im_size (Variable): Variable of size([h, w]) of each image
Returns:
pred (Variable): The prediction result after non-max suppress.
"""
boxes = []
scores = []
downsample = 32
for i, output in enumerate(outputs):
box, score = fluid.layers.yolo_box(
x=output,
img_size=im_size,
anchors=self.mask_anchors[i],
class_num=self.num_classes,
conf_thresh=self.nms.score_threshold,
downsample_ratio=downsample,
name=self.prefix_name + "yolo_box" + str(i))
boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
downsample //= 2
yolo_boxes = fluid.layers.concat(boxes, axis=1)
yolo_scores = fluid.layers.concat(scores, axis=2)
pred = fluid.layers.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.nms.score_threshold,
nms_top_k=self.nms.nms_top_k,
keep_top_k=self.nms.keep_top_k,
nms_threshold=self.nms.nms_threshold,
background_label=self.nms.background_label,
normalized=self.nms.normalized,
name="multiclass_nms")
return pred
name: faster_rcnn_resnet50_fpn_venus
dir: "modules/image/object_detection/faster_rcnn_resnet50_fpn_venus"
# resources:
# -
# url: https://paddlehub.bj.bcebos.com/model/cv/faster_rcnn_resnet50_fpn_model.tar.gz
# dest: faster_rcnn_resnet50_fpn_model
# uncompress: True
name: yolov3_darknet53_venus
dir: "modules/image/object_detection/yolov3_darknet53_venus"
# resources:
# -
# url: https://paddlehub.bj.bcebos.com/model/cv/yolov3_darknet53_model.tar.gz
# dest: yolov3_darknet53_model
# uncompress: True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册