未验证 提交 3d3ca657 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update yolov3_darknet53_coco2017 (#1951)

* update ssd_vgg16_512_coco2017

* update unittest

* update unittest

* update version

* update gpu config

* update

* add clean func

* update save inference model
Co-authored-by: Nchenjian <chenjian26@baidu.com>
上级 0ed21a48
...@@ -100,19 +100,13 @@ ...@@ -100,19 +100,13 @@
- save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在) - save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- 将模型保存到指定路径。 - 将模型保存到指定路径。
- **参数** - **参数**
- dirname: 存在模型的目录名称; <br/> - dirname: 模型保存路径 <br/>
- model\_filename: 模型文件名称,默认为\_\_model\_\_; <br/>
- params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);<br/>
- combined: 是否将参数保存到统一的一个文件中。
## 四、服务部署 ## 四、服务部署
...@@ -165,6 +159,9 @@ ...@@ -165,6 +159,9 @@
* 1.1.1 * 1.1.1
修复numpy数据读取问题 修复numpy数据读取问题
* 1.2.0
移除 fluid api
- ```shell - ```shell
$ hub install yolov3_darknet53_coco2017==1.1.1 $ hub install yolov3_darknet53_coco2017==1.2.0
``` ```
...@@ -99,19 +99,13 @@ ...@@ -99,19 +99,13 @@
- save\_path (str, optional): output path for saving results - save\_path (str, optional): output path for saving results
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- Save model to specific path - Save model to specific path
- **Parameters** - **Parameters**
- dirname: output dir for saving model - dirname: model save path
- model\_filename: filename for saving model
- params\_filename: filename for saving parameters
- combined: whether save parameters into one file
## IV.Server Deployment ## IV.Server Deployment
...@@ -164,6 +158,9 @@ ...@@ -164,6 +158,9 @@
* 1.1.1 * 1.1.1
Fix the problem of reading numpy Fix the problem of reading numpy
* 1.2.0
Remove fluid api
- ```shell - ```shell
$ hub install yolov3_darknet53_coco2017==1.1.1 $ hub install yolov3_darknet53_coco2017==1.2.0
``` ```
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import math
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
__all__ = ['DarkNet']
class DarkNet(object):
"""DarkNet, see https://pjreddie.com/darknet/yolo/
Args:
depth (int): network depth, currently only darknet 53 is supported
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
norm_decay (float): weight decay for normalization layer weights
get_prediction (bool): whether to get prediction
class_dim (int): number of class while classification
"""
def __init__(self,
depth=53,
norm_type='sync_bn',
norm_decay=0.,
weight_prefix_name='',
get_prediction=False,
class_dim=1000):
assert depth in [53], "unsupported depth value"
self.depth = depth
self.norm_type = norm_type
self.norm_decay = norm_decay
self.depth_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)}
self.prefix_name = weight_prefix_name
self.class_dim = class_dim
self.get_prediction = get_prediction
def _conv_norm(self, input, ch_out, filter_size, stride, padding, act='leaky', name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
bn_param_attr = ParamAttr(regularizer=L2Decay(float(self.norm_decay)), name=bn_name + '.scale')
bn_bias_attr = ParamAttr(regularizer=L2Decay(float(self.norm_decay)), name=bn_name + '.offset')
out = fluid.layers.batch_norm(
input=conv,
act=None,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
# leaky relu here has `alpha` as 0.1, can not be set by
# `act` param in fluid.layers.batch_norm above.
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
def _downsample(self, input, ch_out, filter_size=3, stride=2, padding=1, name=None):
return self._conv_norm(input, ch_out=ch_out, filter_size=filter_size, stride=stride, padding=padding, name=name)
def basicblock(self, input, ch_out, name=None):
conv1 = self._conv_norm(input, ch_out=ch_out, filter_size=1, stride=1, padding=0, name=name + ".0")
conv2 = self._conv_norm(conv1, ch_out=ch_out * 2, filter_size=3, stride=1, padding=1, name=name + ".1")
out = fluid.layers.elementwise_add(x=input, y=conv2, act=None)
return out
def layer_warp(self, block_func, input, ch_out, count, name=None):
out = block_func(input, ch_out=ch_out, name='{}.0'.format(name))
for j in six.moves.xrange(1, count):
out = block_func(out, ch_out=ch_out, name='{}.{}'.format(name, j))
return out
def __call__(self, input):
"""
Get the backbone of DarkNet, that is output for the 5 stages.
"""
stages, block_func = self.depth_cfg[self.depth]
stages = stages[0:5]
conv = self._conv_norm(
input=input, ch_out=32, filter_size=3, stride=1, padding=1, name=self.prefix_name + "yolo_input")
downsample_ = self._downsample(
input=conv, ch_out=conv.shape[1] * 2, name=self.prefix_name + "yolo_input.downsample")
blocks = []
for i, stage in enumerate(stages):
block = self.layer_warp(
block_func=block_func,
input=downsample_,
ch_out=32 * 2**i,
count=stage,
name=self.prefix_name + "stage.{}".format(i))
blocks.append(block)
if i < len(stages) - 1: # do not downsaple in the last stage
downsample_ = self._downsample(
input=block, ch_out=block.shape[1] * 2, name=self.prefix_name + "stage.{}.downsample".format(i))
if self.get_prediction:
pool = fluid.layers.pool2d(input=block, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=self.class_dim,
param_attr=ParamAttr(initializer=fluid.initializer.Uniform(-stdv, stdv), name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
out = fluid.layers.softmax(out)
return out
else:
return blocks
...@@ -6,29 +6,27 @@ import argparse ...@@ -6,29 +6,27 @@ import argparse
import os import os
from functools import partial from functools import partial
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.jit
import paddlehub as hub import paddle.static
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from yolov3_darknet53_coco2017.darknet import DarkNet from .processor import load_label_info, postprocess, base64_to_cv2
from yolov3_darknet53_coco2017.processor import load_label_info, postprocess, base64_to_cv2 from .data_feed import reader
from yolov3_darknet53_coco2017.data_feed import reader
from yolov3_darknet53_coco2017.yolo_head import MultiClassNMS, YOLOv3Head
@moduleinfo( @moduleinfo(
name="yolov3_darknet53_coco2017", name="yolov3_darknet53_coco2017",
version="1.1.1", version="1.2.0",
type="CV/object_detection", type="CV/object_detection",
summary="Baidu's YOLOv3 model for object detection, with backbone DarkNet53, trained with dataset coco2017.", summary="Baidu's YOLOv3 model for object detection, with backbone DarkNet53, trained with dataset coco2017.",
author="paddlepaddle", author="paddlepaddle",
author_email="paddle-dev@baidu.com") author_email="paddle-dev@baidu.com")
class YOLOv3DarkNet53Coco2017(hub.Module): class YOLOv3DarkNet53Coco2017:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "yolov3_darknet53_model") self.default_pretrained_model_path = os.path.join(self.directory, "yolov3_darknet53_model", "model")
self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt")) self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt"))
self._set_config() self._set_config()
...@@ -36,11 +34,13 @@ class YOLOv3DarkNet53Coco2017(hub.Module): ...@@ -36,11 +34,13 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
""" """
predictor config setting. predictor config setting.
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path) model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
cpu_config.switch_ir_optim(False) cpu_config.switch_ir_optim(False)
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
...@@ -49,88 +49,14 @@ class YOLOv3DarkNet53Coco2017(hub.Module): ...@@ -49,88 +49,14 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
def context(self, trainable=True, pretrained=True, get_prediction=False):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
# image
image = fluid.layers.data(name='image', shape=[3, 608, 608], dtype='float32')
# backbone
backbone = DarkNet(norm_type='bn', norm_decay=0., depth=53)
# body_feats
body_feats = backbone(image)
# im_size
im_size = fluid.layers.data(name='im_size', shape=[2], dtype='int32')
# yolo_head
yolo_head = YOLOv3Head(num_classes=80)
# head_features
head_features, body_features = yolo_head._get_outputs(body_feats, is_train=trainable)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# var_prefix
var_prefix = '@HUB_{}@'.format(self.name)
# name of inputs
inputs = {'image': var_prefix + image.name, 'im_size': var_prefix + im_size.name}
# name of outputs
if get_prediction:
bbox_out = yolo_head.get_prediction(head_features, im_size)
outputs = {'bbox_out': [var_prefix + bbox_out.name]}
else:
outputs = {
'head_features': [var_prefix + var.name for var in head_features],
'body_features': [var_prefix + var.name for var in body_features]
}
# add_vars_prefix
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(fluid.default_startup_program(), var_prefix)
# inputs
inputs = {key: context_prog.global_block().vars[value] for key, value in inputs.items()}
# outputs
outputs = {
key: [context_prog.global_block().vars[varname] for varname in value]
for key, value in outputs.items()
}
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
# pretrained
if pretrained:
def _if_exist(var):
return os.path.exists(os.path.join(self.default_pretrained_model_path, var.name))
fluid.io.load_vars(exe, self.default_pretrained_model_path, predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
def object_detection(self, def object_detection(self,
paths=None, paths=None,
images=None, images=None,
data=None,
batch_size=1, batch_size=1,
use_gpu=False, use_gpu=False,
output_dir='detection_result', output_dir='detection_result',
...@@ -168,52 +94,34 @@ class YOLOv3DarkNet53Coco2017(hub.Module): ...@@ -168,52 +94,34 @@ class YOLOv3DarkNet53Coco2017(hub.Module):
) )
paths = paths if paths else list() paths = paths if paths else list()
if data and 'image' in data:
paths += data['image']
data_reader = partial(reader, paths, images) data_reader = partial(reader, paths, images)
batch_reader = fluid.io.batch(data_reader, batch_size=batch_size) batch_reader = paddle.batch(data_reader, batch_size=batch_size)
res = [] res = []
for iter_id, feed_data in enumerate(batch_reader()): for iter_id, feed_data in enumerate(batch_reader()):
feed_data = np.array(feed_data) feed_data = np.array(feed_data)
image_tensor = PaddleTensor(np.array(list(feed_data[:, 0])))
im_size_tensor = PaddleTensor(np.array(list(feed_data[:, 1])))
if use_gpu:
data_out = self.gpu_predictor.run([image_tensor, im_size_tensor])
else:
data_out = self.cpu_predictor.run([image_tensor, im_size_tensor])
output = postprocess( predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
paths=paths, input_names = predictor.get_input_names()
images=images, input_handle = predictor.get_input_handle(input_names[0])
data_out=data_out, input_handle.copy_from_cpu(np.array(list(feed_data[:, 0])))
score_thresh=score_thresh, input_handle = predictor.get_input_handle(input_names[1])
label_names=self.label_names, input_handle.copy_from_cpu(np.array(list(feed_data[:, 1])))
output_dir=output_dir,
handle_id=iter_id * batch_size, predictor.run()
visualization=visualization) output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output = postprocess(paths=paths,
images=images,
data_out=output_handle,
score_thresh=score_thresh,
label_names=self.label_names,
output_dir=output_dir,
handle_id=iter_id * batch_size,
visualization=visualization)
res.extend(output) res.extend(output)
return res return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
......
...@@ -88,7 +88,7 @@ def load_label_info(file_path): ...@@ -88,7 +88,7 @@ def load_label_info(file_path):
def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True): def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True):
""" """
postprocess the lod_tensor produced by fluid.Executor.run postprocess the lod_tensor produced by Executor.run
Args: Args:
paths (list[str]): The paths of images. paths (list[str]): The paths of images.
...@@ -113,9 +113,8 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, ...@@ -113,9 +113,8 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir,
confidence (float): The confidence of detection result. confidence (float): The confidence of detection result.
save_path (str): The path to save output images. save_path (str): The path to save output images.
""" """
lod_tensor = data_out[0] lod = data_out.lod()[0]
lod = lod_tensor.lod[0] results = data_out.copy_to_cpu()
results = lod_tensor.as_ndarray()
check_dir(output_dir) check_dir(output_dir)
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="yolov3_darknet53_coco2017")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_object_detection1(self):
results = self.module.object_detection(
paths=['tests/test.jpg']
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection2(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')]
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection3(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(0 < left < 1000)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection4(self):
self.assertRaises(
AssertionError,
self.module.object_detection,
paths=['no.jpg']
)
def test_object_detection5(self):
self.assertRaises(
AttributeError,
self.module.object_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
__all__ = ['MultiClassNMS', 'YOLOv3Head']
class MultiClassNMS(object):
# __op__ = fluid.layers.multiclass_nms
def __init__(self, background_label, keep_top_k, nms_threshold, nms_top_k, normalized, score_threshold):
super(MultiClassNMS, self).__init__()
self.background_label = background_label
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.normalized = normalized
self.score_threshold = score_threshold
class YOLOv3Head(object):
"""Head block for YOLOv3 network
Args:
norm_decay (float): weight decay for normalization layer weights
num_classes (int): number of output classes
ignore_thresh (float): threshold to ignore confidence loss
label_smooth (bool): whether to use label smoothing
anchors (list): anchors
anchor_masks (list): anchor masks
nms (object): an instance of `MultiClassNMS`
"""
def __init__(self,
norm_decay=0.,
num_classes=80,
ignore_thresh=0.7,
label_smooth=True,
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198],
[373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
nms=MultiClassNMS(
background_label=-1,
keep_top_k=100,
nms_threshold=0.45,
nms_top_k=1000,
normalized=True,
score_threshold=0.01),
weight_prefix_name=''):
self.norm_decay = norm_decay
self.num_classes = num_classes
self.ignore_thresh = ignore_thresh
self.label_smooth = label_smooth
self.anchor_masks = anchor_masks
self._parse_anchors(anchors)
self.nms = nms
self.prefix_name = weight_prefix_name
def _conv_bn(self, input, ch_out, filter_size, stride, padding, act='leaky', is_test=True, name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
bn_param_attr = ParamAttr(regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale')
bn_bias_attr = ParamAttr(regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset')
out = fluid.layers.batch_norm(
input=conv,
act=None,
is_test=is_test,
param_attr=bn_param_attr,
bias_attr=bn_bias_attr,
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
def _detection_block(self, input, channel, is_test=True, name=None):
assert channel % 2 == 0, \
"channel {} cannot be divided by 2 in detection block {}" \
.format(channel, name)
conv = input
for j in range(2):
conv = self._conv_bn(
conv, channel, filter_size=1, stride=1, padding=0, is_test=is_test, name='{}.{}.0'.format(name, j))
conv = self._conv_bn(
conv, channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test, name='{}.{}.1'.format(name, j))
route = self._conv_bn(
conv, channel, filter_size=1, stride=1, padding=0, is_test=is_test, name='{}.2'.format(name))
tip = self._conv_bn(
route, channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test, name='{}.tip'.format(name))
return route, tip
def _upsample(self, input, scale=2, name=None):
out = fluid.layers.resize_nearest(input=input, scale=float(scale), name=name)
return out
def _parse_anchors(self, anchors):
"""
Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors
"""
self.anchors = []
self.mask_anchors = []
assert len(anchors) > 0, "ANCHORS not set."
assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set."
for anchor in anchors:
assert len(anchor) == 2, "anchor {} len should be 2".format(anchor)
self.anchors.extend(anchor)
anchor_num = len(anchors)
for masks in self.anchor_masks:
self.mask_anchors.append([])
for mask in masks:
assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask])
def _get_outputs(self, input, is_train=True):
"""
Get YOLOv3 head output
Args:
input (list): List of Variables, output of backbone stages
is_train (bool): whether in train or test mode
Returns:
outputs (list): Variables of each output layer
"""
outputs = []
# get last out_layer_num blocks in reverse order
out_layer_num = len(self.anchor_masks)
if isinstance(input, OrderedDict):
blocks = list(input.values())[-1:-out_layer_num - 1:-1]
else:
blocks = input[-1:-out_layer_num - 1:-1]
route = None
for i, block in enumerate(blocks):
if i > 0: # perform concat in first 2 detection_block
block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self._detection_block(
block, channel=512 // (2**i), is_test=(not is_train), name=self.prefix_name + "yolo_block.{}".format(i))
# out channel number = mask_num * (5 + class_num)
num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5)
block_out = fluid.layers.conv2d(
input=tip,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name=self.prefix_name + "yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.), name=self.prefix_name + "yolo_output.{}.conv.bias".format(i)))
outputs.append(block_out)
if i < len(blocks) - 1:
# do not perform upsample in the last detection_block
route = self._conv_bn(
input=route,
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not is_train),
name=self.prefix_name + "yolo_transition.{}".format(i))
# upsample
route = self._upsample(route)
return outputs, blocks
def get_prediction(self, outputs, im_size):
"""
Get prediction result of YOLOv3 network
Args:
outputs (list): list of Variables, return from _get_outputs
im_size (Variable): Variable of size([h, w]) of each image
Returns:
pred (Variable): The prediction result after non-max suppress.
"""
boxes = []
scores = []
downsample = 32
for i, output in enumerate(outputs):
box, score = fluid.layers.yolo_box(
x=output,
img_size=im_size,
anchors=self.mask_anchors[i],
class_num=self.num_classes,
conf_thresh=self.nms.score_threshold,
downsample_ratio=downsample,
name=self.prefix_name + "yolo_box" + str(i))
boxes.append(box)
scores.append(fluid.layers.transpose(score, perm=[0, 2, 1]))
downsample //= 2
yolo_boxes = fluid.layers.concat(boxes, axis=1)
yolo_scores = fluid.layers.concat(scores, axis=2)
pred = fluid.layers.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.nms.score_threshold,
nms_top_k=self.nms.nms_top_k,
keep_top_k=self.nms.keep_top_k,
nms_threshold=self.nms.nms_threshold,
background_label=self.nms.background_label,
normalized=self.nms.normalized,
name="multiclass_nms")
return pred
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册