未验证 提交 89b6f5fe 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #23 from heavengate/yolov3

add YOLOv3
......@@ -1084,7 +1084,7 @@ class Model(fluid.dygraph.Layer):
return eval_result
def predict(self, test_data, batch_size=1, num_workers=0):
def predict(self, test_data, batch_size=1, num_workers=0, stack_outputs=True):
"""
FIXME: add more comments and usage
Args:
......@@ -1097,6 +1097,12 @@ class Model(fluid.dygraph.Layer):
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
stack_output (bool): whether stack output field like a batch, as for an output
filed of a sample is in shape [X, Y], test_data contains N samples, predict
output field will be in shape [N, X, Y] if stack_output is True, and will
be a length N list in shape [[X, Y], [X, Y], ....[X, Y]] if stack_outputs
is False. stack_outputs as False is used for LoDTensor output situation,
it is recommended set as True if outputs contains no LoDTensor. Default False
"""
if fluid.in_dygraph_mode():
......@@ -1123,19 +1129,16 @@ class Model(fluid.dygraph.Layer):
if not isinstance(test_loader, Iterable):
loader = test_loader()
outputs = None
outputs = []
for data in tqdm.tqdm(loader):
if not fluid.in_dygraph_mode():
data = data[0]
outs = self.test(*data)
data = flatten(data)
outputs.append(self.test(data[:len(self._inputs)]))
if outputs is None:
outputs = outs
else:
outputs = [
np.vstack([x, outs[i]]) for i, x in enumerate(outputs)
]
# NOTE: for lod tensor output, we should not stack outputs
# for stacking may loss its detail info
outputs = list(zip(*outputs))
if stack_outputs:
outputs = [np.stack(outs, axis=0) for outs in outputs]
self._test_dataloader = None
if test_loader is not None and self._adapter._nranks > 1 \
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from . import resnet
from . import darknet
from . import yolov3
from .resnet import *
from .darknet import *
from .yolov3 import *
__all__ = resnet.__all__ \
+ darknet.__all__ \
+ yolov3.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from model import Model
from .download import get_weights_path
__all__ = ['DarkNet53', 'ConvBNLayer', 'darknet53']
# {num_layers: (url, md5)}
pretrain_infos = {
53: ('https://paddlemodels.bj.bcebos.com/hapi/darknet53.pdparams',
'2506357a5c31e865785112fc614a487d')
}
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=1,
groups=1,
padding=0,
act="leaky"):
super(ConvBNLayer, self).__init__()
self.conv = Conv2D(
num_channels=ch_in,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=False,
act=None)
self.batch_norm = BatchNorm(
num_channels=ch_out,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02),
regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.)))
self.act = act
def forward(self, inputs):
out = self.conv(inputs)
out = self.batch_norm(out)
if self.act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
class DownSample(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=2,
padding=1):
super(DownSample, self).__init__()
self.conv_bn_layer = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding)
self.ch_out = ch_out
def forward(self, inputs):
out = self.conv_bn_layer(inputs)
return out
class BasicBlock(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0)
self.conv2 = ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out*2,
filter_size=3,
stride=1,
padding=1)
def forward(self, inputs):
conv1 = self.conv1(inputs)
conv2 = self.conv2(conv1)
out = fluid.layers.elementwise_add(x=inputs, y=conv2, act=None)
return out
class LayerWarp(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out, count):
super(LayerWarp,self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out)
self.res_out_list = []
for i in range(1,count):
res_out = self.add_sublayer("basic_block_%d" % (i),
BasicBlock(
ch_out*2,
ch_out))
self.res_out_list.append(res_out)
self.ch_out = ch_out
def forward(self,inputs):
y = self.basicblock0(inputs)
for basic_block_i in self.res_out_list:
y = basic_block_i(y)
return y
DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
class DarkNet53(Model):
def __init__(self, num_layers=53, ch_in=3):
super(DarkNet53, self).__init__()
assert num_layers in DarkNet_cfg.keys(), \
"only support num_layers in {} currently" \
.format(DarkNet_cfg.keys())
self.stages = DarkNet_cfg[num_layers]
self.stages = self.stages[0:5]
self.conv0 = ConvBNLayer(
ch_in=ch_in,
ch_out=32,
filter_size=3,
stride=1,
padding=1)
self.downsample0 = DownSample(
ch_in=32,
ch_out=32 * 2)
self.darknet53_conv_block_list = []
self.downsample_list = []
ch_in = [64,128,256,512,1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer(
"stage_%d" % (i),
LayerWarp(
int(ch_in[i]),
32*(2**i),
stage))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
downsample = self.add_sublayer(
"stage_%d_downsample" % i,
DownSample(
ch_in = 32*(2**(i+1)),
ch_out = 32*(2**(i+2))))
self.downsample_list.append(downsample)
def forward(self,inputs):
out = self.conv0(inputs)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
out = self.downsample_list[i](out)
return blocks[-1:-4:-1]
def _darknet(num_layers=53, input_channels=3, pretrained=True):
model = DarkNet53(num_layers, input_channels)
if pretrained:
assert num_layers in pretrain_infos.keys(), \
"DarkNet{} do not have pretrained weights now, " \
"pretrained should be set as False".format(num_layers)
weight_path = get_weights_path(*(pretrain_infos[num_layers]))
assert weight_path.endswith('.pdparams'), \
"suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def darknet53(input_channels=3, pretrained=True):
return _darknet(53, input_channels, pretrained)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from model import Model, Loss
from .darknet import darknet53, ConvBNLayer
from .download import get_weights_path
__all__ = ['YoloLoss', 'YOLOv3', 'yolov3_darknet53']
# {num_layers: (url, md5)}
pretrain_infos = {
53: ('https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams',
'aed7dd45124ff2e844ae3bd5ba6c91d2')
}
class YoloDetectionBlock(fluid.dygraph.Layer):
def __init__(self, ch_in, channel):
super(YoloDetectionBlock, self).__init__()
assert channel % 2 == 0, \
"channel {} cannot be divided by 2".format(channel)
self.conv0 = ConvBNLayer(
ch_in=ch_in,
ch_out=channel,
filter_size=1,
stride=1,
padding=0)
self.conv1 = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1)
self.conv2 = ConvBNLayer(
ch_in=channel*2,
ch_out=channel,
filter_size=1,
stride=1,
padding=0)
self.conv3 = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1)
self.route = ConvBNLayer(
ch_in=channel*2,
ch_out=channel,
filter_size=1,
stride=1,
padding=0)
self.tip = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1)
def forward(self, inputs):
out = self.conv0(inputs)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
route = self.route(out)
tip = self.tip(route)
return route, tip
class YOLOv3(Model):
def __init__(self, num_classes=80, model_mode='train'):
super(YOLOv3, self).__init__()
self.num_classes = num_classes
assert str.lower(model_mode) in ['train', 'eval', 'test'], \
"model_mode should be 'train' 'eval' or 'test', but got " \
"{}".format(model_mode)
self.model_mode = str.lower(model_mode)
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.valid_thresh = 0.005
self.nms_thresh = 0.45
self.nms_topk = 400
self.nms_posk = 100
self.draw_thresh = 0.5
self.backbone = darknet53(pretrained=(model_mode=='train'))
self.block_outputs = []
self.yolo_blocks = []
self.route_blocks = []
for idx, num_chan in enumerate([1024, 768, 384]):
yolo_block = self.add_sublayer(
"yolo_detecton_block_{}".format(idx),
YoloDetectionBlock(num_chan, 512 // (2**idx)))
self.yolo_blocks.append(yolo_block)
num_filters = len(self.anchor_masks[idx]) * (self.num_classes + 5)
block_out = self.add_sublayer(
"block_out_{}".format(idx),
Conv2D(num_channels=1024 // (2**idx),
num_filters=num_filters,
filter_size=1,
act=None,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.))))
self.block_outputs.append(block_out)
if idx < 2:
route = self.add_sublayer(
"route2_{}".format(idx),
ConvBNLayer(ch_in=512 // (2**idx),
ch_out=256 // (2**idx),
filter_size=1,
act='leaky_relu'))
self.route_blocks.append(route)
def forward(self, img_info, inputs):
outputs = []
boxes = []
scores = []
downsample = 32
feats = self.backbone(inputs)
route = None
for idx, feat in enumerate(feats):
if idx > 0:
feat = fluid.layers.concat(input=[route, feat], axis=1)
route, tip = self.yolo_blocks[idx](feat)
block_out = self.block_outputs[idx](tip)
outputs.append(block_out)
if idx < 2:
route = self.route_blocks[idx](route)
route = fluid.layers.resize_nearest(route, scale=2)
if self.model_mode != 'train':
anchor_mask = self.anchor_masks[idx]
mask_anchors = []
for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m])
mask_anchors.append(self.anchors[2 * m + 1])
img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
b, s = fluid.layers.yolo_box(
x=block_out,
img_size=img_shape,
anchors=mask_anchors,
class_num=self.num_classes,
conf_thresh=self.valid_thresh,
downsample_ratio=downsample)
boxes.append(b)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
downsample //= 2
if self.model_mode == 'train':
return outputs
preds = [img_id[0, :],
fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2),
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)]
if self.model_mode == 'test':
return preds
# model_mode == "eval"
return outputs + preds
class YoloLoss(Loss):
def __init__(self, num_classes=80, num_max_boxes=50):
super(YoloLoss, self).__init__()
self.num_classes = num_classes
self.num_max_boxes = num_max_boxes
self.ignore_thresh = 0.7
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
def forward(self, outputs, labels):
downsample = 32
gt_box, gt_label, gt_score = labels
losses = []
for idx, out in enumerate(outputs):
if idx == 3: break # debug
anchor_mask = self.anchor_masks[idx]
loss = fluid.layers.yolov3_loss(
x=out,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchor_mask=anchor_mask,
downsample_ratio=downsample,
anchors=self.anchors,
class_num=self.num_classes,
ignore_thresh=self.ignore_thresh,
use_label_smooth=True)
loss = fluid.layers.reduce_mean(loss)
losses.append(loss)
downsample //= 2
return losses
def _yolov3_darknet(num_layers=53, num_classes=80,
model_mode='train', pretrained=True):
model = YOLOv3(num_classes, model_mode)
if pretrained:
assert num_layers in pretrain_infos.keys(), \
"YOLOv3-DarkNet{} do not have pretrained weights now, " \
"pretrained should be set as False".format(num_layers)
weight_path = get_weights_path(*(pretrain_infos[num_layers]))
assert weight_path.endswith('.pdparams'), \
"suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
def yolov3_darknet53(num_classes=80, model_mode='train', pretrained=True):
return _yolov3_darknet(53, num_classes, model_mode, pretrained)
此差异已折叠。
# YOLOv3 目标检测模型
---
## 内容
- [模型简介](#模型简介)
- [快速开始](#快速开始)
- [参考论文](#参考论文)
## 模型简介
[YOLOv3](https://arxiv.org/abs/1804.02767) 是由 [Joseph Redmon](https://arxiv.org/search/cs?searchtype=author&query=Redmon%2C+J)[Ali Farhadi](https://arxiv.org/search/cs?searchtype=author&query=Farhadi%2C+A) 提出的单阶段检测器, 该检测器与达到同样精度的传统目标检测方法相比,推断速度能达到接近两倍.
传统目标检测方法通过两阶段检测,第一阶段生成预选框,第二阶段对预选框进行分类和位置坐标的调整,而YOLO将目标检测看做是对框位置和类别概率的一个单阶段回归问题,使得YOLO能达到近两倍的检测速度。而YOLOv3在YOLO的基础上引入的多尺度预测,使得YOLOv3网络对于小物体的检测精度大幅提高。
[YOLOv3](https://arxiv.org/abs/1804.02767) 是一阶段End2End的目标检测器。其目标检测原理如下图所示:
<p align="center">
<img src="image/YOLOv3.jpg" height=400 width=600 hspace='10'/> <br />
YOLOv3检测原理
</p>
YOLOv3将输入图像分成S\*S个格子,每个格子预测B个bounding box,每个bounding box预测内容包括: Location(x, y, w, h)、Confidence Score和C个类别的概率,因此YOLOv3输出层的channel数为B\*(5 + C)。YOLOv3的loss函数也有三部分组成:Location误差,Confidence误差和分类误差。
YOLOv3的网络结构如下图所示:
<p align="center">
<img src="image/YOLOv3_structure.jpg" height=400 width=400 hspace='10'/> <br />
YOLOv3网络结构
</p>
YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层和输出层组成。
1. 特征提取网络。YOLOv3使用 [DarkNet53](https://arxiv.org/abs/1612.08242)作为特征提取网络:DarkNet53 基本采用了全卷积网络,用步长为2的卷积操作替代了池化层,同时添加了 Residual 单元,避免在网络层数过深时发生梯度弥散。
2. 特征融合层。为了解决之前YOLO版本对小目标不敏感的问题,YOLOv3采用了3个不同尺度的特征图来进行目标检测,分别为13\*13,26\*26,52\*52,用来检测大、中、小三种目标。特征融合层选取 DarkNet 产出的三种尺度特征图作为输入,借鉴了FPN(feature pyramid networks)的思想,通过一系列的卷积层和上采样对各尺度的特征图进行融合。
3. 输出层。同样使用了全卷积结构,其中最后一个卷积层的卷积核个数是255:3\*(80+4+1)=255,3表示一个grid cell包含3个bounding box,4表示框的4个坐标信息,1表示Confidence Score,80表示COCO数据集中80个类别的概率。
## 快速开始
### 安装说明
#### paddle安装
本项目依赖于 PaddlePaddle 1.7及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
#### 代码下载及环境变量设置
克隆代码库到本地,并设置`PYTHONPATH`环境变量
```bash
git clone https://github.com/PaddlePaddle/hapi
cd hapi
export PYTHONPATH=$PYTHONPATH:`pwd`
cd tsm
```
#### 安装COCO-API
训练前需要首先下载[COCO-API](https://github.com/cocodataset/cocoapi):
```bash
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python setup.py install --user
```
### 数据准备
模型目前支持COCO数据集格式的数据读入和精度评估,我们同时提供了将转换为COCO数据集的格式的Pascal VOC数据集下载,可通过如下命令下载。
```bash
python dataset/download_voc.py
```
数据目录结构如下:
```
dataset/voc/
├── annotations
│   ├── instances_train2017.json
│   ├── instances_val2017.json
| ...
├── train2017
│   ├── 1013.jpg
│   ├── 1014.jpg
| ...
├── val2017
│   ├── 2551.jpg
│   ├── 2552.jpg
| ...
```
### 模型训练
数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`yolo_checkpoint`目录下。
YOLOv3模型训练总batch_size为64训练,以下以使用4卡Tesla P40每卡batch_size为16训练介绍训练方式。对于静态图和动态图,多卡训练中`--batch_size`为每卡上的batch_size,即总batch_size为`--batch_size`乘以卡数。
`main.py`脚本参数可通过如下命令查询
```bash
python main.py --help
```
#### 静态图训练
使用如下方式进行多卡训练:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --data=<path/to/dataset> --batch_size=16
```
#### 动态图训练
动态图训练只需要在运行脚本时添加`-d`参数即可。
使用如下方式进行多卡训练:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --data=<path/to/dataset> --batch_size=16 -d
```
### 模型评估
YOLOv3模型输出为LoDTensor,只支持使用batch_size为1进行评估,可通过如下两种方式进行模型评估。
1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估
```bash
python main.py --data=dataset/voc --eval_only
```
2. 加载checkpoint进行精度评估
```bash
python main.py --data=dataset/voc --eval_only --weights=yolo_checkpoint/no_mixup/final
```
同样可以通过指定`-d`参数进行动态图模式的评估。
#### 评估精度
在10类小数据集下训练模型权重见[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams),评估精度如下:
```bash
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.503
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.779
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.562
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.190
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.390
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.578
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.405
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.591
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.599
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.294
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.506
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.670
```
### 模型推断及可视化
可通过如下两种方式进行模型推断。
1. 自动下载Paddle发布的[YOLOv3-DarkNet53](https://paddlemodels.bj.bcebos.com/hapi/yolov3_darknet53.pdparams)权重评估
```bash
python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg
```
2. 加载checkpoint进行精度评估
```bash
python infer.py --label_list=dataset/voc/label_list.txt --infer_image=image/dog.jpg --weights=yolo_checkpoint/mo_mixup/final
```
推断结果可视化图像会保存于`--output`指定的文件夹下,默认保存于`./output`目录。
模型推断会输出如下检测结果日志:
```text
2020-04-02 08:26:47,268-INFO: detect bicycle at [116.14993, 127.278336, 579.7716, 438.44214] score: 0.97
2020-04-02 08:26:47,273-INFO: detect dog at [127.44086, 215.71997, 316.04276, 539.7584] score: 0.99
2020-04-02 08:26:47,274-INFO: detect car at [475.42343, 80.007484, 687.16095, 171.27374] score: 0.98
2020-04-02 08:26:47,274-INFO: Detection bbox results save in output/dog.jpg
```
## 参考论文
- [You Only Look Once: Unified, Real-Time Object Detection](https://arxiv.org/abs/1506.02640v5), Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi.
- [YOLOv3: An Incremental Improvement](https://arxiv.org/abs/1804.02767v1), Joseph Redmon, Ali Farhadi.
- [Bag of Freebies for Training Object Detection Neural Networks](https://arxiv.org/abs/1902.04103v3), Zhi Zhang, Tong He, Hang Zhang, Zhongyue Zhang, Junyuan Xie, Mu Li.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import cv2
import numpy as np
from pycocotools.coco import COCO
from paddle.fluid.io import Dataset
import logging
logger = logging.getLogger(__name__)
__all__ = ['COCODataset']
class COCODataset(Dataset):
"""
Load dataset with MS-COCO format.
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): voc annotation file path.
sample_num (int): number of samples to load, -1 means all.
use_default_label (bool): whether use the default mapping of
label to integer index. Default True.
with_background (bool): whether load background as a class,
default True.
transform (callable): callable transform to perform on samples,
default None.
mixup (bool): whether return image mixup samples, default False.
alpha (float): alpha factor of beta distribution to generate
mixup score, used only when mixup is True, default 1.5
beta (float): beta factor of beta distribution to generate
mixup score, used only when mixup is True, default 1.5
"""
def __init__(self,
dataset_dir='',
image_dir='',
anno_path='',
sample_num=-1,
with_background=True,
transform=None,
mixup=False,
alpha=1.5,
beta=1.5):
# roidbs is list of dict whose structure is:
# {
# 'im_file': im_fname, # image file name
# 'im_id': im_id, # image id
# 'h': im_h, # height of image
# 'w': im_w, # width
# 'is_crowd': is_crowd,
# 'gt_class': gt_class,
# 'gt_bbox': gt_bbox,
# 'gt_score': gt_score,
# 'difficult': difficult
# }
self._anno_path = os.path.join(dataset_dir, anno_path)
self._image_dir = os.path.join(dataset_dir, image_dir)
assert os.path.exists(self._anno_path), \
"anno_path {} not exists".format(anno_path)
assert os.path.exists(self._image_dir), \
"image_dir {} not exists".format(image_dir)
self._sample_num = sample_num
self._with_background = with_background
self._transform = transform
self._mixup = mixup
self._alpha = alpha
self._beta = beta
# load in dataset roidbs
self._load_roidb_and_cname2cid()
def _load_roidb_and_cname2cid(self):
assert self._anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path
coco = COCO(self._anno_path)
img_ids = coco.getImgIds()
cat_ids = coco.getCatIds()
records = []
ct = 0
# when with_background = True, mapping category to classid, like:
# background:0, first_class:1, second_class:2, ...
catid2clsid = dict({
catid: i + int(self._with_background)
for i, catid in enumerate(cat_ids)
})
cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in catid2clsid.items()
})
for img_id in img_ids:
img_anno = coco.loadImgs(img_id)[0]
im_fname = img_anno['file_name']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
instances = coco.loadAnns(ins_anno_ids)
bboxes = []
for inst in instances:
x, y, box_w, box_h = inst['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(im_w - 1, x1 + max(0, box_w - 1))
y2 = min(im_h - 1, y1 + max(0, box_h - 1))
if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
inst['clean_bbox'] = [x1, y1, x2, y2]
bboxes.append(inst)
else:
logger.warn(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
gt_score = np.ones((num_bbox, 1), dtype=np.float32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
if 'segmentation' in box:
gt_poly[i] = box['segmentation']
im_fname = os.path.join(self._image_dir,
im_fname) if self._image_dir else im_fname
coco_rec = {
'im_file': im_fname,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_score': gt_score,
'gt_poly': gt_poly,
}
records.append(coco_rec)
ct += 1
if self._sample_num > 0 and ct >= self._sample_num:
break
assert len(records) > 0, 'not found any coco record in %s' % (self._anno_path)
logger.info('{} samples in file {}'.format(ct, self._anno_path))
self._roidbs, self._cname2cid = records, cname2cid
@property
def num_classes(self):
return len(self._cname2cid)
def __len__(self):
return len(self._roidbs)
def _getitem_by_index(self, idx):
roidb = self._roidbs[idx]
with open(roidb['im_file'], 'rb') as f:
data = np.frombuffer(f.read(), dtype='uint8')
im = cv2.imdecode(data, 1)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_info = np.array([roidb['im_id'][0], roidb['h'], roidb['w']], dtype='int32')
gt_bbox = roidb['gt_bbox']
gt_class = roidb['gt_class']
gt_score = roidb['gt_score']
return im_info, im, gt_bbox, gt_class, gt_score
def __getitem__(self, idx):
im_info, im, gt_bbox, gt_class, gt_score = self._getitem_by_index(idx)
if self._mixup:
mixup_idx = idx + np.random.randint(1, self.__len__())
mixup_idx %= self.__len__()
_, mixup_im, mixup_bbox, mixup_class, _ = \
self._getitem_by_index(mixup_idx)
im, gt_bbox, gt_class, gt_score = \
self._mixup_image(im, gt_bbox, gt_class, mixup_im,
mixup_bbox, mixup_class)
if self._transform:
im_info, im, gt_bbox, gt_class, gt_score = \
self._transform(im_info, im, gt_bbox, gt_class, gt_score)
return [im_info, im, gt_bbox, gt_class, gt_score]
def _mixup_image(self, img1, bbox1, class1, img2, bbox2, class2):
factor = np.random.beta(self._alpha, self._beta)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
return img1, bbox1, class1, np.ones_like(class1, dtype="float32")
if factor <= 0.0:
return img2, bbox2, class2, np.ones_like(class2, dtype="float32")
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
img = np.zeros((h, w, img1.shape[2]), 'float32')
img[:img1.shape[0], :img1.shape[1], :] = \
img1.astype('float32') * factor
img[:img2.shape[0], :img2.shape[1], :] += \
img2.astype('float32') * (1.0 - factor)
gt_bbox = np.concatenate((bbox1, bbox2), axis=0)
gt_class = np.concatenate((class1, class2), axis=0)
score1 = np.ones_like(class1, dtype="float32") * factor
score2 = np.ones_like(class2, dtype="float32") * (1.0 - factor)
gt_score = np.concatenate((score1, score2), axis=0)
return img, gt_bbox, gt_class, gt_score
@property
def mixup(self):
return self._mixup
@mixup.setter
def mixup(self, value):
if not isinstance(value, bool):
raise ValueError("mixup should be a boolean number")
logger.info("{} set mixup to {}".format(self, value))
self._mixup = value
def pascalvoc_label(with_background=True):
labels_map = {
'aeroplane': 1,
'bicycle': 2,
'bird': 3,
'boat': 4,
'bottle': 5,
'bus': 6,
'car': 7,
'cat': 8,
'chair': 9,
'cow': 10,
'diningtable': 11,
'dog': 12,
'horse': 13,
'motorbike': 14,
'person': 15,
'pottedplant': 16,
'sheep': 17,
'sofa': 18,
'train': 19,
'tvmonitor': 20
}
if not with_background:
labels_map = {k: v - 1 for k, v in labels_map.items()}
return labels_map
......@@ -17,8 +17,6 @@ import json
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
from metrics import Metric
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
......@@ -26,12 +24,13 @@ logger = logging.getLogger(__name__)
__all__ = ['COCOMetric']
OUTFILE = './bbox.json'
# considered to change to a callback later
class COCOMetric(Metric):
# COCOMetric behavior is different from Metric defined in high
# level API, COCOMetric will and con only accumulate on the epoch
# end, so we impliment COCOMetric as not a high level API Metric
class COCOMetric():
"""
Metrci for MS-COCO dataset, only support update with batch
size as 1.
......@@ -43,7 +42,6 @@ class COCOMetric(Metric):
"""
def __init__(self, anno_path, with_background=True, **kwargs):
super(COCOMetric, self).__init__(**kwargs)
self.anno_path = anno_path
self.with_background = with_background
self.bbox_results = []
......@@ -54,15 +52,14 @@ class COCOMetric(Metric):
{i + int(with_background): catid
for i, catid in enumerate(cat_ids)})
def update(self, preds, *args, **kwargs):
im_ids, bboxes = preds
assert im_ids.shape[0] == 1, \
def update(self, img_id, bboxes):
assert img_id.shape[0] == 1, \
"COCOMetric can only update with batch size = 1"
if bboxes.shape[1] != 6:
# no bbox detected in this batch
return
im_id = int(im_ids)
img_id = int(img_id)
for i in range(bboxes.shape[0]):
dt = bboxes[i, :]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
......@@ -72,7 +69,7 @@ class COCOMetric(Metric):
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'image_id': img_id,
'category_id': catid,
'bbox': bbox,
'score': score
......@@ -95,7 +92,7 @@ class COCOMetric(Metric):
# flush coco evaluation result
sys.stdout.flush()
self.result = map_stats[0]
return self.result
return [self.result]
def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
assert coco_gt != None or anno_file != None
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,3 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import sys
import tarfile
from download import _download
import logging
logger = logging.getLogger(__name__)
DATASETS = {
'voc': [
('https://paddlemodels.bj.bcebos.com/hapi/voc.tar',
'9faeb7fd997aeea843092fd608d5bcb4', ),
],
}
def download_decompress_file(data_dir, url, md5):
logger.info("Downloading from {}".format(url))
tar_file = _download(url, data_dir, md5)
logger.info("Decompressing {}".format(tar_file))
with tarfile.open(tar_file) as tf:
tf.extractall(path=data_dir)
os.remove(tar_file)
if __name__ == "__main__":
data_dir = osp.split(osp.realpath(sys.argv[0]))[0]
for name, infos in DATASETS.items():
for info in infos:
download_decompress_file(data_dir, *info)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import os
import argparse
import numpy as np
from PIL import Image
from paddle import fluid
from paddle.fluid.optimizer import Momentum
from paddle.fluid.io import DataLoader
from model import Model, Input, set_device
from models import yolov3_darknet53, YoloLoss
from coco import COCODataset
from transforms import *
from visualizer import draw_bbox
import logging
logger = logging.getLogger(__name__)
IMAGE_MEAN = [0.485, 0.456, 0.406]
IMAGE_STD = [0.229, 0.224, 0.225]
def get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def load_labels(label_list, with_background=True):
idx = int(with_background)
cat2name = {}
with open(label_list) as f:
for line in f.readlines():
line = line.strip()
if line:
cat2name[idx] = line
idx += 1
return cat2name
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
inputs = [Input([None, 3], 'int32', name='img_info'),
Input([None, 3, None, None], 'float32', name='image')]
cat2name = load_labels(FLAGS.label_list, with_background=False)
model = yolov3_darknet53(num_classes=len(cat2name),
model_mode='test',
pretrained=FLAGS.weights is None)
model.prepare(inputs=inputs, device=FLAGS.device)
if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True)
# image preprocess
orig_img = Image.open(FLAGS.infer_image).convert('RGB')
w, h = orig_img.size
img = orig_img.resize((608, 608), Image.BICUBIC)
img = np.array(img).astype('float32') / 255.0
img -= np.array(IMAGE_MEAN)
img /= np.array(IMAGE_STD)
img = img.transpose((2, 0, 1))[np.newaxis, :]
img_info = np.array([0, h, w]).astype('int32')[np.newaxis, :]
_, bboxes = model.test([img_info, img])
vis_img = draw_bbox(orig_img, cat2name, bboxes, FLAGS.draw_threshold)
save_name = get_save_image_name(FLAGS.output_dir, FLAGS.infer_image)
logger.info("Detection bbox results save in {}".format(save_name))
vis_img.save(save_name, quality=95)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Yolov3 Training on VOC")
parser.add_argument(
"--device", type=str, default='gpu', help="device to use, gpu or cpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"--label_list", type=str, default=None,
help="path to category label list file")
parser.add_argument(
"-t", "--draw_threshold", type=float, default=0.5,
help="threshold to reserve the result for visualization")
parser.add_argument(
"-i", "--infer_image", type=str, default=None,
help="image path for inference")
parser.add_argument(
"-o", "--output_dir", type=str, default='output',
help="directory to save inference result if --visualize is set")
parser.add_argument(
"-w", "--weights", default=None, type=str,
help="path to weights for inference")
FLAGS = parser.parse_args()
assert os.path.isfile(FLAGS.infer_image), \
"infer_image {} not a file".format(FLAGS.infer_image)
assert os.path.isfile(FLAGS.label_list), \
"label_list {} not a file".format(FLAGS.label_list)
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import numpy as np
from paddle import fluid
from paddle.fluid.optimizer import Momentum
from paddle.fluid.io import DataLoader
from model import Model, Input, set_device
from distributed import DistributedBatchSampler
from models import yolov3_darknet53, YoloLoss
from coco_metric import COCOMetric
from coco import COCODataset
from transforms import *
NUM_MAX_BOXES = 50
def make_optimizer(step_per_epoch, parameter_list=None):
base_lr = FLAGS.lr
warm_up_iter = 1000
momentum = 0.9
weight_decay = 5e-4
boundaries = [step_per_epoch * e for e in [200, 250]]
values = [base_lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries,
values=values)
learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate,
warmup_steps=warm_up_iter,
start_lr=0.0,
end_lr=base_lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(weight_decay),
momentum=momentum,
parameter_list=parameter_list)
return optimizer
def main():
device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None
inputs = [Input([None, 3], 'int32', name='img_info'),
Input([None, 3, None, None], 'float32', name='image')]
labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'),
Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'),
Input([None, NUM_MAX_BOXES], 'float32', name='gt_score')]
if not FLAGS.eval_only: # training mode
train_transform = Compose([ColorDistort(),
RandomExpand(),
RandomCrop(),
RandomFlip(),
NormalizeBox(),
PadBox(),
BboxXYXY2XYWH()])
train_collate_fn = BatchCompose([RandomShape(), NormalizeImage()])
dataset = COCODataset(dataset_dir=FLAGS.data,
anno_path='annotations/instances_train2017.json',
image_dir='train2017',
with_background=False,
mixup=True,
transform=train_transform)
batch_sampler = DistributedBatchSampler(dataset,
batch_size=FLAGS.batch_size,
shuffle=True,
drop_last=True)
loader = DataLoader(dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=FLAGS.num_workers,
return_list=True,
collate_fn=train_collate_fn)
else: # evaluation mode
eval_transform = Compose([ResizeImage(target_size=608),
NormalizeBox(),
PadBox(),
BboxXYXY2XYWH()])
eval_collate_fn = BatchCompose([NormalizeImage()])
dataset = COCODataset(dataset_dir=FLAGS.data,
anno_path='annotations/instances_val2017.json',
image_dir='val2017',
with_background=False,
transform=eval_transform)
# batch_size can only be 1 in evaluation for YOLOv3
# prediction bbox is a LoDTensor
batch_sampler = DistributedBatchSampler(dataset,
batch_size=1,
shuffle=False,
drop_last=False)
loader = DataLoader(dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=FLAGS.num_workers,
return_list=True,
collate_fn=eval_collate_fn)
pretrained = FLAGS.eval_only and FLAGS.weights is None
model = yolov3_darknet53(num_classes=dataset.num_classes,
model_mode='eval' if FLAGS.eval_only else 'train',
pretrained=pretrained)
if FLAGS.pretrain_weights is not None:
model.load(FLAGS.pretrain_weights, skip_mismatch=True, reset_optimizer=True)
optim = make_optimizer(len(batch_sampler), parameter_list=model.parameters())
model.prepare(optim,
YoloLoss(num_classes=dataset.num_classes),
inputs=inputs, labels=labels,
device=FLAGS.device)
# NOTE: we implement COCO metric of YOLOv3 model here, separately
# from 'prepare' and 'fit' framework for follwing reason:
# 1. YOLOv3 network structure is different between 'train' and
# 'eval' mode, in 'eval' mode, output prediction bbox is not the
# feature map used for YoloLoss calculating
# 2. COCO metric behavior is also different from defined Metric
# for COCO metric should not perform accumulate in each iteration
# but only accumulate at the end of an epoch
if FLAGS.eval_only:
if FLAGS.weights is not None:
model.load(FLAGS.weights, reset_optimizer=True)
preds = model.predict(loader, stack_outputs=False)
_, _, _, img_ids, bboxes = preds
anno_path = os.path.join(FLAGS.data, 'annotations/instances_val2017.json')
coco_metric = COCOMetric(anno_path=anno_path, with_background=False)
for img_id, bbox in zip(img_ids, bboxes):
coco_metric.update(img_id, bbox)
coco_metric.accumulate()
coco_metric.reset()
return
if FLAGS.resume is not None:
model.load(FLAGS.resume)
model.fit(train_data=loader,
epochs=FLAGS.epoch - FLAGS.no_mixup_epoch,
save_dir="yolo_checkpoint/mixup",
save_freq=10)
# do not use image mixup transfrom in laste FLAGS.no_mixup_epoch epoches
dataset.mixup = False
model.fit(train_data=loader,
epochs=FLAGS.no_mixup_epoch,
save_dir="yolo_checkpoint/no_mixup",
save_freq=5)
if __name__ == '__main__':
parser = argparse.ArgumentParser("Yolov3 Training on VOC")
parser.add_argument(
"--data", type=str, default='dataset/voc',
help="path to dataset directory")
parser.add_argument(
"--device", type=str, default='gpu', help="device to use, gpu or cpu")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"--eval_only", action='store_true', help="run evaluation only")
parser.add_argument(
"-e", "--epoch", default=300, type=int, help="number of epoch")
parser.add_argument(
"--no_mixup_epoch", default=30, type=int,
help="number of the last N epoch without image mixup")
parser.add_argument(
'--lr', '--learning-rate', default=0.001, type=float, metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=8, type=int, help="batch size")
parser.add_argument(
"-j", "--num_workers", default=4, type=int, help="reader worker number")
parser.add_argument(
"-p", "--pretrain_weights", default=None, type=str,
help="path to pretrained weights")
parser.add_argument(
"-r", "--resume", default=None, type=str,
help="path to model weights")
parser.add_argument(
"-w", "--weights", default=None, type=str,
help="path to weights for evaluation")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import numpy as np
from PIL import Image, ImageDraw
import logging
logger = logging.getLogger(__name__)
__all__ = ['draw_bbox']
def color_map(num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = np.array(color_map).reshape(-1, 3)
return color_map
def draw_bbox(image, catid2name, bboxes, threshold):
"""
Draw bbox on image
"""
bboxes = np.array(bboxes)
if bboxes.shape[1] != 6:
logger.info("No bbox detect")
return image
draw = ImageDraw.Draw(image)
catid2color = {}
color_list = color_map(len(catid2name))
for bbox in bboxes:
catid, score, xmin, ymin, xmax, ymax = bbox
if score < threshold:
continue
if catid not in catid2color:
idx = np.random.randint(len(color_list))
catid2color[catid] = color_list[idx]
color = tuple(catid2color[catid])
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill=color)
logger.info("detect {} at {} score: {:.2f}".format(
catid2name[int(catid)], [xmin, ymin, xmax, ymax], score))
# draw label
text = "{} {:.2f}".format(catid2name[catid], score)
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return image
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册