提交 674b82a4 编写于 作者: D dengkaipeng

add infer and dataset

上级 d4e91159
...@@ -75,29 +75,27 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层 ...@@ -75,29 +75,27 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层
### 数据准备 ### 数据准备
[MS-COCO数据集](http://cocodataset.org/#download)上进行训练,通过如下方式下载数据集 模型目前支持COCO数据集格式的数据读入和精度评估,我们同时提供了将转换为COCO数据集的格式的Pascal VOC数据集下载,可通过如下命令下载
```bash ```bash
python dataset/coco/download.py python dataset/voc/download.py
``` ```
数据目录结构如下: 数据目录结构如下:
``` ```
dataset/coco/ dataset/voc/
├── annotations ├── annotations
│   ├── instances_train2014.json
│   ├── instances_train2017.json │   ├── instances_train2017.json
│   ├── instances_val2014.json
│   ├── instances_val2017.json │   ├── instances_val2017.json
| ... | ...
├── train2017 ├── train2017
│   ├── 000000000009.jpg │   ├── 1013.jpg
│   ├── 000000580008.jpg │   ├── 1014.jpg
| ... | ...
├── val2017 ├── val2017
│   ├── 000000000139.jpg │   ├── 2551.jpg
│   ├── 000000000285.jpg │   ├── 2552.jpg
| ... | ...
``` ```
...@@ -140,15 +138,17 @@ YOLOv3模型输出为LoDTensor,只支持使用batch_size为1进行评估,可 ...@@ -140,15 +138,17 @@ YOLOv3模型输出为LoDTensor,只支持使用batch_size为1进行评估,可
1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估 1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估
```bash ```bash
python main.py --data=<path/to/dataset> --eval_only python main.py --data=dataset/voc --eval_only
``` ```
2. 加载checkpoint进行精度评估 2. 加载checkpoint进行精度评估
```bash ```bash
python main.py --data=<path/to/dataset> --eval_only --weights=yolo_checkpoint/final python main.py --data=dataset/voc --eval_only --weights=yolo_checkpoint/no_mixup/final
``` ```
同样可以通过指定`-d`参数进行动态图模式的评估。
#### 评估精度 #### 评估精度
在10类小数据集下训练模型权重见[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams),评估精度如下: 在10类小数据集下训练模型权重见[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams),评估精度如下:
...@@ -168,6 +168,33 @@ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.506 ...@@ -168,6 +168,33 @@ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.506
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.670 Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.670
``` ```
### 模型推断及可视化
可通过如下两种方式进行模型推断。
1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估
```bash
python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg
```
2. 加载checkpoint进行精度评估
```bash
python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg --weights=yolo_checkpoint/mo_mixup/final
```
推断结果可视化图像会保存于`--output`指定的文件夹下,默认保存于`./output`目录。
模型推断会输出如下检测结果日志:
```text
2020-04-02 08:26:47,268-INFO: detect bicycle at [116.14993, 127.278336, 579.7716, 438.44214] score: 0.97
2020-04-02 08:26:47,273-INFO: detect dog at [127.44086, 215.71997, 316.04276, 539.7584] score: 0.99
2020-04-02 08:26:47,274-INFO: detect car at [475.42343, 80.007484, 687.16095, 171.27374] score: 0.98
2020-04-02 08:26:47,274-INFO: Detection bbox results save in output/dog.jpg
```
## 参考论文 ## 参考论文
- [You Only Look Once: Unified, Real-Time Object Detection](https://arxiv.org/abs/1506.02640v5), Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi. - [You Only Look Once: Unified, Real-Time Object Detection](https://arxiv.org/abs/1506.02640v5), Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def colormap(rgb=False):
"""
Get colormap
"""
color_list = np.array([
0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494,
0.184, 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078,
0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000,
1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000,
0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667,
0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000,
0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000,
1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000,
0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500,
0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667,
0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333,
0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000,
0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333,
0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000,
1.000, 0.667, 0.333, 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000,
1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.167,
0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000,
0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000,
0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000,
0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000,
0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833,
0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286,
0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714,
0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000
]).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
if not rgb:
color_list = color_list[:, ::-1]
return color_list
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tarfile
from paddle.dataset.common import download
DATASETS = {
'voc': [
('https://paddlemodels.bj.bcebos.com/hapi/voc.tar',
'9faeb7fd997aeea843092fd608d5bcb4', ),
],
}
def download_decompress_file(data_dir, url, md5):
logger.info("Downloading from {}".format(url))
tar_file = download(url, data_dir, md5)
logger.info("Decompressing {}".format(tar_file))
with tarfile.open(tar_file) as tf:
tf.extractall(path=data_dir)
os.remove(tar_file)
if __name__ == "__main__":
data_dir = osp.split(osp.realpath(sys.argv[0]))[0]
for name, infos in DATASETS.items():
for info in infos:
download_decompress_file(data_dir, *info)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from PIL import Image
from paddle import fluid
from paddle.fluid.optimizer import Momentum
from paddle.fluid.io import DataLoader
from model import Model, Input, set_device
from modeling import yolov3_darknet53, YoloLoss
from coco import COCODataset
from transforms import *
from visualizer import draw_bbox
import logging
logger = logging.getLogger(__name__)
IMAGE_MEAN = [0.485, 0.456, 0.406]
IMAGE_STD = [0.229, 0.224, 0.225]
def get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def load_labels(label_list, with_background=True):
idx = int(with_background)
cat2name = {}
with open(label_list) as f:
for line in f.readlines():
line = line.strip()
if line:
cat2name[idx] = line
idx += 1
return cat2name
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
inputs = [Input([None, 3], 'int32', name='img_info'),
Input([None, 3, None, None], 'float32', name='image')]
cat2name = load_labels(FLAGS.label_list, with_background=False)
model = yolov3_darknet53(num_classes=len(cat2name),
model_mode='test',
pretrained=FLAGS.weights is None)
model.prepare(inputs=inputs, device=FLAGS.device)
if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True)
# image preprocess
orig_img = Image.open(FLAGS.infer_image).convert('RGB')
w, h = orig_img.size
img = orig_img.resize((608, 608), Image.BICUBIC)
img = np.array(img).astype('float32') / 255.0
img -= np.array(IMAGE_MEAN)
img /= np.array(IMAGE_STD)
img = img.transpose((2, 0, 1))[np.newaxis, :]
img_info = np.array([0, h, w]).astype('int32')[np.newaxis, :]
_, bboxes = model.test([img_info, img])
vis_img = draw_bbox(orig_img, cat2name, bboxes, FLAGS.draw_threshold)
save_name = get_save_image_name(FLAGS.output_dir, FLAGS.infer_image)
logger.info("Detection bbox results save in {}".format(save_name))
vis_img.save(save_name, quality=95)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Yolov3 Training on VOC")
parser.add_argument(
"--device", type=str, default='gpu', help="device to use, gpu or cpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"--label_list", type=str, default=None,
help="path to category label list file")
parser.add_argument(
"-t", "--draw_threshold", type=float, default=0.5,
help="threshold to reserve the result for visualization")
parser.add_argument(
"-i", "--infer_image", type=str, default=None,
help="image path for inference")
parser.add_argument(
"-o", "--output_dir", type=str, default='output',
help="directory to save inference result if --visualize is set")
parser.add_argument(
"-w", "--weights", default=None, type=str,
help="path to weights for inference")
FLAGS = parser.parse_args()
assert os.path.isfile(FLAGS.infer_image), \
"infer_image {} not a file".format(FLAGS.infer_image)
assert os.path.isfile(FLAGS.label_list), \
"label_list {} not a file".format(FLAGS.label_list)
main()
...@@ -195,8 +195,6 @@ if __name__ == '__main__': ...@@ -195,8 +195,6 @@ if __name__ == '__main__':
help='initial learning rate') help='initial learning rate')
parser.add_argument( parser.add_argument(
"-b", "--batch_size", default=8, type=int, help="batch size") "-b", "--batch_size", default=8, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=1, type=int, help="number of devices")
parser.add_argument( parser.add_argument(
"-j", "--num_workers", default=4, type=int, help="reader worker number") "-j", "--num_workers", default=4, type=int, help="reader worker number")
parser.add_argument( parser.add_argument(
......
...@@ -91,8 +91,8 @@ class YOLOv3(Model): ...@@ -91,8 +91,8 @@ class YOLOv3(Model):
def __init__(self, num_classes=80, model_mode='train'): def __init__(self, num_classes=80, model_mode='train'):
super(YOLOv3, self).__init__() super(YOLOv3, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
assert str.lower(model_mode) in ['train', 'eval'], \ assert str.lower(model_mode) in ['train', 'eval', 'test'], \
"model_mode should be 'train' or 'val', but got " \ "model_mode should be 'train' 'eval' or 'test', but got " \
"{}".format(model_mode) "{}".format(model_mode)
self.model_mode = str.lower(model_mode) self.model_mode = str.lower(model_mode)
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
...@@ -157,7 +157,7 @@ class YOLOv3(Model): ...@@ -157,7 +157,7 @@ class YOLOv3(Model):
route = self.route_blocks[idx](route) route = self.route_blocks[idx](route)
route = fluid.layers.resize_nearest(route, scale=2) route = fluid.layers.resize_nearest(route, scale=2)
if self.model_mode == 'eval': if self.model_mode != 'train':
anchor_mask = self.anchor_masks[idx] anchor_mask = self.anchor_masks[idx]
mask_anchors = [] mask_anchors = []
for m in anchor_mask: for m in anchor_mask:
...@@ -181,16 +181,21 @@ class YOLOv3(Model): ...@@ -181,16 +181,21 @@ class YOLOv3(Model):
if self.model_mode == 'train': if self.model_mode == 'train':
return outputs return outputs
return outputs + [img_id[0, :], fluid.layers.multiclass_nms( preds = [img_id[0, :],
fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1), bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2), scores=fluid.layers.concat(scores, axis=2),
score_threshold=self.valid_thresh, score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk, nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk, keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh, nms_threshold=self.nms_thresh,
background_label=-1) background_label=-1)]
]
if self.model_mode == 'test':
return preds
# model_mode == "eval"
return outputs + preds
class YoloLoss(Loss): class YoloLoss(Loss):
def __init__(self, num_classes=80, num_max_boxes=50): def __init__(self, num_classes=80, num_max_boxes=50):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import numpy as np
from PIL import Image, ImageDraw
from colormap import colormap
import logging
logger = logging.getLogger(__name__)
__all__ = ['draw_bbox']
def draw_bbox(image, catid2name, bboxes, threshold):
"""
Draw bbox on image
"""
bboxes = np.array(bboxes)
if bboxes.shape[1] != 6:
logger.info("No bbox detect")
return image
draw = ImageDraw.Draw(image)
catid2color = {}
color_list = colormap(rgb=True)[:40]
for bbox in bboxes:
catid, score, xmin, ymin, xmax, ymax = bbox
if score < threshold:
continue
if catid not in catid2color:
idx = np.random.randint(len(color_list))
catid2color[catid] = color_list[idx]
color = tuple(catid2color[catid])
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill=color)
logger.info("detect {} at {} score: {:.2f}".format(
catid2name[int(catid)], [xmin, ymin, xmax, ymax], score))
# draw label
text = "{} {:.2f}".format(catid2name[catid], score)
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return image
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册