diff --git a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/README.md b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/README.md index 456de66bac3b6b2c59466eecb53ce86a09fa783b..2e5032d0b5d5744d8498d9fe767efb9034d0e025 100644 --- a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/README.md +++ b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/README.md @@ -101,19 +101,13 @@ - save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在) - ```python - def save_inference_model(dirname, - model_filename=None, - params_filename=None, - combined=True) + def save_inference_model(dirname) ``` - 将模型保存到指定路径。 - **参数** - - dirname: 存在模型的目录名称;
- - model\_filename: 模型文件名称,默认为\_\_model\_\_;
- - params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);
- - combined: 是否将参数保存到统一的一个文件中。 + - dirname: 模型保存路径
## 四、服务部署 @@ -167,6 +161,10 @@ 修复numpy数据读取问题 +* 1.1.0 + + 移除 fluid api + - ```shell - $ hub install yolov3_mobilenet_v1_coco2017==1.0.2 + $ hub install yolov3_mobilenet_v1_coco2017==1.1.0 ``` diff --git a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/README_en.md b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/README_en.md index f80472bfa12c931152e617a965d8023b079c02da..08ecd92a907f9b7b7ed54b50884083db15d2f1fe 100644 --- a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/README_en.md +++ b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/README_en.md @@ -100,19 +100,13 @@ - save\_path (str, optional): output path for saving results - ```python - def save_inference_model(dirname, - model_filename=None, - params_filename=None, - combined=True) + def save_inference_model(dirname) ``` - Save model to specific path - **Parameters** - - dirname: output dir for saving model - - model\_filename: filename for saving model - - params\_filename: filename for saving parameters - - combined: whether save parameters into one file + - dirname: model save path ## IV.Server Deployment @@ -166,6 +160,10 @@ Fix the problem of reading numpy +* 1.1.0 + + Remove fluid api + - ```shell - $ hub install yolov3_mobilenet_v1_coco2017==1.0.2 + $ hub install yolov3_mobilenet_v1_coco2017==1.1.0 ``` diff --git a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/mobilenet_v1.py b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/mobilenet_v1.py deleted file mode 100644 index 05f64c9382b8630e41bac0546f67dcb83d7d4822..0000000000000000000000000000000000000000 --- a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/mobilenet_v1.py +++ /dev/null @@ -1,194 +0,0 @@ -# coding=utf-8 -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from paddle import fluid -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.regularizer import L2Decay - -__all__ = ['MobileNet'] - - -class MobileNet(object): - """ - MobileNet v1, see https://arxiv.org/abs/1704.04861 - - Args: - norm_type (str): normalization type, 'bn' and 'sync_bn' are supported - norm_decay (float): weight decay for normalization layer weights - conv_group_scale (int): scaling factor for convolution groups - with_extra_blocks (bool): if extra blocks should be added - extra_block_filters (list): number of filter for each extra block - """ - __shared__ = ['norm_type', 'weight_prefix_name'] - - def __init__(self, - norm_type='bn', - norm_decay=0., - conv_group_scale=1, - conv_learning_rate=1.0, - with_extra_blocks=False, - extra_block_filters=[[256, 512], [128, 256], [128, 256], - [64, 128]], - weight_prefix_name=''): - self.norm_type = norm_type - self.norm_decay = norm_decay - self.conv_group_scale = conv_group_scale - self.conv_learning_rate = conv_learning_rate - self.with_extra_blocks = with_extra_blocks - self.extra_block_filters = extra_block_filters - self.prefix_name = weight_prefix_name - - def _conv_norm(self, - input, - filter_size, - num_filters, - stride, - padding, - num_groups=1, - act='relu', - use_cudnn=True, - name=None): - parameter_attr = ParamAttr( - learning_rate=self.conv_learning_rate, - initializer=fluid.initializer.MSRA(), - name=name + "_weights") - conv = fluid.layers.conv2d( - input=input, - num_filters=num_filters, - filter_size=filter_size, - stride=stride, - padding=padding, - groups=num_groups, - act=None, - use_cudnn=use_cudnn, - param_attr=parameter_attr, - bias_attr=False) - - bn_name = name + "_bn" - norm_decay = self.norm_decay - bn_param_attr = ParamAttr( - regularizer=L2Decay(norm_decay), name=bn_name + '_scale') - bn_bias_attr = ParamAttr( - regularizer=L2Decay(norm_decay), name=bn_name + '_offset') - return fluid.layers.batch_norm( - input=conv, - act=act, - param_attr=bn_param_attr, - bias_attr=bn_bias_attr, - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') - - def depthwise_separable(self, - input, - num_filters1, - num_filters2, - num_groups, - stride, - scale, - name=None): - depthwise_conv = self._conv_norm( - input=input, - filter_size=3, - num_filters=int(num_filters1 * scale), - stride=stride, - padding=1, - num_groups=int(num_groups * scale), - use_cudnn=False, - name=name + "_dw") - - pointwise_conv = self._conv_norm( - input=depthwise_conv, - filter_size=1, - num_filters=int(num_filters2 * scale), - stride=1, - padding=0, - name=name + "_sep") - return pointwise_conv - - def _extra_block(self, - input, - num_filters1, - num_filters2, - num_groups, - stride, - name=None): - pointwise_conv = self._conv_norm( - input=input, - filter_size=1, - num_filters=int(num_filters1), - stride=1, - num_groups=int(num_groups), - padding=0, - name=name + "_extra1") - normal_conv = self._conv_norm( - input=pointwise_conv, - filter_size=3, - num_filters=int(num_filters2), - stride=2, - num_groups=int(num_groups), - padding=1, - name=name + "_extra2") - return normal_conv - - def __call__(self, input): - scale = self.conv_group_scale - - blocks = [] - # input 1/1 - out = self._conv_norm( - input, 3, int(32 * scale), 2, 1, name=self.prefix_name + "conv1") - # 1/2 - out = self.depthwise_separable( - out, 32, 64, 32, 1, scale, name=self.prefix_name + "conv2_1") - out = self.depthwise_separable( - out, 64, 128, 64, 2, scale, name=self.prefix_name + "conv2_2") - # 1/4 - out = self.depthwise_separable( - out, 128, 128, 128, 1, scale, name=self.prefix_name + "conv3_1") - out = self.depthwise_separable( - out, 128, 256, 128, 2, scale, name=self.prefix_name + "conv3_2") - # 1/8 - blocks.append(out) - out = self.depthwise_separable( - out, 256, 256, 256, 1, scale, name=self.prefix_name + "conv4_1") - out = self.depthwise_separable( - out, 256, 512, 256, 2, scale, name=self.prefix_name + "conv4_2") - # 1/16 - blocks.append(out) - for i in range(5): - out = self.depthwise_separable( - out, - 512, - 512, - 512, - 1, - scale, - name=self.prefix_name + "conv5_" + str(i + 1)) - module11 = out - - out = self.depthwise_separable( - out, 512, 1024, 512, 2, scale, name=self.prefix_name + "conv5_6") - # 1/32 - out = self.depthwise_separable( - out, 1024, 1024, 1024, 1, scale, name=self.prefix_name + "conv6") - module13 = out - blocks.append(out) - if not self.with_extra_blocks: - return blocks - - num_filters = self.extra_block_filters - module14 = self._extra_block(module13, num_filters[0][0], - num_filters[0][1], 1, 2, - self.prefix_name + "conv7_1") - module15 = self._extra_block(module14, num_filters[1][0], - num_filters[1][1], 1, 2, - self.prefix_name + "conv7_2") - module16 = self._extra_block(module15, num_filters[2][0], - num_filters[2][1], 1, 2, - self.prefix_name + "conv7_3") - module17 = self._extra_block(module16, num_filters[3][0], - num_filters[3][1], 1, 2, - self.prefix_name + "conv7_4") - return module11, module13, module14, module15, module16, module17 diff --git a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/module.py b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/module.py index 98e1110a0bba80f6559d3653612cb49754179fb0..0a642907e1dc458787a6b80f709e5c053ed20fab 100644 --- a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/module.py +++ b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/module.py @@ -6,31 +6,29 @@ import argparse import os from functools import partial +import paddle import numpy as np -import paddle.fluid as fluid -import paddlehub as hub -from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor +import paddle.jit +import paddle.static +from paddle.inference import Config, create_predictor from paddlehub.module.module import moduleinfo, runnable, serving -from paddlehub.common.paddle_helper import add_vars_prefix -from yolov3_mobilenet_v1_coco2017.mobilenet_v1 import MobileNet -from yolov3_mobilenet_v1_coco2017.processor import load_label_info, postprocess, base64_to_cv2 -from yolov3_mobilenet_v1_coco2017.data_feed import reader -from yolov3_mobilenet_v1_coco2017.yolo_head import MultiClassNMS, YOLOv3Head +from .processor import load_label_info, postprocess, base64_to_cv2 +from .data_feed import reader @moduleinfo( name="yolov3_mobilenet_v1_coco2017", - version="1.0.2", + version="1.1.0", type="CV/object_detection", summary= "Baidu's YOLOv3 model for object detection with backbone MobileNet_V1, trained with dataset COCO2017.", author="paddlepaddle", author_email="paddle-dev@baidu.com") -class YOLOv3MobileNetV1Coco2017(hub.Module): - def _initialize(self): +class YOLOv3MobileNetV1Coco2017: + def __init__(self): self.default_pretrained_model_path = os.path.join( - self.directory, "yolov3_mobilenet_v1_model") + self.directory, "yolov3_mobilenet_v1_model", "model") self.label_names = load_label_info( os.path.join(self.directory, "label_file.txt")) self._set_config() @@ -39,11 +37,13 @@ class YOLOv3MobileNetV1Coco2017(hub.Module): """ predictor config setting. """ - cpu_config = AnalysisConfig(self.default_pretrained_model_path) + model = self.default_pretrained_model_path+'.pdmodel' + params = self.default_pretrained_model_path+'.pdiparams' + cpu_config = Config(model, params) cpu_config.disable_glog_info() cpu_config.disable_gpu() cpu_config.switch_ir_optim(False) - self.cpu_predictor = create_paddle_predictor(cpu_config) + self.cpu_predictor = create_predictor(cpu_config) try: _places = os.environ["CUDA_VISIBLE_DEVICES"] @@ -52,106 +52,10 @@ class YOLOv3MobileNetV1Coco2017(hub.Module): except: use_gpu = False if use_gpu: - gpu_config = AnalysisConfig(self.default_pretrained_model_path) + gpu_config = Config(model, params) gpu_config.disable_glog_info() gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) - self.gpu_predictor = create_paddle_predictor(gpu_config) - - def context(self, trainable=True, pretrained=True, get_prediction=False): - """ - Distill the Head Features, so as to perform transfer learning. - - Args: - trainable (bool): whether to set parameters trainable. - pretrained (bool): whether to load default pretrained model. - get_prediction (bool): whether to get prediction. - - Returns: - inputs(dict): the input variables. - outputs(dict): the output variables. - context_prog (Program): the program to execute transfer learning. - """ - context_prog = fluid.Program() - startup_program = fluid.Program() - with fluid.program_guard(context_prog, startup_program): - with fluid.unique_name.guard(): - # image - image = fluid.layers.data( - name='image', shape=[3, 608, 608], dtype='float32') - # backbone - backbone = MobileNet( - norm_type='sync_bn', - norm_decay=0., - conv_group_scale=1, - with_extra_blocks=False) - # body_feats - body_feats = backbone(image) - # im_size - im_size = fluid.layers.data( - name='im_size', shape=[2], dtype='int32') - # yolo_head - yolo_head = YOLOv3Head(num_classes=80) - # head_features - head_features, body_features = yolo_head._get_outputs( - body_feats, is_train=trainable) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(startup_program) - - # var_prefix - var_prefix = '@HUB_{}@'.format(self.name) - # name of inputs - inputs = { - 'image': var_prefix + image.name, - 'im_size': var_prefix + im_size.name - } - # name of outputs - if get_prediction: - bbox_out = yolo_head.get_prediction(head_features, im_size) - outputs = {'bbox_out': [var_prefix + bbox_out.name]} - else: - outputs = { - 'head_features': - [var_prefix + var.name for var in head_features], - 'body_features': - [var_prefix + var.name for var in body_features] - } - # add_vars_prefix - add_vars_prefix(context_prog, var_prefix) - add_vars_prefix(startup_program, var_prefix) - # inputs - inputs = { - key: context_prog.global_block().vars[value] - for key, value in inputs.items() - } - # outputs - outputs = { - key: [ - context_prog.global_block().vars[varname] - for varname in value - ] - for key, value in outputs.items() - } - # trainable - for param in context_prog.global_block().iter_parameters(): - param.trainable = trainable - # pretrained - if pretrained: - - def _if_exist(var): - return os.path.exists( - os.path.join(self.default_pretrained_model_path, - var.name)) - - fluid.io.load_vars( - exe, - self.default_pretrained_model_path, - predicate=_if_exist) - else: - exe.run(startup_program) - - return inputs, outputs, context_prog + self.gpu_predictor = create_predictor(gpu_config) def object_detection(self, paths=None, @@ -194,54 +98,33 @@ class YOLOv3MobileNetV1Coco2017(hub.Module): paths = paths if paths else list() data_reader = partial(reader, paths, images) - batch_reader = fluid.io.batch(data_reader, batch_size=batch_size) + batch_reader = paddle.batch(data_reader, batch_size=batch_size) res = [] for iter_id, feed_data in enumerate(batch_reader()): feed_data = np.array(feed_data) - image_tensor = PaddleTensor(np.array(list(feed_data[:, 0]))) - im_size_tensor = PaddleTensor(np.array(list(feed_data[:, 1]))) - if use_gpu: - data_out = self.gpu_predictor.run( - [image_tensor, im_size_tensor]) - else: - data_out = self.cpu_predictor.run( - [image_tensor, im_size_tensor]) - output = postprocess( - paths=paths, - images=images, - data_out=data_out, - score_thresh=score_thresh, - label_names=self.label_names, - output_dir=output_dir, - handle_id=iter_id * batch_size, - visualization=visualization) + predictor = self.gpu_predictor if use_gpu else self.cpu_predictor + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + input_handle.copy_from_cpu(np.array(list(feed_data[:, 0]))) + input_handle = predictor.get_input_handle(input_names[1]) + input_handle.copy_from_cpu(np.array(list(feed_data[:, 1]))) + + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + + output = postprocess(paths=paths, + images=images, + data_out=output_handle, + score_thresh=score_thresh, + label_names=self.label_names, + output_dir=output_dir, + handle_id=iter_id * batch_size, + visualization=visualization) res.extend(output) return res - def save_inference_model(self, - dirname, - model_filename=None, - params_filename=None, - combined=True): - if combined: - model_filename = "__model__" if not model_filename else model_filename - params_filename = "__params__" if not params_filename else params_filename - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - program, feeded_var_names, target_vars = fluid.io.load_inference_model( - dirname=self.default_pretrained_model_path, executor=exe) - - fluid.io.save_inference_model( - dirname=dirname, - main_program=program, - executor=exe, - feeded_var_names=feeded_var_names, - target_vars=target_vars, - model_filename=model_filename, - params_filename=params_filename) - @serving def serving_method(self, images, **kwargs): """ diff --git a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/processor.py b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/processor.py index 2f9a42d9c0ce6fc2d819349580d850b908ccfb51..aa9a61bd0c2afa1d36a13af5db7a5d17123b6d34 100644 --- a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/processor.py +++ b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/processor.py @@ -101,7 +101,7 @@ def postprocess(paths, handle_id, visualization=True): """ - postprocess the lod_tensor produced by fluid.Executor.run + postprocess the lod_tensor produced by Executor.run Args: paths (list[str]): The paths of images. @@ -126,9 +126,8 @@ def postprocess(paths, confidence (float): The confidence of detection result. save_path (str): The path to save output images. """ - lod_tensor = data_out[0] - lod = lod_tensor.lod[0] - results = lod_tensor.as_ndarray() + lod = data_out.lod()[0] + results = data_out.copy_to_cpu() check_dir(output_dir) diff --git a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/test.py b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/test.py new file mode 100644 index 0000000000000000000000000000000000000000..ed99b62891f3e964fc0506c9f9159f6bfcd3779b --- /dev/null +++ b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/test.py @@ -0,0 +1,108 @@ +import os +import shutil +import unittest + +import cv2 +import requests +import paddlehub as hub + + +class TestHubModule(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a' + if not os.path.exists('tests'): + os.makedirs('tests') + response = requests.get(img_url) + assert response.status_code == 200, 'Network Error.' + with open('tests/test.jpg', 'wb') as f: + f.write(response.content) + cls.module = hub.Module(name="yolov3_mobilenet_v1_coco2017") + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree('tests') + shutil.rmtree('inference') + shutil.rmtree('detection_result') + + def test_object_detection1(self): + results = self.module.object_detection( + paths=['tests/test.jpg'] + ) + bbox = results[0]['data'][0] + label = bbox['label'] + confidence = bbox['confidence'] + left = bbox['left'] + right = bbox['right'] + top = bbox['top'] + bottom = bbox['bottom'] + + self.assertEqual(label, 'cat') + self.assertTrue(confidence > 0.5) + self.assertTrue(0 < left < 1000) + self.assertTrue(1000 < right < 3500) + self.assertTrue(500 < top < 1500) + self.assertTrue(1000 < bottom < 4500) + + def test_object_detection2(self): + results = self.module.object_detection( + images=[cv2.imread('tests/test.jpg')] + ) + bbox = results[0]['data'][0] + label = bbox['label'] + confidence = bbox['confidence'] + left = bbox['left'] + right = bbox['right'] + top = bbox['top'] + bottom = bbox['bottom'] + + self.assertEqual(label, 'cat') + self.assertTrue(confidence > 0.5) + self.assertTrue(0 < left < 1000) + self.assertTrue(1000 < right < 3500) + self.assertTrue(500 < top < 1500) + self.assertTrue(1000 < bottom < 4500) + + def test_object_detection3(self): + results = self.module.object_detection( + images=[cv2.imread('tests/test.jpg')], + visualization=False + ) + bbox = results[0]['data'][0] + label = bbox['label'] + confidence = bbox['confidence'] + left = bbox['left'] + right = bbox['right'] + top = bbox['top'] + bottom = bbox['bottom'] + + self.assertEqual(label, 'cat') + self.assertTrue(confidence > 0.5) + self.assertTrue(0 < left < 1000) + self.assertTrue(1000 < right < 3500) + self.assertTrue(500 < top < 1500) + self.assertTrue(1000 < bottom < 4500) + + def test_object_detection4(self): + self.assertRaises( + AssertionError, + self.module.object_detection, + paths=['no.jpg'] + ) + + def test_object_detection5(self): + self.assertRaises( + AttributeError, + self.module.object_detection, + images=['test.jpg'] + ) + + def test_save_inference_model(self): + self.module.save_inference_model('./inference/model') + + self.assertTrue(os.path.exists('./inference/model.pdmodel')) + self.assertTrue(os.path.exists('./inference/model.pdiparams')) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/yolo_head.py b/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/yolo_head.py deleted file mode 100644 index 7428fb4c281507c30918e12a04753d559346cf7b..0000000000000000000000000000000000000000 --- a/modules/image/object_detection/yolov3_mobilenet_v1_coco2017/yolo_head.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import OrderedDict - -from paddle import fluid -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.regularizer import L2Decay - -__all__ = ['MultiClassNMS', 'YOLOv3Head'] - - -class MultiClassNMS(object): - # __op__ = fluid.layers.multiclass_nms - def __init__(self, background_label, keep_top_k, nms_threshold, nms_top_k, - normalized, score_threshold): - super(MultiClassNMS, self).__init__() - self.background_label = background_label - self.keep_top_k = keep_top_k - self.nms_threshold = nms_threshold - self.nms_top_k = nms_top_k - self.normalized = normalized - self.score_threshold = score_threshold - - -class YOLOv3Head(object): - """Head block for YOLOv3 network - - Args: - norm_decay (float): weight decay for normalization layer weights - num_classes (int): number of output classes - ignore_thresh (float): threshold to ignore confidence loss - label_smooth (bool): whether to use label smoothing - anchors (list): anchors - anchor_masks (list): anchor masks - nms (object): an instance of `MultiClassNMS` - """ - - def __init__(self, - norm_decay=0., - num_classes=80, - ignore_thresh=0.7, - label_smooth=True, - anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], - [59, 119], [116, 90], [156, 198], [373, 326]], - anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], - nms=MultiClassNMS( - background_label=-1, - keep_top_k=100, - nms_threshold=0.45, - nms_top_k=1000, - normalized=True, - score_threshold=0.01), - weight_prefix_name=''): - self.norm_decay = norm_decay - self.num_classes = num_classes - self.ignore_thresh = ignore_thresh - self.label_smooth = label_smooth - self.anchor_masks = anchor_masks - self._parse_anchors(anchors) - self.nms = nms - self.prefix_name = weight_prefix_name - - def _conv_bn(self, - input, - ch_out, - filter_size, - stride, - padding, - act='leaky', - is_test=True, - name=None): - conv = fluid.layers.conv2d( - input=input, - num_filters=ch_out, - filter_size=filter_size, - stride=stride, - padding=padding, - act=None, - param_attr=ParamAttr(name=name + ".conv.weights"), - bias_attr=False) - - bn_name = name + ".bn" - bn_param_attr = ParamAttr( - regularizer=L2Decay(self.norm_decay), name=bn_name + '.scale') - bn_bias_attr = ParamAttr( - regularizer=L2Decay(self.norm_decay), name=bn_name + '.offset') - out = fluid.layers.batch_norm( - input=conv, - act=None, - is_test=is_test, - param_attr=bn_param_attr, - bias_attr=bn_bias_attr, - moving_mean_name=bn_name + '.mean', - moving_variance_name=bn_name + '.var') - - if act == 'leaky': - out = fluid.layers.leaky_relu(x=out, alpha=0.1) - return out - - def _detection_block(self, input, channel, is_test=True, name=None): - assert channel % 2 == 0, \ - "channel {} cannot be divided by 2 in detection block {}" \ - .format(channel, name) - - conv = input - for j in range(2): - conv = self._conv_bn( - conv, - channel, - filter_size=1, - stride=1, - padding=0, - is_test=is_test, - name='{}.{}.0'.format(name, j)) - conv = self._conv_bn( - conv, - channel * 2, - filter_size=3, - stride=1, - padding=1, - is_test=is_test, - name='{}.{}.1'.format(name, j)) - route = self._conv_bn( - conv, - channel, - filter_size=1, - stride=1, - padding=0, - is_test=is_test, - name='{}.2'.format(name)) - tip = self._conv_bn( - route, - channel * 2, - filter_size=3, - stride=1, - padding=1, - is_test=is_test, - name='{}.tip'.format(name)) - return route, tip - - def _upsample(self, input, scale=2, name=None): - out = fluid.layers.resize_nearest( - input=input, scale=float(scale), name=name) - return out - - def _parse_anchors(self, anchors): - """ - Check ANCHORS/ANCHOR_MASKS in config and parse mask_anchors - - """ - self.anchors = [] - self.mask_anchors = [] - - assert len(anchors) > 0, "ANCHORS not set." - assert len(self.anchor_masks) > 0, "ANCHOR_MASKS not set." - - for anchor in anchors: - assert len(anchor) == 2, "anchor {} len should be 2".format(anchor) - self.anchors.extend(anchor) - - anchor_num = len(anchors) - for masks in self.anchor_masks: - self.mask_anchors.append([]) - for mask in masks: - assert mask < anchor_num, "anchor mask index overflow" - self.mask_anchors[-1].extend(anchors[mask]) - - def _get_outputs(self, input, is_train=True): - """ - Get YOLOv3 head output - - Args: - input (list): List of Variables, output of backbone stages - is_train (bool): whether in train or test mode - - Returns: - outputs (list): Variables of each output layer - """ - - outputs = [] - - # get last out_layer_num blocks in reverse order - out_layer_num = len(self.anchor_masks) - if isinstance(input, OrderedDict): - blocks = list(input.values())[-1:-out_layer_num - 1:-1] - else: - blocks = input[-1:-out_layer_num - 1:-1] - route = None - for i, block in enumerate(blocks): - if i > 0: # perform concat in first 2 detection_block - block = fluid.layers.concat(input=[route, block], axis=1) - route, tip = self._detection_block( - block, - channel=512 // (2**i), - is_test=(not is_train), - name=self.prefix_name + "yolo_block.{}".format(i)) - - # out channel number = mask_num * (5 + class_num) - num_filters = len(self.anchor_masks[i]) * (self.num_classes + 5) - block_out = fluid.layers.conv2d( - input=tip, - num_filters=num_filters, - filter_size=1, - stride=1, - padding=0, - act=None, - param_attr=ParamAttr(name=self.prefix_name + - "yolo_output.{}.conv.weights".format(i)), - bias_attr=ParamAttr( - regularizer=L2Decay(0.), - name=self.prefix_name + - "yolo_output.{}.conv.bias".format(i))) - outputs.append(block_out) - - if i < len(blocks) - 1: - # do not perform upsample in the last detection_block - route = self._conv_bn( - input=route, - ch_out=256 // (2**i), - filter_size=1, - stride=1, - padding=0, - is_test=(not is_train), - name=self.prefix_name + "yolo_transition.{}".format(i)) - # upsample - route = self._upsample(route) - - return outputs, blocks - - def get_prediction(self, outputs, im_size): - """ - Get prediction result of YOLOv3 network - - Args: - outputs (list): list of Variables, return from _get_outputs - im_size (Variable): Variable of size([h, w]) of each image - - Returns: - pred (Variable): The prediction result after non-max suppress. - - """ - boxes = [] - scores = [] - downsample = 32 - for i, output in enumerate(outputs): - box, score = fluid.layers.yolo_box( - x=output, - img_size=im_size, - anchors=self.mask_anchors[i], - class_num=self.num_classes, - conf_thresh=self.nms.score_threshold, - downsample_ratio=downsample, - name=self.prefix_name + "yolo_box" + str(i)) - boxes.append(box) - scores.append(fluid.layers.transpose(score, perm=[0, 2, 1])) - - downsample //= 2 - - yolo_boxes = fluid.layers.concat(boxes, axis=1) - yolo_scores = fluid.layers.concat(scores, axis=2) - pred = fluid.layers.multiclass_nms( - bboxes=yolo_boxes, - scores=yolo_scores, - score_threshold=self.nms.score_threshold, - nms_top_k=self.nms.nms_top_k, - keep_top_k=self.nms.keep_top_k, - nms_threshold=self.nms.nms_threshold, - background_label=self.nms.background_label, - normalized=self.nms.normalized, - name="multiclass_nms") - return pred