未验证 提交 9d700dd6 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update yolov3_resnet50_vd_coco2017 (#1954)

* update yolov3_resnet50_vd_coco2017

* update unittest

* update gpu config

* update

* add clean func

* update save inference model
上级 12a6bf92
...@@ -100,20 +100,13 @@ ...@@ -100,20 +100,13 @@
- save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在) - save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- 将模型保存到指定路径。 - 将模型保存到指定路径。
- **参数** - **参数**
- dirname: 存在模型的目录名称; <br/> - dirname: 模型保存路径 <br/>
- model\_filename: 模型文件名称,默认为\_\_model\_\_; <br/>
- params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);<br/>
- combined: 是否将参数保存到统一的一个文件中。
## 四、服务部署 ## 四、服务部署
...@@ -166,6 +159,10 @@ ...@@ -166,6 +159,10 @@
修复numpy数据读取问题 修复numpy数据读取问题
* 1.1.0
移除 fluid api
- ```shell - ```shell
$ hub install yolov3_resnet50_vd_coco2017==1.0.2 $ hub install yolov3_resnet50_vd_coco2017==1.1.0
``` ```
...@@ -99,19 +99,13 @@ ...@@ -99,19 +99,13 @@
- save\_path (str, optional): output path for saving results - save\_path (str, optional): output path for saving results
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- Save model to specific path - Save model to specific path
- **Parameters** - **Parameters**
- dirname: output dir for saving model - dirname: save model path
- model\_filename: filename for saving model
- params\_filename: filename for saving parameters
- combined: whether save parameters into one file
## IV.Server Deployment ## IV.Server Deployment
...@@ -165,6 +159,10 @@ ...@@ -165,6 +159,10 @@
Fix the problem of reading numpy Fix the problem of reading numpy
* 1.1.0
Remove fluid api
- ```shell - ```shell
$ hub install yolov3_resnet50_vd_coco2017==1.0.2 $ hub install yolov3_resnet50_vd_coco2017==1.1.0
``` ```
...@@ -6,31 +6,28 @@ import argparse ...@@ -6,31 +6,28 @@ import argparse
import os import os
from functools import partial from functools import partial
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.static
import paddlehub as hub from paddle.inference import Config, create_predictor
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from yolov3_resnet50_vd_coco2017.resnet import ResNet from .processor import load_label_info, postprocess, base64_to_cv2
from yolov3_resnet50_vd_coco2017.processor import load_label_info, postprocess, base64_to_cv2 from .data_feed import reader
from yolov3_resnet50_vd_coco2017.data_feed import reader
from yolov3_resnet50_vd_coco2017.yolo_head import MultiClassNMS, YOLOv3Head
@moduleinfo( @moduleinfo(
name="yolov3_resnet50_vd_coco2017", name="yolov3_resnet50_vd_coco2017",
version="1.0.2", version="1.1.0",
type="CV/object_detection", type="CV/object_detection",
summary= summary=
"Baidu's YOLOv3 model for object detection with backbone ResNet50, trained with dataset coco2017.", "Baidu's YOLOv3 model for object detection with backbone ResNet50, trained with dataset coco2017.",
author="paddlepaddle", author="paddlepaddle",
author_email="paddle-dev@baidu.com") author_email="paddle-dev@baidu.com")
class YOLOv3ResNet50Coco2017(hub.Module): class YOLOv3ResNet50Coco2017:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join( self.default_pretrained_model_path = os.path.join(
self.directory, "yolov3_resnet50_model") self.directory, "yolov3_resnet50_model", "model")
self.label_names = load_label_info( self.label_names = load_label_info(
os.path.join(self.directory, "label_file.txt")) os.path.join(self.directory, "label_file.txt"))
self._set_config() self._set_config()
...@@ -39,11 +36,13 @@ class YOLOv3ResNet50Coco2017(hub.Module): ...@@ -39,11 +36,13 @@ class YOLOv3ResNet50Coco2017(hub.Module):
""" """
predictor config setting. predictor config setting.
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path) model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
cpu_config.switch_ir_optim(False) cpu_config.switch_ir_optim(False)
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
...@@ -52,110 +51,10 @@ class YOLOv3ResNet50Coco2017(hub.Module): ...@@ -52,110 +51,10 @@ class YOLOv3ResNet50Coco2017(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
def context(self, trainable=True, pretrained=True, get_prediction=False):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
# image
image = fluid.layers.data(
name='image', shape=[3, 608, 608], dtype='float32')
# backbone
backbone = ResNet(
norm_type='sync_bn',
freeze_at=0,
freeze_norm=False,
norm_decay=0.,
dcn_v2_stages=[5],
depth=50,
variant='d',
feature_maps=[3, 4, 5])
# body_feats
body_feats = backbone(image)
# im_size
im_size = fluid.layers.data(
name='im_size', shape=[2], dtype='int32')
# yolo_head
yolo_head = YOLOv3Head(num_classes=80)
# head_features
head_features, body_features = yolo_head._get_outputs(
body_feats, is_train=trainable)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# var_prefix
var_prefix = '@HUB_{}@'.format(self.name)
# name of inputs
inputs = {
'image': var_prefix + image.name,
'im_size': var_prefix + im_size.name
}
# name of outputs
if get_prediction:
bbox_out = yolo_head.get_prediction(head_features, im_size)
outputs = {'bbox_out': [var_prefix + bbox_out.name]}
else:
outputs = {
'head_features':
[var_prefix + var.name for var in head_features],
'body_features':
[var_prefix + var.name for var in body_features]
}
# add_vars_prefix
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(fluid.default_startup_program(), var_prefix)
# inputs
inputs = {
key: context_prog.global_block().vars[value]
for key, value in inputs.items()
}
# outputs
outputs = {
key: [
context_prog.global_block().vars[varname]
for varname in value
]
for key, value in outputs.items()
}
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
# pretrained
if pretrained:
def _if_exist(var):
return os.path.exists(
os.path.join(self.default_pretrained_model_path,
var.name))
fluid.io.load_vars(
exe,
self.default_pretrained_model_path,
predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
def object_detection(self, def object_detection(self,
paths=None, paths=None,
...@@ -198,23 +97,25 @@ class YOLOv3ResNet50Coco2017(hub.Module): ...@@ -198,23 +97,25 @@ class YOLOv3ResNet50Coco2017(hub.Module):
paths = paths if paths else list() paths = paths if paths else list()
data_reader = partial(reader, paths, images) data_reader = partial(reader, paths, images)
batch_reader = fluid.io.batch(data_reader, batch_size=batch_size) batch_reader = paddle.batch(data_reader, batch_size=batch_size)
res = [] res = []
for iter_id, feed_data in enumerate(batch_reader()): for iter_id, feed_data in enumerate(batch_reader()):
feed_data = np.array(feed_data) feed_data = np.array(feed_data)
image_tensor = PaddleTensor(np.array(list(feed_data[:, 0])))
im_size_tensor = PaddleTensor(np.array(list(feed_data[:, 1])))
if use_gpu:
data_out = self.gpu_predictor.run(
[image_tensor, im_size_tensor])
else:
data_out = self.cpu_predictor.run(
[image_tensor, im_size_tensor])
output = postprocess( predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
paths=paths, input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(np.array(list(feed_data[:, 0])))
input_handle = predictor.get_input_handle(input_names[1])
input_handle.copy_from_cpu(np.array(list(feed_data[:, 1])))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output = postprocess(paths=paths,
images=images, images=images,
data_out=data_out, data_out=output_handle,
score_thresh=score_thresh, score_thresh=score_thresh,
label_names=self.label_names, label_names=self.label_names,
output_dir=output_dir, output_dir=output_dir,
...@@ -223,29 +124,6 @@ class YOLOv3ResNet50Coco2017(hub.Module): ...@@ -223,29 +124,6 @@ class YOLOv3ResNet50Coco2017(hub.Module):
res.extend(output) res.extend(output)
return res return res
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
......
# coding=utf-8
class NameAdapter(object):
"""Fix the backbones variable names for pretrained weight"""
def __init__(self, model):
super(NameAdapter, self).__init__()
self.model = model
@property
def model_type(self):
return getattr(self.model, '_model_type', '')
@property
def variant(self):
return getattr(self.model, 'variant', '')
def fix_conv_norm_name(self, name):
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
# the naming rule is same as pretrained weight
if self.model_type == 'SEResNeXt':
bn_name = name + "_bn"
return bn_name
def fix_shortcut_name(self, name):
if self.model_type == 'SEResNeXt':
name = 'conv' + name + '_prj'
return name
def fix_bottleneck_name(self, name):
if self.model_type == 'SEResNeXt':
conv_name1 = 'conv' + name + '_x1'
conv_name2 = 'conv' + name + '_x2'
conv_name3 = 'conv' + name + '_x3'
shortcut_name = name
else:
conv_name1 = name + "_branch2a"
conv_name2 = name + "_branch2b"
conv_name3 = name + "_branch2c"
shortcut_name = name + "_branch1"
return conv_name1, conv_name2, conv_name3, shortcut_name
def fix_layer_warp_name(self, stage_num, count, i):
name = 'res' + str(stage_num)
if count > 10 and stage_num == 4:
if i == 0:
conv_name = name + "a"
else:
conv_name = name + "b" + str(i)
else:
conv_name = name + chr(ord("a") + i)
if self.model_type == 'SEResNeXt':
conv_name = str(stage_num + 2) + '_' + str(i + 1)
return conv_name
def fix_c1_stage_name(self):
return "res_conv1" if self.model_type == 'ResNeXt' else "conv1"
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
nonlocal_params = {
"use_zero_init_conv": False,
"conv_init_std": 0.01,
"no_bias": True,
"use_maxpool": False,
"use_softmax": True,
"use_bn": False,
"use_scale": True, # vital for the model prformance!!!
"use_affine": False,
"bn_momentum": 0.9,
"bn_epsilon": 1.0000001e-5,
"bn_init_gamma": 0.9,
"weight_decay_bn": 1.e-4,
}
def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner,
max_pool_stride=2):
cur = input
theta = fluid.layers.conv2d(input = cur, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr=ParamAttr(name = prefix + '_theta' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_theta' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if not nonlocal_params["no_bias"] else False, \
name = prefix + '_theta')
theta_shape = theta.shape
theta_shape_op = fluid.layers.shape(theta)
theta_shape_op.stop_gradient = True
if nonlocal_params["use_maxpool"]:
max_pool = fluid.layers.pool2d(input = cur, \
pool_size = [max_pool_stride, max_pool_stride], \
pool_type = 'max', \
pool_stride = [max_pool_stride, max_pool_stride], \
pool_padding = [0, 0], \
name = prefix + '_pool')
else:
max_pool = cur
phi = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_phi' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_phi' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_phi')
phi_shape = phi.shape
g = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_g' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0, scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_g' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_g')
g_shape = g.shape
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta = fluid.layers.reshape(theta, shape=(0, 0, -1))
theta = fluid.layers.transpose(theta, [0, 2, 1])
phi = fluid.layers.reshape(phi, [0, 0, -1])
theta_phi = fluid.layers.matmul(theta, phi, name=prefix + '_affinity')
g = fluid.layers.reshape(g, [0, 0, -1])
if nonlocal_params["use_softmax"]:
if nonlocal_params["use_scale"]:
theta_phi_sc = fluid.layers.scale(theta_phi, scale=dim_inner**-.5)
else:
theta_phi_sc = theta_phi
p = fluid.layers.softmax(
theta_phi_sc, name=prefix + '_affinity' + '_prob')
else:
# not clear about what is doing in xlw's code
p = None # not implemented
raise "Not implemented when not use softmax"
# note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p = fluid.layers.transpose(p, [0, 2, 1])
t = fluid.layers.matmul(g, p, name=prefix + '_y')
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_shape = t.shape
t_re = fluid.layers.reshape(
t, shape=list(theta_shape), actual_shape=theta_shape_op)
blob_out = t_re
blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \
filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_out' + "_w", \
initializer = fluid.initializer.Constant(value = 0.) \
if nonlocal_params["use_zero_init_conv"] \
else fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_out' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_out')
blob_out_shape = blob_out.shape
if nonlocal_params["use_bn"]:
bn_name = prefix + "_bn"
blob_out = fluid.layers.batch_norm(blob_out, \
# is_test = test_mode, \
momentum = nonlocal_params["bn_momentum"], \
epsilon = nonlocal_params["bn_epsilon"], \
name = bn_name, \
param_attr = ParamAttr(name = bn_name + "_s", \
initializer = fluid.initializer.Constant(value = nonlocal_params["bn_init_gamma"]), \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
bias_attr = ParamAttr(name = bn_name + "_b", \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
moving_mean_name = bn_name + "_rm", \
moving_variance_name = bn_name + "_riv") # add bn
if nonlocal_params["use_affine"]:
affine_scale = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_s'), \
default_initializer = fluid.initializer.Constant(value = 1.))
affine_bias = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_b'), \
default_initializer = fluid.initializer.Constant(value = 0.))
blob_out = fluid.layers.affine_channel(blob_out, scale = affine_scale, \
bias = affine_bias, name = prefix + '_affine') # add affine
return blob_out
def add_space_nonlocal(input, dim_in, dim_out, prefix, dim_inner):
'''
add_space_nonlocal:
Non-local Neural Networks: see https://arxiv.org/abs/1711.07971
'''
conv = space_nonlocal(input, dim_in, dim_out, prefix, dim_inner)
output = fluid.layers.elementwise_add(input, conv, name=prefix + '_sum')
return output
...@@ -101,7 +101,7 @@ def postprocess(paths, ...@@ -101,7 +101,7 @@ def postprocess(paths,
handle_id, handle_id,
visualization=True): visualization=True):
""" """
postprocess the lod_tensor produced by fluid.Executor.run postprocess the lod_tensor produced by Executor.run
Args: Args:
paths (list[str]): The paths of images. paths (list[str]): The paths of images.
...@@ -126,9 +126,8 @@ def postprocess(paths, ...@@ -126,9 +126,8 @@ def postprocess(paths,
confidence (float): The confidence of detection result. confidence (float): The confidence of detection result.
save_path (str): The path to save output images. save_path (str): The path to save output images.
""" """
lod_tensor = data_out[0] lod = data_out.lod()[0]
lod = lod_tensor.lod[0] results = data_out.copy_to_cpu()
results = lod_tensor.as_ndarray()
check_dir(output_dir) check_dir(output_dir)
......
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from collections import OrderedDict
from numbers import Integral
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import Constant
from .nonlocal_helper import add_space_nonlocal
from .name_adapter import NameAdapter
__all__ = ['ResNet', 'ResNetC5']
class ResNet(object):
"""
Residual Network, see https://arxiv.org/abs/1512.03385
Args:
depth (int): ResNet depth, should be 34, 50.
freeze_at (int): freeze the backbone at which stage
norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel'
freeze_norm (bool): freeze normalization layers
norm_decay (float): weight decay for normalization layer weights
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
nonlocal_stages (list): index of stages who select nonlocal networks
"""
__shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
def __init__(self,
depth=50,
freeze_at=0,
norm_type='sync_bn',
freeze_norm=False,
norm_decay=0.,
variant='b',
feature_maps=[3, 4, 5],
dcn_v2_stages=[],
weight_prefix_name='',
nonlocal_stages=[],
get_prediction=False,
class_dim=1000):
super(ResNet, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
assert depth in [34, 50], \
"depth {} not in [34, 50]"
assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
assert len(feature_maps) > 0, "need one or more feature maps"
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and depth<50), \
"non-local is not supported for resnet18 or resnet34"
self.depth = depth
self.freeze_at = freeze_at
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.variant = variant
self._model_type = 'ResNet'
self.feature_maps = feature_maps
self.dcn_v2_stages = dcn_v2_stages
self.depth_cfg = {
34: ([3, 4, 6, 3], self.basicblock),
50: ([3, 4, 6, 3], self.bottleneck),
}
self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64
self.na = NameAdapter(self)
self.prefix_name = weight_prefix_name
self.nonlocal_stages = nonlocal_stages
self.nonlocal_mod_cfg = {
50: 2,
101: 5,
152: 8,
200: 12,
}
self.get_prediction = get_prediction
self.class_dim = class_dim
def _conv_offset(self,
input,
filter_size,
stride,
padding,
act=None,
name=None):
out_channel = filter_size * filter_size * 3
out = fluid.layers.conv2d(
input,
num_filters=out_channel,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=ParamAttr(initializer=Constant(0.0), name=name + ".w_0"),
bias_attr=ParamAttr(initializer=Constant(0.0), name=name + ".b_0"),
act=act,
name=name)
return out
def _conv_norm(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None,
dcn_v2=False):
_name = self.prefix_name + name if self.prefix_name != '' else name
if not dcn_v2:
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=_name + "_weights"),
bias_attr=False,
name=_name + '.conv2d.output.1')
else:
# select deformable conv"
offset_mask = self._conv_offset(
input=input,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
act=None,
name=_name + "_conv_offset")
offset_channel = filter_size**2 * 2
mask_channel = filter_size**2
offset, mask = fluid.layers.split(
input=offset_mask,
num_or_sections=[offset_channel, mask_channel],
dim=1)
mask = fluid.layers.sigmoid(mask)
conv = fluid.layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
deformable_groups=1,
im2col_step=1,
param_attr=ParamAttr(name=_name + "_weights"),
bias_attr=False,
name=_name + ".conv2d.output.1")
bn_name = self.na.fix_conv_norm_name(name)
bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
norm_lr = 0. if self.freeze_norm else 1.
norm_decay = self.norm_decay
pattr = ParamAttr(
name=bn_name + '_scale',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
battr = ParamAttr(
name=bn_name + '_offset',
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
if self.norm_type in ['bn', 'sync_bn']:
global_stats = True if self.freeze_norm else False
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif self.norm_type == 'affine_channel':
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=pattr,
default_initializer=fluid.initializer.Constant(1.))
bias = fluid.layers.create_parameter(
shape=[conv.shape[1]],
dtype=conv.dtype,
attr=battr,
default_initializer=fluid.initializer.Constant(0.))
out = fluid.layers.affine_channel(
x=conv, scale=scale, bias=bias, act=act)
if self.freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
def _shortcut(self, input, ch_out, stride, is_first, name):
max_pooling_in_short_cut = self.variant == 'd'
ch_in = input.shape[1]
# the naming rule is same as pretrained weight
name = self.na.fix_shortcut_name(name)
std_senet = getattr(self, 'std_senet', False)
if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
if std_senet:
if is_first:
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return self._conv_norm(input, ch_out, 3, stride, name=name)
if max_pooling_in_short_cut and not is_first:
input = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
ceil_mode=True,
pool_type='avg')
return self._conv_norm(input, ch_out, 1, 1, name=name)
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck(self,
input,
num_filters,
stride,
is_first,
name,
dcn_v2=False):
if self.variant == 'a':
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
# ResNeXt
groups = getattr(self, 'groups', 1)
group_width = getattr(self, 'group_width', -1)
if groups == 1:
expand = 4
elif (groups * group_width) == 256:
expand = 1
else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
num_filters = num_filters // 2
expand = 2
conv_name1, conv_name2, conv_name3, \
shortcut_name = self.na.fix_bottleneck_name(name)
std_senet = getattr(self, 'std_senet', False)
if std_senet:
conv_def = [[
int(num_filters / 2), 1, stride1, 'relu', 1, conv_name1
], [num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
else:
conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1],
[num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
residual = input
for i, (c, k, s, act, g, _name) in enumerate(conv_def):
residual = self._conv_norm(
input=residual,
num_filters=c,
filter_size=k,
stride=s,
act=act,
groups=g,
name=_name,
dcn_v2=(i == 1 and dcn_v2))
short = self._shortcut(
input,
num_filters * expand,
stride,
is_first=is_first,
name=shortcut_name)
# Squeeze-and-Excitation
if callable(getattr(self, '_squeeze_excitation', None)):
residual = self._squeeze_excitation(
input=residual, num_channels=num_filters, name='fc' + name)
return fluid.layers.elementwise_add(
x=short, y=residual, act='relu', name=name + ".add.output.5")
def basicblock(self,
input,
num_filters,
stride,
is_first,
name,
dcn_v2=False):
assert dcn_v2 is False, "Not implemented yet."
conv0 = self._conv_norm(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self._conv_norm(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self._shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def layer_warp(self, input, stage_num):
"""
Args:
input (Variable): input variable.
stage_num (int): the stage number, should be 2, 3, 4, 5
Returns:
The last variable in endpoint-th stage.
"""
assert stage_num in [2, 3, 4, 5]
stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2]
ch_out = self.stage_filters[stage_num - 2]
is_first = False if stage_num != 2 else True
dcn_v2 = True if stage_num in self.dcn_v2_stages else False
nonlocal_mod = 1000
if stage_num in self.nonlocal_stages:
nonlocal_mod = self.nonlocal_mod_cfg[
self.depth] if stage_num == 4 else 2
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv = input
for i in range(count):
conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
if self.depth < 50:
is_first = True if i == 0 and stage_num == 2 else False
conv = block_func(
input=conv,
num_filters=ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
is_first=is_first,
name=conv_name,
dcn_v2=dcn_v2)
# add non local model
dim_in = conv.shape[1]
nonlocal_name = "nonlocal_conv{}".format(stage_num)
if i % nonlocal_mod == nonlocal_mod - 1:
conv = add_space_nonlocal(conv, dim_in, dim_in,
nonlocal_name + '_{}'.format(i),
int(dim_in / 2))
return conv
def c1_stage(self, input):
out_chan = self._c1_out_chan_num
conv1_name = self.na.fix_c1_stage_name()
if self.variant in ['c', 'd']:
conv_def = [
[out_chan // 2, 3, 2, "conv1_1"],
[out_chan // 2, 3, 1, "conv1_2"],
[out_chan, 3, 1, "conv1_3"],
]
else:
conv_def = [[out_chan, 7, 2, conv1_name]]
for (c, k, s, _name) in conv_def:
input = self._conv_norm(
input=input,
num_filters=c,
filter_size=k,
stride=s,
act='relu',
name=_name)
output = fluid.layers.pool2d(
input=input,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
return output
def __call__(self, input):
assert isinstance(input, Variable)
assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
"feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
res_endpoints = []
res = input
feature_maps = self.feature_maps
severed_head = getattr(self, 'severed_head', False)
if not severed_head:
res = self.c1_stage(res)
feature_maps = range(2, max(self.feature_maps) + 1)
for i in feature_maps:
res = self.layer_warp(res, i)
if i in self.feature_maps:
res_endpoints.append(res)
if self.freeze_at >= i:
res.stop_gradient = True
if self.get_prediction:
pool = fluid.layers.pool2d(
input=res, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
out = fluid.layers.softmax(out)
return out
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)])
class ResNetC5(ResNet):
def __init__(self,
depth=50,
freeze_at=2,
norm_type='affine_channel',
freeze_norm=True,
norm_decay=0.,
variant='b',
feature_maps=[5],
weight_prefix_name=''):
super(ResNetC5, self).__init__(depth, freeze_at, norm_type, freeze_norm,
norm_decay, variant, feature_maps)
self.severed_head = True
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="yolov3_resnet50_vd_coco2017")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_object_detection1(self):
results = self.module.object_detection(
paths=['tests/test.jpg']
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(1000 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(1000 < bottom < 4500)
def test_object_detection2(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')]
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(1000 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(1000 < bottom < 4500)
def test_object_detection3(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(1000 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(1000 < bottom < 4500)
def test_object_detection4(self):
self.assertRaises(
AssertionError,
self.module.object_detection,
paths=['no.jpg']
)
def test_object_detection5(self):
self.assertRaises(
AttributeError,
self.module.object_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
__all__ = ['MultiClassNMS', 'YOLOv3Head']
class MultiClassNMS(object):
# __op__ = fluid.layers.multiclass_nms
def __init__(self, background_label, keep_top_k, nms_threshold, nms_top_k,
normalized, score_threshold):
super(MultiClassNMS, self).__init__()
self.background_label = background_label
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.normalized = normalized
self.score_threshold = score_threshold
class YOLOv3Head(object):
"""Head block for YOLOv3 network
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
"""
def __init__(self,
norm_decay=0.,
num_classes=80,
ignore_thresh=0.7,
label_smooth=True,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
nms=MultiClassNMS(
background_label=-1,
keep_top_k=100,
nms_threshold=0.45,
nms_top_k=1000,
normalized=True,
score_threshold=0.01),
weight_prefix_name=''):
self.norm_decay = norm_decay
self.num_classes = num_classes
self.ignore_thresh = ignore_thresh
self.label_smooth = label_smooth
self.anchor_masks = anchor_masks
self._parse_anchors(anchors)
self.nms = nms
self.prefix_name = weight_prefix_name
def _conv_bn(self,
input,
ch_out,
filter_size,
stride,
padding,
act='leaky',
is_test=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
bn_param_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
bn_bias_attr = ParamAttr(
regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
out = fluid.layers.batch_norm(
input=conv,
act=None,
is_test=is_test,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
def _detection_block(self, input, channel, is_test=True, name=None):
assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
.format(channel, name)
conv = input
for j in range(2):
conv = self._conv_bn(
conv,
channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.{}.0'.format(name, j))
conv = self._conv_bn(
conv,
channel * 2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test,
name='{}.{}.1'.format(name, j))
route = self._conv_bn(
conv,
channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test,
name='{}.2'.format(name))
tip = self._conv_bn(
route,
channel * 2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test,
name='{}.tip'.format(name))
return route, tip
def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest(
input=input, scale=float(scale), name=name)
return out
def _parse_anchors(self, anchors):
"""
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
"""
self.anchors = []
self.mask_anchors = []
assert len(anchors) > 0, "ANCHORS not set."
assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
for anchor in anchors:
assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
self.anchors.extend(anchor)
anchor_num = len(anchors)
for masks in self.anchor_masks:
self.mask_anchors.append([])
for mask in masks:
assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask])
def _get_outputs(self, input, is_train=True):
"""
Get YOLOv3 head output
Args:
input (list): List of Variables, output of backbone stages
is_train (bool): whether in train or test mode
Returns:
outputs (list): Variables of each output layer
"""
outputs = []
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
if isinstance(input, OrderedDict):
blocks = list(input.values())[-1:-out_layer_num - 1:-1]
else:
blocks = input[-1:-out_layer_num - 1:-1]
route = None
for i, block in enumerate(blocks):
if i > 0: # perform concat in first 2 detection_block
block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self._detection_block(
block,
channel=512 // (2**i),
is_test=(not is_train),
name=self.prefix_name + "yolo_block.{}".format(i))
# out channel number = mask_num * (5 + class_num)
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
block_out = fluid.layers.conv2d(
input=tip,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name=self.prefix_name +
"yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name=self.prefix_name +
"yolo_output.{}.conv.bias".format(i)))
outputs.append(block_out)
if i < len(blocks) - 1:
# do not perform upsample in the last detection_block
route = self._conv_bn(
input=route,
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not is_train),
name=self.prefix_name + "yolo_transition.{}".format(i))
# upsample
route = self._upsample(route)
return outputs, blocks
def get_prediction(self, outputs, im_size):
"""
Get prediction result of YOLOv3 network
Args:
outputs (list): list of Variables, return from _get_outputs
im_size (Variable): Variable of size([h, w]) of each image
Returns:
pred (Variable): The prediction result after non-max suppress.
"""
boxes = []
scores = []
downsample = 32
for i, output in enumerate(outputs):
box, score = fluid.layers.yolo_box(
x=output,
img_size=im_size,
anchors=self.mask_anchors[i],
class_num=self.num_classes,
conf_thresh=self.nms.score_threshold,
downsample_ratio=downsample,
name=self.prefix_name + "yolo_box" + str(i))
boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
downsample //= 2
yolo_boxes = fluid.layers.concat(boxes, axis=1)
yolo_scores = fluid.layers.concat(scores, axis=2)
pred = fluid.layers.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.nms.score_threshold,
nms_top_k=self.nms.nms_top_k,
keep_top_k=self.nms.keep_top_k,
nms_threshold=self.nms.nms_threshold,
background_label=self.nms.background_label,
normalized=self.nms.normalized,
name="multiclass_nms")
return pred
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册