未验证 提交 9d700dd6 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update yolov3_resnet50_vd_coco2017 (#1954)

* update yolov3_resnet50_vd_coco2017

* update unittest

* update gpu config

* update

* add clean func

* update save inference model
上级 12a6bf92
......@@ -100,20 +100,13 @@
- save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
- ```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- 将模型保存到指定路径。
- **参数**
- dirname: 存在模型的目录名称; <br/>
- model\_filename: 模型文件名称,默认为\_\_model\_\_; <br/>
- params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);<br/>
- combined: 是否将参数保存到统一的一个文件中。
- dirname: 模型保存路径 <br/>
## 四、服务部署
......@@ -166,6 +159,10 @@
修复numpy数据读取问题
* 1.1.0
移除 fluid api
- ```shell
$ hub install yolov3_resnet50_vd_coco2017==1.0.2
$ hub install yolov3_resnet50_vd_coco2017==1.1.0
```
......@@ -99,19 +99,13 @@
- save\_path (str, optional): output path for saving results
- ```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- Save model to specific path
- **Parameters**
- dirname: output dir for saving model
- model\_filename: filename for saving model
- params\_filename: filename for saving parameters
- combined: whether save parameters into one file
- dirname: save model path
## IV.Server Deployment
......@@ -165,6 +159,10 @@
Fix the problem of reading numpy
* 1.1.0
Remove fluid api
- ```shell
$ hub install yolov3_resnet50_vd_coco2017==1.0.2
$ hub install yolov3_resnet50_vd_coco2017==1.1.0
```
......@@ -6,44 +6,43 @@ import argparse
import os
from functools import partial
import paddle
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from yolov3_resnet50_vd_coco2017.resnet import ResNet
from yolov3_resnet50_vd_coco2017.processor import load_label_info, postprocess, base64_to_cv2
from yolov3_resnet50_vd_coco2017.data_feed import reader
from yolov3_resnet50_vd_coco2017.yolo_head import MultiClassNMS, YOLOv3Head
from .processor import load_label_info, postprocess, base64_to_cv2
from .data_feed import reader
@moduleinfo(
name="yolov3_resnet50_vd_coco2017",
version="1.0.2",
version="1.1.0",
type="CV/object_detection",
summary=
"Baidu's YOLOv3 model for object detection with backbone ResNet50, trained with dataset coco2017.",
author="paddlepaddle",
author_email="paddle-dev@baidu.com")
class YOLOv3ResNet50Coco2017(hub.Module):
def _initialize(self):
class YOLOv3ResNet50Coco2017:
def __init__(self):
self.default_pretrained_model_path = os.path.join(
self.directory, "yolov3_resnet50_model")
self.directory, "yolov3_resnet50_model", "model")
self.label_names = load_label_info(
os.path.join(self.directory, "label_file.txt"))
self._set_config()
def _set_config(self):
"""
predictor config setting.
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
cpu_config.switch_ir_optim(False)
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
......@@ -52,110 +51,10 @@ class YOLOv3ResNet50Coco2017(hub.Module):
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
def context(self, trainable=True, pretrained=True, get_prediction=False):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
# image
image = fluid.layers.data(
name='image', shape=[3, 608, 608], dtype='float32')
# backbone
backbone = ResNet(
norm_type='sync_bn',
freeze_at=0,
freeze_norm=False,
norm_decay=0.,
dcn_v2_stages=[5],
depth=50,
variant='d',
feature_maps=[3, 4, 5])
# body_feats
body_feats = backbone(image)
# im_size
im_size = fluid.layers.data(
name='im_size', shape=[2], dtype='int32')
# yolo_head
yolo_head = YOLOv3Head(num_classes=80)
# head_features
head_features, body_features = yolo_head._get_outputs(
body_feats, is_train=trainable)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# var_prefix
var_prefix = '@HUB_{}@'.format(self.name)
# name of inputs
inputs = {
'image': var_prefix + image.name,
'im_size': var_prefix + im_size.name
}
# name of outputs
if get_prediction:
bbox_out = yolo_head.get_prediction(head_features, im_size)
outputs = {'bbox_out': [var_prefix + bbox_out.name]}
else:
outputs = {
'head_features':
[var_prefix + var.name for var in head_features],
'body_features':
[var_prefix + var.name for var in body_features]
}
# add_vars_prefix
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(fluid.default_startup_program(), var_prefix)
# inputs
inputs = {
key: context_prog.global_block().vars[value]
for key, value in inputs.items()
}
# outputs
outputs = {
key: [
context_prog.global_block().vars[varname]
for varname in value
]
for key, value in outputs.items()
}
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
# pretrained
if pretrained:
def _if_exist(var):
return os.path.exists(
os.path.join(self.default_pretrained_model_path,
var.name))
fluid.io.load_vars(
exe,
self.default_pretrained_model_path,
predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
self.gpu_predictor = create_predictor(gpu_config)
def object_detection(self,
paths=None,
......@@ -198,54 +97,33 @@ class YOLOv3ResNet50Coco2017(hub.Module):
paths = paths if paths else list()
data_reader = partial(reader, paths, images)
batch_reader = fluid.io.batch(data_reader, batch_size=batch_size)
batch_reader = paddle.batch(data_reader, batch_size=batch_size)
res = []
for iter_id, feed_data in enumerate(batch_reader()):
feed_data = np.array(feed_data)
image_tensor = PaddleTensor(np.array(list(feed_data[:, 0])))
im_size_tensor = PaddleTensor(np.array(list(feed_data[:, 1])))
if use_gpu:
data_out = self.gpu_predictor.run(
[image_tensor, im_size_tensor])
else:
data_out = self.cpu_predictor.run(
[image_tensor, im_size_tensor])
output = postprocess(
paths=paths,
images=images,
data_out=data_out,
score_thresh=score_thresh,
label_names=self.label_names,
output_dir=output_dir,
handle_id=iter_id * batch_size,
visualization=visualization)
predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(np.array(list(feed_data[:, 0])))
input_handle = predictor.get_input_handle(input_names[1])
input_handle.copy_from_cpu(np.array(list(feed_data[:, 1])))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output = postprocess(paths=paths,
images=images,
data_out=output_handle,
score_thresh=score_thresh,
label_names=self.label_names,
output_dir=output_dir,
handle_id=iter_id * batch_size,
visualization=visualization)
res.extend(output)
return res
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......
# coding=utf-8
class NameAdapter(object):
"""Fix the backbones variable names for pretrained weight"""
def __init__(self, model):
super(NameAdapter, self).__init__()
self.model = model
@property
def model_type(self):
return getattr(self.model, '_model_type', '')
@property
def variant(self):
return getattr(self.model, 'variant', '')
def fix_conv_norm_name(self, name):
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
# the naming rule is same as pretrained weight
if self.model_type == 'SEResNeXt':
bn_name = name + "_bn"
return bn_name
def fix_shortcut_name(self, name):
if self.model_type == 'SEResNeXt':
name = 'conv' + name + '_prj'
return name
def fix_bottleneck_name(self, name):
if self.model_type == 'SEResNeXt':
conv_name1 = 'conv' + name + '_x1'
conv_name2 = 'conv' + name + '_x2'
conv_name3 = 'conv' + name + '_x3'
shortcut_name = name
else:
conv_name1 = name + "_branch2a"
conv_name2 = name + "_branch2b"
conv_name3 = name + "_branch2c"
shortcut_name = name + "_branch1"
return conv_name1, conv_name2, conv_name3, shortcut_name
def fix_layer_warp_name(self, stage_num, count, i):
name = 'res' + str(stage_num)
if count > 10 and stage_num == 4:
if i == 0:
conv_name = name + "a"
else:
conv_name = name + "b" + str(i)
else:
conv_name = name + chr(ord("a") + i)
if self.model_type == 'SEResNeXt':
conv_name = str(stage_num + 2) + '_' + str(i + 1)
return conv_name
def fix_c1_stage_name(self):
return "res_conv1" if self.model_type == 'ResNeXt' else "conv1"
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
nonlocal_params = {
"use_zero_init_conv": False,
"conv_init_std": 0.01,
"no_bias": True,
"use_maxpool": False,
"use_softmax": True,
"use_bn": False,
"use_scale": True, # vital for the model prformance!!!
"use_affine": False,
"bn_momentum": 0.9,
"bn_epsilon": 1.0000001e-5,
"bn_init_gamma": 0.9,
"weight_decay_bn": 1.e-4,
}
def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner,
max_pool_stride=2):
cur = input
theta = fluid.layers.conv2d(input = cur, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr=ParamAttr(name = prefix + '_theta' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_theta' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if not nonlocal_params["no_bias"] else False, \
name = prefix + '_theta')
theta_shape = theta.shape
theta_shape_op = fluid.layers.shape(theta)
theta_shape_op.stop_gradient = True
if nonlocal_params["use_maxpool"]:
max_pool = fluid.layers.pool2d(input = cur, \
pool_size = [max_pool_stride, max_pool_stride], \
pool_type = 'max', \
pool_stride = [max_pool_stride, max_pool_stride], \
pool_padding = [0, 0], \
name = prefix + '_pool')
else:
max_pool = cur
phi = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_phi' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_phi' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_phi')
phi_shape = phi.shape
g = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_g' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0, scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_g' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_g')
g_shape = g.shape
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta = fluid.layers.reshape(theta, shape=(0, 0, -1))
theta = fluid.layers.transpose(theta, [0, 2, 1])
phi = fluid.layers.reshape(phi, [0, 0, -1])
theta_phi = fluid.layers.matmul(theta, phi, name=prefix + '_affinity')
g = fluid.layers.reshape(g, [0, 0, -1])
if nonlocal_params["use_softmax"]:
if nonlocal_params["use_scale"]:
theta_phi_sc = fluid.layers.scale(theta_phi, scale=dim_inner**-.5)
else:
theta_phi_sc = theta_phi
p = fluid.layers.softmax(
theta_phi_sc, name=prefix + '_affinity' + '_prob')
else:
# not clear about what is doing in xlw's code
p = None # not implemented
raise "Not implemented when not use softmax"
# note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p = fluid.layers.transpose(p, [0, 2, 1])
t = fluid.layers.matmul(g, p, name=prefix + '_y')
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_shape = t.shape
t_re = fluid.layers.reshape(
t, shape=list(theta_shape), actual_shape=theta_shape_op)
blob_out = t_re
blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \
filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_out' + "_w", \
initializer = fluid.initializer.Constant(value = 0.) \
if nonlocal_params["use_zero_init_conv"] \
else fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_out' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_out')
blob_out_shape = blob_out.shape
if nonlocal_params["use_bn"]:
bn_name = prefix + "_bn"
blob_out = fluid.layers.batch_norm(blob_out, \
# is_test = test_mode, \
momentum = nonlocal_params["bn_momentum"], \
epsilon = nonlocal_params["bn_epsilon"], \
name = bn_name, \
param_attr = ParamAttr(name = bn_name + "_s", \
initializer = fluid.initializer.Constant(value = nonlocal_params["bn_init_gamma"]), \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
bias_attr = ParamAttr(name = bn_name + "_b", \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
moving_mean_name = bn_name + "_rm", \
moving_variance_name = bn_name + "_riv") # add bn
if nonlocal_params["use_affine"]:
affine_scale = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_s'), \
default_initializer = fluid.initializer.Constant(value = 1.))
affine_bias = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_b'), \
default_initializer = fluid.initializer.Constant(value = 0.))
blob_out = fluid.layers.affine_channel(blob_out, scale = affine_scale, \
bias = affine_bias, name = prefix + '_affine') # add affine
return blob_out
def add_space_nonlocal(input, dim_in, dim_out, prefix, dim_inner):
'''
add_space_nonlocal:
Non-local Neural Networks: see https://arxiv.org/abs/1711.07971
'''
conv = space_nonlocal(input, dim_in, dim_out, prefix, dim_inner)
output = fluid.layers.elementwise_add(input, conv, name=prefix + '_sum')
return output
......@@ -101,7 +101,7 @@ def postprocess(paths,
handle_id,
visualization=True):
"""
postprocess the lod_tensor produced by fluid.Executor.run
postprocess the lod_tensor produced by Executor.run
Args:
paths (list[str]): The paths of images.
......@@ -126,9 +126,8 @@ def postprocess(paths,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor = data_out[0]
lod = lod_tensor.lod[0]
results = lod_tensor.as_ndarray()
lod = data_out.lod()[0]
results = data_out.copy_to_cpu()
check_dir(output_dir)
......
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from collections import OrderedDict
from numbers import Integral
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import Constant
from .nonlocal_helper import add_space_nonlocal
from .name_adapter import NameAdapter
__all__ = ['ResNet', 'ResNetC5']
class ResNet(object):
"""
Residual Network, see https://arxiv.org/abs/1512.03385
Args:
depth (int): ResNet depth, should be 34, 50.
freeze_at (int): freeze the backbone at which stage
norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel'
freeze_norm (bool): freeze normalization layers
norm_decay (float): weight decay for normalization layer weights
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
nonlocal_stages (list): index of stages who select nonlocal networks
"""
__shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
def __init__(self,
depth=50,
freeze_at=0,
norm_type='sync_bn',
freeze_norm=False,
norm_decay=0.,
variant='b',
feature_maps=[3, 4, 5],
dcn_v2_stages=[],
weight_prefix_name='',
nonlocal_stages=[],
get_prediction=False,
class_dim=1000):
super(ResNet, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
assert depth in [34, 50], \
"depth {} not in [34, 50]"
assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
assert len(feature_maps) > 0, "need one or more feature maps"
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and depth<50), \
"non-local is not supported for resnet18 or resnet34"
self.depth = depth
self.freeze_at = freeze_at
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.variant = variant
self._model_type = 'ResNet'
self.feature_maps = feature_maps
self.dcn_v2_stages = dcn_v2_stages
self.depth_cfg = {
34: ([3, 4, 6, 3], self.basicblock),
50: ([3, 4, 6, 3], self.bottleneck),
}
self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64
self.na = NameAdapter(self)
self.prefix_name = weight_prefix_name
self.nonlocal_stages = nonlocal_stages
self.nonlocal_mod_cfg = {
50: 2,
101: 5,
152: 8,
200: 12,
}
self.get_prediction = get_prediction
self.class_dim = class_dim
def _conv_offset(self,
input,
filter_size,
stride,
padding,
act=None,
name=None):
out_channel = filter_size * filter_size * 3
out = fluid.layers.conv2d(
input,
num_filters=out_channel,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=ParamAttr(initializer=Constant(0.0), name=name + ".w_0"),
bias_attr=ParamAttr(initializer=Constant(0.0), name=name + ".b_0"),
act=act,
name=name)
return out
def _conv_norm(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None,
dcn_v2=False):
_name = self.prefix_name + name if self.prefix_name != '' else name
if not dcn_v2:
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=_name + "_weights"),
bias_attr=False,
name=_name + '.conv2d.output.1')
else:
# select deformable conv"
offset_mask = self._conv_offset(
input=input,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
act=None,
name=_name + "_conv_offset")
offset_channel = filter_size**2 * 2
mask_channel = filter_size**2
offset, mask = fluid.layers.split(
input=offset_mask,
num_or_sections=[offset_channel, mask_channel],
dim=1)
mask = fluid.layers.sigmoid(mask)
conv = fluid.layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
deformable_groups=1,
im2col_step=1,
param_attr=ParamAttr(name=_name + "_weights"),
bias_attr=False,
name=_name + ".conv2d.output.1")
bn_name = self.na.fix_conv_norm_name(name)
bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
norm_lr = 0. if self.freeze_norm else 1.
norm_decay = self.norm_decay
pattr = ParamAttr(
name=bn_name + '_scale',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
battr = ParamAttr(
name=bn_name + '_offset',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
if self.norm_type in ['bn', 'sync_bn']:
global_stats = True if self.freeze_norm else False
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif self.norm_type == 'affine_channel':
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=pattr,
default_initializer=fluid.initializer.Constant(1.))
bias = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=battr,
default_initializer=fluid.initializer.Constant(0.))
out = fluid.layers.affine_channel(
x=conv, scale=scale, bias=bias, act=act)
if self.freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
def _shortcut(self, input, ch_out, stride, is_first, name):
max_pooling_in_short_cut = self.variant == 'd'
ch_in = input.shape[1]
# the naming rule is same as pretrained weight
name = self.na.fix_shortcut_name(name)
std_senet = getattr(self, 'std_senet', False)
if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
if std_senet:
if is_first:
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return self._conv_norm(input, ch_out, 3, stride, name=name)
if max_pooling_in_short_cut and not is_first:
input = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
ceil_mode=True,
pool_type='avg')
return self._conv_norm(input, ch_out, 1, 1, name=name)
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck(self,
input,
num_filters,
stride,
is_first,
name,
dcn_v2=False):
if self.variant == 'a':
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
# ResNeXt
groups = getattr(self, 'groups', 1)
group_width = getattr(self, 'group_width', -1)
if groups == 1:
expand = 4
elif (groups * group_width) == 256:
expand = 1
else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
num_filters = num_filters // 2
expand = 2
conv_name1, conv_name2, conv_name3, \
shortcut_name = self.na.fix_bottleneck_name(name)
std_senet = getattr(self, 'std_senet', False)
if std_senet:
conv_def = [[
int(num_filters / 2), 1, stride1, 'relu', 1, conv_name1
], [num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
else:
conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1],
[num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
residual = input
for i, (c, k, s, act, g, _name) in enumerate(conv_def):
residual = self._conv_norm(
input=residual,
num_filters=c,
filter_size=k,
stride=s,
act=act,
groups=g,
name=_name,
dcn_v2=(i == 1 and dcn_v2))
short = self._shortcut(
input,
num_filters * expand,
stride,
is_first=is_first,
name=shortcut_name)
# Squeeze-and-Excitation
if callable(getattr(self, '_squeeze_excitation', None)):
residual = self._squeeze_excitation(
input=residual, num_channels=num_filters, name='fc' + name)
return fluid.layers.elementwise_add(
x=short, y=residual, act='relu', name=name + ".add.output.5")
def basicblock(self,
input,
num_filters,
stride,
is_first,
name,
dcn_v2=False):
assert dcn_v2 is False, "Not implemented yet."
conv0 = self._conv_norm(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self._conv_norm(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self._shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def layer_warp(self, input, stage_num):
"""
Args:
input (Variable): input variable.
stage_num (int): the stage number, should be 2, 3, 4, 5
Returns:
The last variable in endpoint-th stage.
"""
assert stage_num in [2, 3, 4, 5]
stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2]
ch_out = self.stage_filters[stage_num - 2]
is_first = False if stage_num != 2 else True
dcn_v2 = True if stage_num in self.dcn_v2_stages else False
nonlocal_mod = 1000
if stage_num in self.nonlocal_stages:
nonlocal_mod = self.nonlocal_mod_cfg[
self.depth] if stage_num == 4 else 2
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv = input
for i in range(count):
conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
if self.depth < 50:
is_first = True if i == 0 and stage_num == 2 else False
conv = block_func(
input=conv,
num_filters=ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
is_first=is_first,
name=conv_name,
dcn_v2=dcn_v2)
# add non local model
dim_in = conv.shape[1]
nonlocal_name = "nonlocal_conv{}".format(stage_num)
if i % nonlocal_mod == nonlocal_mod - 1:
conv = add_space_nonlocal(conv, dim_in, dim_in,
nonlocal_name + '_{}'.format(i),
int(dim_in / 2))
return conv
def c1_stage(self, input):
out_chan = self._c1_out_chan_num
conv1_name = self.na.fix_c1_stage_name()
if self.variant in ['c', 'd']:
conv_def = [
[out_chan // 2, 3, 2, "conv1_1"],
[out_chan // 2, 3, 1, "conv1_2"],
[out_chan, 3, 1, "conv1_3"],
]
else:
conv_def = [[out_chan, 7, 2, conv1_name]]
for (c, k, s, _name) in conv_def:
input = self._conv_norm(
input=input,
num_filters=c,
filter_size=k,
stride=s,
act='relu',
name=_name)
output = fluid.layers.pool2d(
input=input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
return output
def __call__(self, input):
assert isinstance(input, Variable)
assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
"feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
res_endpoints = []
res = input
feature_maps = self.feature_maps
severed_head = getattr(self, 'severed_head', False)
if not severed_head:
res = self.c1_stage(res)
feature_maps = range(2, max(self.feature_maps) + 1)
for i in feature_maps:
res = self.layer_warp(res, i)
if i in self.feature_maps:
res_endpoints.append(res)
if self.freeze_at >= i:
res.stop_gradient = True
if self.get_prediction:
pool = fluid.layers.pool2d(
input=res, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
out = fluid.layers.softmax(out)
return out
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)])
class ResNetC5(ResNet):
def __init__(self,
depth=50,
freeze_at=2,
norm_type='affine_channel',
freeze_norm=True,
norm_decay=0.,
variant='b',
feature_maps=[5],
weight_prefix_name=''):
super(ResNetC5, self).__init__(depth, freeze_at, norm_type, freeze_norm,
norm_decay, variant, feature_maps)
self.severed_head = True
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="yolov3_resnet50_vd_coco2017")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_object_detection1(self):
results = self.module.object_detection(
paths=['tests/test.jpg']
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(1000 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(1000 < bottom < 4500)
def test_object_detection2(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')]
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(1000 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(1000 < bottom < 4500)
def test_object_detection3(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(1000 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(1000 < bottom < 4500)
def test_object_detection4(self):
self.assertRaises(
AssertionError,
self.module.object_detection,
paths=['no.jpg']
)
def test_object_detection5(self):
self.assertRaises(
AttributeError,
self.module.object_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
__all__ = ['MultiClassNMS', 'YOLOv3Head']
class MultiClassNMS(object):
# __op__ = fluid.layers.multiclass_nms
def __init__(self, background_label, keep_top_k, nms_threshold, nms_top_k,
normalized, score_threshold):
super(MultiClassNMS, self).__init__()
self.background_label = background_label
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.normalized = normalized
self.score_threshold = score_threshold
class YOLOv3Head(object):
"""Head block for YOLOv3 network
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
"""
def __init__(self,
norm_decay=0.,
num_classes=80,
ignore_thresh=0.7,
label_smooth=True,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
nms=MultiClassNMS(
background_label=-1,
keep_top_k=100,
nms_threshold=0.45,
nms_top_k=1000,
normalized=True,
score_threshold=0.01),
weight_prefix_name=''):
self.norm_decay = norm_decay
self.num_classes = num_classes
self.ignore_thresh = ignore_thresh
self.label_smooth = label_smooth
self.anchor_masks = anchor_masks
self._parse_anchors(anchors)
self.nms = nms
self.prefix_name = weight_prefix_name
def _conv_bn(self,
input,
ch_out,
filter_size,
stride,
padding,
act='leaky',
is_test=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
bn_param_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
bn_bias_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
out = fluid.layers.batch_norm(
input=conv,
act=None,
is_test=is_test,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
def _detection_block(self, input, channel, is_test=True, name=None):
assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
.format(channel, name)
conv = input
for j in range(2):
conv = self._conv_bn(
conv,
channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.{}.0'.format(name, j))
conv = self._conv_bn(
conv,
channel * 2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test,
name='{}.{}.1'.format(name, j))
route = self._conv_bn(
conv,
channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.2'.format(name))
tip = self._conv_bn(
route,
channel * 2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test,
name='{}.tip'.format(name))
return route, tip
def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
return out
def _parse_anchors(self, anchors):
"""
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
"""
self.anchors = []
self.mask_anchors = []
assert len(anchors) > 0, "ANCHORS not set."
assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
for anchor in anchors:
assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
self.anchors.extend(anchor)
anchor_num = len(anchors)
for masks in self.anchor_masks:
self.mask_anchors.append([])
for mask in masks:
assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask])
def _get_outputs(self, input, is_train=True):
"""
Get YOLOv3 head output
Args:
input (list): List of Variables, output of backbone stages
is_train (bool): whether in train or test mode
Returns:
outputs (list): Variables of each output layer
"""
outputs = []
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
if isinstance(input, OrderedDict):
blocks = list(input.values())[-1:-out_layer_num - 1:-1]
else:
blocks = input[-1:-out_layer_num - 1:-1]
route = None
for i, block in enumerate(blocks):
if i > 0: # perform concat in first 2 detection_block
block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self._detection_block(
block,
channel=512 // (2**i),
is_test=(not is_train),
name=self.prefix_name + "yolo_block.{}".format(i))
# out channel number = mask_num * (5 + class_num)
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
block_out = fluid.layers.conv2d(
input=tip,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name=self.prefix_name +
"yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name=self.prefix_name +
"yolo_output.{}.conv.bias".format(i)))
outputs.append(block_out)
if i < len(blocks) - 1:
# do not perform upsample in the last detection_block
route = self._conv_bn(
input=route,
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not is_train),
name=self.prefix_name + "yolo_transition.{}".format(i))
# upsample
route = self._upsample(route)
return outputs, blocks
def get_prediction(self, outputs, im_size):
"""
Get prediction result of YOLOv3 network
Args:
outputs (list): list of Variables, return from _get_outputs
im_size (Variable): Variable of size([h, w]) of each image
Returns:
pred (Variable): The prediction result after non-max suppress.
"""
boxes = []
scores = []
downsample = 32
for i, output in enumerate(outputs):
box, score = fluid.layers.yolo_box(
x=output,
img_size=im_size,
anchors=self.mask_anchors[i],
class_num=self.num_classes,
conf_thresh=self.nms.score_threshold,
downsample_ratio=downsample,
name=self.prefix_name + "yolo_box" + str(i))
boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
downsample //= 2
yolo_boxes = fluid.layers.concat(boxes, axis=1)
yolo_scores = fluid.layers.concat(scores, axis=2)
pred = fluid.layers.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.nms.score_threshold,
nms_top_k=self.nms.nms_top_k,
keep_top_k=self.nms.keep_top_k,
nms_threshold=self.nms.nms_threshold,
background_label=self.nms.background_label,
normalized=self.nms.normalized,
name="multiclass_nms")
return pred
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册