diff --git a/yolov3/README.md b/yolov3/README.md index e899daaf4b655a02e1ee5404375bd9c5b93b139b..7165206387b7271f33341fbf3fd472ea699b1e28 100644 --- a/yolov3/README.md +++ b/yolov3/README.md @@ -75,29 +75,27 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层 ### 数据准备 -在[MS-COCO数据集](http://cocodataset.org/#download)上进行训练,通过如下方式下载数据集。 +模型目前支持COCO数据集格式的数据读入和精度评估,我们同时提供了将转换为COCO数据集的格式的Pascal VOC数据集下载,可通过如下命令下载。 ```bash - python dataset/coco/download.py + python dataset/voc/download.py ``` 数据目录结构如下: ``` - dataset/coco/ + dataset/voc/ ├── annotations - │   ├── instances_train2014.json │   ├── instances_train2017.json - │   ├── instances_val2014.json │   ├── instances_val2017.json | ... ├── train2017 - │   ├── 000000000009.jpg - │   ├── 000000580008.jpg + │   ├── 1013.jpg + │   ├── 1014.jpg | ... ├── val2017 - │   ├── 000000000139.jpg - │   ├── 000000000285.jpg + │   ├── 2551.jpg + │   ├── 2552.jpg | ... ``` @@ -140,15 +138,17 @@ YOLOv3模型输出为LoDTensor,只支持使用batch_size为1进行评估,可 1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估 ```bash -python main.py --data= --eval_only +python main.py --data=dataset/voc --eval_only ``` 2. 加载checkpoint进行精度评估 ```bash -python main.py --data= --eval_only --weights=yolo_checkpoint/final +python main.py --data=dataset/voc --eval_only --weights=yolo_checkpoint/no_mixup/final ``` +同样可以通过指定`-d`参数进行动态图模式的评估。 + #### 评估精度 在10类小数据集下训练模型权重见[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams),评估精度如下: @@ -168,6 +168,33 @@ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.506 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.670 ``` +### 模型推断及可视化 + +可通过如下两种方式进行模型推断。 + +1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估 + +```bash +python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg +``` + +2. 加载checkpoint进行精度评估 + +```bash +python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg --weights=yolo_checkpoint/mo_mixup/final +``` + +推断结果可视化图像会保存于`--output`指定的文件夹下,默认保存于`./output`目录。 + +模型推断会输出如下检测结果日志: + +```text +2020-04-02 08:26:47,268-INFO: detect bicycle at [116.14993, 127.278336, 579.7716, 438.44214] score: 0.97 +2020-04-02 08:26:47,273-INFO: detect dog at [127.44086, 215.71997, 316.04276, 539.7584] score: 0.99 +2020-04-02 08:26:47,274-INFO: detect car at [475.42343, 80.007484, 687.16095, 171.27374] score: 0.98 +2020-04-02 08:26:47,274-INFO: Detection bbox results save in output/dog.jpg +``` + ## 参考论文 - [You Only Look Once: Unified, Real-Time Object Detection](https://arxiv.org/abs/1506.02640v5), Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi. diff --git a/yolov3/colormap.py b/yolov3/colormap.py new file mode 100644 index 0000000000000000000000000000000000000000..af20f9348e2b06affa89c7cd32e87fc8f25ec706 --- /dev/null +++ b/yolov3/colormap.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def colormap(rgb=False): + """ + Get colormap + """ + color_list = np.array([ + 0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, + 0.184, 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, + 0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000, + 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667, + 0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000, + 1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000, + 0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667, + 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333, + 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333, + 0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, + 1.000, 0.667, 0.333, 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.167, + 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286, + 0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714, + 0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000 + ]).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list diff --git a/yolov3/dataset/voc/download.py b/yolov3/dataset/voc/download.py new file mode 100644 index 0000000000000000000000000000000000000000..72976f7b03a871a13bb53fa952bccc671f967cf2 --- /dev/null +++ b/yolov3/dataset/voc/download.py @@ -0,0 +1,41 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tarfile + +from paddle.dataset.common import download + +DATASETS = { + 'voc': [ + ('https://paddlemodels.bj.bcebos.com/hapi/voc.tar', + '9faeb7fd997aeea843092fd608d5bcb4', ), + ], +} + +def download_decompress_file(data_dir, url, md5): + logger.info("Downloading from {}".format(url)) + tar_file = download(url, data_dir, md5) + logger.info("Decompressing {}".format(tar_file)) + with tarfile.open(tar_file) as tf: + tf.extractall(path=data_dir) + os.remove(tar_file) + + +if __name__ == "__main__": + data_dir = osp.split(osp.realpath(sys.argv[0]))[0] + for name, infos in DATASETS.items(): + for info in infos: + download_decompress_file(data_dir, *info) + diff --git a/yolov3/image/YOLOv3.jpg b/yolov3/image/YOLOv3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..06b81f545247c1d542fd661f947eb0cf3edc480e Binary files /dev/null and b/yolov3/image/YOLOv3.jpg differ diff --git a/yolov3/image/YOLOv3_structure.jpg b/yolov3/image/YOLOv3_structure.jpg new file mode 100644 index 0000000000000000000000000000000000000000..51bd2d1733e2f78945d3e871cb5b649aad95d633 Binary files /dev/null and b/yolov3/image/YOLOv3_structure.jpg differ diff --git a/yolov3/image/dog.jpg b/yolov3/image/dog.jpg new file mode 100644 index 0000000000000000000000000000000000000000..77b0381222eaed50867643f4166092c781e56d5b Binary files /dev/null and b/yolov3/image/dog.jpg differ diff --git a/yolov3/infer.py b/yolov3/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..612e5175181106f575bf2929ea1d1a7cf71e5e3b --- /dev/null +++ b/yolov3/infer.py @@ -0,0 +1,125 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import os +import argparse +import numpy as np +from PIL import Image + +from paddle import fluid +from paddle.fluid.optimizer import Momentum +from paddle.fluid.io import DataLoader + +from model import Model, Input, set_device +from modeling import yolov3_darknet53, YoloLoss +from coco import COCODataset +from transforms import * +from visualizer import draw_bbox + +import logging +logger = logging.getLogger(__name__) + +IMAGE_MEAN = [0.485, 0.456, 0.406] +IMAGE_STD = [0.229, 0.224, 0.225] + + +def get_save_image_name(output_dir, image_path): + """ + Get save image name from source image path. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + image_name = os.path.split(image_path)[-1] + name, ext = os.path.splitext(image_name) + return os.path.join(output_dir, "{}".format(name)) + ext + + +def load_labels(label_list, with_background=True): + idx = int(with_background) + cat2name = {} + with open(label_list) as f: + for line in f.readlines(): + line = line.strip() + if line: + cat2name[idx] = line + idx += 1 + return cat2name + + +def main(): + device = set_device(FLAGS.device) + fluid.enable_dygraph(device) if FLAGS.dynamic else None + + inputs = [Input([None, 3], 'int32', name='img_info'), + Input([None, 3, None, None], 'float32', name='image')] + + cat2name = load_labels(FLAGS.label_list, with_background=False) + + model = yolov3_darknet53(num_classes=len(cat2name), + model_mode='test', + pretrained=FLAGS.weights is None) + + model.prepare(inputs=inputs, device=FLAGS.device) + + if FLAGS.weights is not None: + model.load(FLAGS.weights, reset_optimizer=True) + + # image preprocess + orig_img = Image.open(FLAGS.infer_image).convert('RGB') + w, h = orig_img.size + img = orig_img.resize((608, 608), Image.BICUBIC) + img = np.array(img).astype('float32') / 255.0 + img -= np.array(IMAGE_MEAN) + img /= np.array(IMAGE_STD) + img = img.transpose((2, 0, 1))[np.newaxis, :] + img_info = np.array([0, h, w]).astype('int32')[np.newaxis, :] + + _, bboxes = model.test([img_info, img]) + + vis_img = draw_bbox(orig_img, cat2name, bboxes, FLAGS.draw_threshold) + save_name = get_save_image_name(FLAGS.output_dir, FLAGS.infer_image) + logger.info("Detection bbox results save in {}".format(save_name)) + vis_img.save(save_name, quality=95) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Yolov3 Training on VOC") + parser.add_argument( + "--device", type=str, default='gpu', help="device to use, gpu or cpu") + parser.add_argument( + "-d", "--dynamic", action='store_true', help="enable dygraph mode") + parser.add_argument( + "--label_list", type=str, default=None, + help="path to category label list file") + parser.add_argument( + "-t", "--draw_threshold", type=float, default=0.5, + help="threshold to reserve the result for visualization") + parser.add_argument( + "-i", "--infer_image", type=str, default=None, + help="image path for inference") + parser.add_argument( + "-o", "--output_dir", type=str, default='output', + help="directory to save inference result if --visualize is set") + parser.add_argument( + "-w", "--weights", default=None, type=str, + help="path to weights for inference") + FLAGS = parser.parse_args() + assert os.path.isfile(FLAGS.infer_image), \ + "infer_image {} not a file".format(FLAGS.infer_image) + assert os.path.isfile(FLAGS.label_list), \ + "label_list {} not a file".format(FLAGS.label_list) + main() diff --git a/yolov3/main.py b/yolov3/main.py index 3dc77993c158b8dcef4d15d1861696f4db055034..8ea40eaa9c2ddc45d957ac8245d4552f4245a044 100644 --- a/yolov3/main.py +++ b/yolov3/main.py @@ -195,8 +195,6 @@ if __name__ == '__main__': help='initial learning rate') parser.add_argument( "-b", "--batch_size", default=8, type=int, help="batch size") - parser.add_argument( - "-n", "--num_devices", default=1, type=int, help="number of devices") parser.add_argument( "-j", "--num_workers", default=4, type=int, help="reader worker number") parser.add_argument( diff --git a/yolov3/modeling.py b/yolov3/modeling.py index b2a52026bbb7c781dbd340862f2c6e760c8cac1b..699fadb10a2cb6654f139f1014cfd11e4fb178b9 100644 --- a/yolov3/modeling.py +++ b/yolov3/modeling.py @@ -91,8 +91,8 @@ class YOLOv3(Model): def __init__(self, num_classes=80, model_mode='train'): super(YOLOv3, self).__init__() self.num_classes = num_classes - assert str.lower(model_mode) in ['train', 'eval'], \ - "model_mode should be 'train' or 'val', but got " \ + assert str.lower(model_mode) in ['train', 'eval', 'test'], \ + "model_mode should be 'train' 'eval' or 'test', but got " \ "{}".format(model_mode) self.model_mode = str.lower(model_mode) self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, @@ -157,7 +157,7 @@ class YOLOv3(Model): route = self.route_blocks[idx](route) route = fluid.layers.resize_nearest(route, scale=2) - if self.model_mode == 'eval': + if self.model_mode != 'train': anchor_mask = self.anchor_masks[idx] mask_anchors = [] for m in anchor_mask: @@ -181,16 +181,21 @@ class YOLOv3(Model): if self.model_mode == 'train': return outputs - return outputs + [img_id[0, :], fluid.layers.multiclass_nms( - bboxes=fluid.layers.concat(boxes, axis=1), - scores=fluid.layers.concat(scores, axis=2), - score_threshold=self.valid_thresh, - nms_top_k=self.nms_topk, - keep_top_k=self.nms_posk, - nms_threshold=self.nms_thresh, - background_label=-1) -] - + preds = [img_id[0, :], + fluid.layers.multiclass_nms( + bboxes=fluid.layers.concat(boxes, axis=1), + scores=fluid.layers.concat(scores, axis=2), + score_threshold=self.valid_thresh, + nms_top_k=self.nms_topk, + keep_top_k=self.nms_posk, + nms_threshold=self.nms_thresh, + background_label=-1)] + + if self.model_mode == 'test': + return preds + + # model_mode == "eval" + return outputs + preds class YoloLoss(Loss): def __init__(self, num_classes=80, num_max_boxes=50): diff --git a/yolov3/visualizer.py b/yolov3/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb04f2dcf4f67604a9370ad4c6324e09c96f92f --- /dev/null +++ b/yolov3/visualizer.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import numpy as np +from PIL import Image, ImageDraw + +from colormap import colormap + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['draw_bbox'] + + +def draw_bbox(image, catid2name, bboxes, threshold): + """ + Draw bbox on image + """ + bboxes = np.array(bboxes) + if bboxes.shape[1] != 6: + logger.info("No bbox detect") + return image + + draw = ImageDraw.Draw(image) + + catid2color = {} + color_list = colormap(rgb=True)[:40] + for bbox in bboxes: + catid, score, xmin, ymin, xmax, ymax = bbox + + if score < threshold: + continue + + if catid not in catid2color: + idx = np.random.randint(len(color_list)) + catid2color[catid] = color_list[idx] + color = tuple(catid2color[catid]) + + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill=color) + logger.info("detect {} at {} score: {:.2f}".format( + catid2name[int(catid)], [xmin, ymin, xmax, ymax], score)) + + # draw label + text = "{} {:.2f}".format(catid2name[catid], score) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + return image