提交 529c597c 编写于 作者: X xiaoting 提交者: Kaipeng Deng

add yolov3 for dygraph (#4136)

* add yolov3 for dygraph
上级 2fc9147a
*.log
*.json
*.jpg
*.png
output/
checkpoints/
weights/
!weights/*.sh
dataset/coco/
!dataset/coco/*.py
log*
output*
# YOLOv3 目标检测
---
本模型是[paddle_yolov3](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/yolov3)的动态图版本
## 内容
- [简介](#简介)
- [快速开始](#快速开始)
- [进阶使用](#进阶使用)
- [FAQ](#faq)
- [参考文献](#参考文献)
- [版本更新](#版本更新)
- [如何贡献代码](#如何贡献代码)
- [作者](#作者)
## 简介
[YOLOv3](https://arxiv.org/abs/1804.02767) 是由 [Joseph Redmon](https://arxiv.org/search/cs?searchtype=author&query=Redmon%2C+J)[Ali Farhadi](https://arxiv.org/search/cs?searchtype=author&query=Farhadi%2C+A) 提出的单阶段检测器, 该检测器与达到同样精度的传统目标检测方法相比,推断速度能达到接近两倍.
在我们的实现版本中使用了 [Bag of Freebies for Training Object Detection Neural Networks](https://arxiv.org/abs/1902.04103v3) 中提出的图像增强和label smooth等优化方法,精度优于darknet框架的实现版本,在COCO-2017数据集上,达到`mAP(0.50:0.95)= 38.9`的精度,比darknet实现版本的精度(33.0)要高5.9.
同时,在推断速度方面,基于Paddle预测库的加速方法,推断速度比darknet高30%.
## 快速开始
### 安装
**安装[COCO-API](https://github.com/cocodataset/cocoapi):**
训练前需要首先下载[COCO-API](https://github.com/cocodataset/cocoapi)
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python setup.py install --user
**安装[PaddlePaddle](https://github.com/PaddlePaddle/Paddle):**
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.7或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
### 数据准备
**COCO数据集:**
[MS-COCO数据集](http://cocodataset.org/#download)上进行训练,通过如下方式下载数据集。
```bash
python dataset/coco/download.py
```
数据目录结构如下:
```
dataset/coco/
├── annotations
│   ├── instances_train2014.json
│   ├── instances_train2017.json
│   ├── instances_val2014.json
│   ├── instances_val2017.json
| ...
├── train2017
│   ├── 000000000009.jpg
│   ├── 000000580008.jpg
| ...
├── val2017
│   ├── 000000000139.jpg
│   ├── 000000000285.jpg
| ...
```
**自定义数据集:**
用户可使用自定义的数据集,我们推荐自定义数据集使用COCO数据集格式的标注,并可通过设置`--data_dir`或修改[reader.py](./reader.py#L39)指定数据集路径。使用COCO数据集格式标注时,目录结构可参考上述COCO数据集目录结构。
### 模型训练
**下载预训练模型:** 本示例提供DarkNet-53预训练[模型](https://paddlemodels.bj.bcebos.com/yolo/darknet53.pdparams ),该模型转换自作者提供的预训练权重[pjreddie/darknet](https://pjreddie.com/media/files/darknet53.conv.74),采用如下命令下载预训练模型:
sh ./weights/download.sh
**注意:** Windows用户可通过`./weights/download.sh`中的链接直接下载和解压。
通过设置`--pretrain` 加载预训练模型。同时在fine-tune时也采用该设置加载已训练模型。
请在训练前确认预训练模型下载与加载正确,否则训练过程中损失可能会出现NAN。
**开始训练:** 数据准备完毕后,可以通过如下的方式启动训练:
python train.py \
--model_save_dir=output/ \
--pretrain=${path_to_pretrain_model} \
--data_dir=${path_to_data} \
--class_num=${category_num}
**多卡训练:**
动态图支持多进程多卡进行模型训练,启动方式:
首先通过设置`export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7`指定8卡GPU训练。
`python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --started_port=9999 train.py --batch_size=16 --use_data_parallel=1`
您也可以直接运行快速开始脚本`start_parall.sh`进行训练,默认使用4卡进行训练,每张卡的batch size为16
执行训练开始时,会得到类似如下输出,每次迭代打印的log数与指定卡数一致:
```
Iter 2, loss 9056.620443, time 3.21156
Iter 3, loss 7720.641968, time 1.63363
Iter 4, loss 6736.150391, time 2.70573
```
**注意:** YOLOv3模型总batch size为64,这里使用4 GPUs每GPU上batch size为16来训练
**模型设置:**
* 模型使用了基于COCO数据集生成的9个先验框:10x13,16x30,33x23,30x61,62x45,59x119,116x90,156x198,373x326
* YOLOv3模型中,若预测框不是该点最佳匹配框但是和任一ground truth框的重叠大于`ignore_thresh=0.7`,则忽略该预测框的目标性损失
**训练策略:**
* 采用momentum优化算法训练YOLOv3,momentum=0.9。
* 学习率采用warmup算法,前4000个Iter学习率从0.0线性增加至0.001。在400000,450000个Iter时使用0.1,0.01乘子进行学习率衰减,最大训练500000个Iter。
下图为模型训练结果:
<p align="center">
<img src="image/train_loss.png" height="400" width="550" hspace="10"/><br />
Train Loss
</p>
### 模型评估
模型评估是指对训练完毕的模型评估各类性能指标。本示例采用[COCO官方评估](http://cocodataset.org/#detections-eval)
sh ./weights/download.sh
`eval.py`是评估模块的主要执行程序,调用示例如下:
python eval.py \
--dataset=coco2017 \
--weights=${path_to_weights} \
--class_num=${category_num}
- 通过设置`export CUDA_VISIBLE_DEVICES=0`指定单卡GPU评估。
## 进阶使用
### 背景介绍
传统目标检测方法通过两阶段检测,第一阶段生成预选框,第二阶段对预选框进行分类和位置坐标的调整,而YOLO将目标检测看做是对框位置和类别概率的一个单阶段回归问题,使得YOLO能达到近两倍的检测速度。而YOLOv3在YOLO的基础上引入的多尺度预测,使得YOLOv3网络对于小物体的检测精度大幅提高。
### 模型概览
[YOLOv3](https://arxiv.org/abs/1804.02767) 是一阶段End2End的目标检测器。其目标检测原理如下图所示:
<p align="center">
<img src="image/YOLOv3.jpg" height=400 width=600 hspace='10'/> <br />
YOLOv3检测原理
</p>
### 模型结构
YOLOv3将输入图像分成S\*S个格子,每个格子预测B个bounding box,每个bounding box预测内容包括: Location(x, y, w, h)、Confidence Score和C个类别的概率,因此YOLOv3输出层的channel数为B\*(5 + C)。YOLOv3的loss函数也有三部分组成:Location误差,Confidence误差和分类误差。
YOLOv3的网络结构如下图所示:
<p align="center">
<img src="image/YOLOv3_structure.jpg" height=400 width=400 hspace='10'/> <br />
YOLOv3网络结构
</p>
YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层和输出层组成。
1. 特征提取网络。YOLOv3使用 [DarkNet53](https://arxiv.org/abs/1612.08242)作为特征提取网络:DarkNet53 基本采用了全卷积网络,用步长为2的卷积操作替代了池化层,同时添加了 Residual 单元,避免在网络层数过深时发生梯度弥散。
2. 特征融合层。为了解决之前YOLO版本对小目标不敏感的问题,YOLOv3采用了3个不同尺度的特征图来进行目标检测,分别为13\*13,26\*26,52\*52,用来检测大、中、小三种目标。特征融合层选取 DarkNet 产出的三种尺度特征图作为输入,借鉴了FPN(feature pyramid networks)的思想,通过一系列的卷积层和上采样对各尺度的特征图进行融合。
3. 输出层。同样使用了全卷积结构,其中最后一个卷积层的卷积核个数是255:3\*(80+4+1)=255,3表示一个grid cell包含3个bounding box,4表示框的4个坐标信息,1表示Confidence Score,80表示COCO数据集中80个类别的概率。
## FAQ
**Q:** 我使用单GPU训练,训练过程中`loss=nan`,这是为什么?
**A:** YOLOv3中`learning_rate=0.001`的设置是针对总batch size为64的情况,若用户的batch size小于该值,建议调小学习率。
**Q:** 我训练YOLOv3速度比较慢,要怎么提速?
**A:** YOLOv3的数据增强比较复杂,速度比较慢,可通过在[reader.py](./reader.py#L284)中增加数据读取的进程数来提速。若用户是进行fine-tune,也可将`--no_mixup_iter`设置大于`--max_iter`的值来禁用mixup提升速度。
**Q:** 我使用YOLOv3训练两个类别的数据集,训练`loss=nan`或推断结果不符合预期,这是为什么?
**A:** `--label_smooth`参数会把所有正例的目标值设置为`1-1/class_num`,负例的目标值设为`1/class_num`,当`class_num`较小时,这个操作影响过大,可能会出现`loss=nan`或者训练结果错误,类别数较小时建议设置`--label_smooth=False`。若使用Paddle Fluid v1.5及以上版本,我们在C++代码中对这种情况作了保护,设置`--label_smooth=True`也不会出现这些问题。
## 参考文献
- [You Only Look Once: Unified, Real-Time Object Detection](https://arxiv.org/abs/1506.02640v5), Joseph Redmon, Santosh Divvala, Ross Girshick, Ali Farhadi.
- [YOLOv3: An Incremental Improvement](https://arxiv.org/abs/1804.02767v1), Joseph Redmon, Ali Farhadi.
- [Bag of Freebies for Training Object Detection Neural Networks](https://arxiv.org/abs/1902.04103v3), Zhi Zhang, Tong He, Hang Zhang, Zhongyue Zhang, Junyuan Xie, Mu Li.
## 版本更新
- 12/2019, 新增YOLOv3动态图模型
## 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
## 作者
- [heavengate](https://github.com/heavengate)
- [tink2123](https://github.com/tink2123)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from PIL import Image
def coco_anno_box_to_center_relative(box, img_height, img_width):
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert len(box) == 4, "box should be a len(4) list or tuple"
x, y, w, h = box
x1 = max(x, 0)
x2 = min(x + w - 1, img_width - 1)
y1 = max(y, 0)
y2 = min(y + h - 1, img_height - 1)
x = (x1 + x2) / 2 / img_width
y = (y1 + y2) / 2 / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
return np.array([x, y, w, h])
def clip_relative_box_in_image(x, y, w, h):
"""Clip relative box coordinates x, y, w, h to [0, 1]"""
x1 = max(x - w / 2, 0.)
x2 = min(x + w / 2, 1.)
y1 = min(y - h / 2, 0.)
y2 = max(y + h / 2, 1.)
x = (x1 + x2) / 2
y = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
def box_xywh_to_xyxy(box):
shape = box.shape
assert shape[-1] == 4, "Box shape[-1] should be 4."
box = box.reshape((-1, 4))
box[:, 0], box[:, 2] = box[:, 0] - box[:, 2] / 2, box[:, 0] + box[:, 2] / 2
box[:, 1], box[:, 3] = box[:, 1] - box[:, 3] / 2, box[:, 1] + box[:, 3] / 2
box = box.reshape(shape)
return box
def box_iou_xywh(box1, box2):
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.minimum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1
inter_h = inter_y2 - inter_y1
inter_w[inter_w < 0] = 0
inter_h[inter_h < 0] = 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
return inter_area / (b1_area + b2_area - inter_area)
def box_iou_xyxy(box1, box2):
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.minimum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1
inter_h = inter_y2 - inter_y1
inter_w[inter_w < 0] = 0
inter_h[inter_h < 0] = 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
return inter_area / (b1_area + b2_area - inter_area)
def box_crop(boxes, labels, scores, crop, img_shape):
x, y, w, h = map(float, crop)
im_w, im_h = map(float, img_shape)
boxes = boxes.copy()
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (
boxes[:, 0] + boxes[:, 2] / 2) * im_w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (
boxes[:, 1] + boxes[:, 3] / 2) * im_h
crop_box = np.array([x, y, x + w, y + h])
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(
axis=1)
boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2])
boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:])
boxes[:, :2] -= crop_box[:2]
boxes[:, 2:] -= crop_box[:2]
mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1))
boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1)
labels = labels * mask.astype('float32')
scores = scores * mask.astype('float32')
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (
boxes[:, 2] - boxes[:, 0]) / w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (
boxes[:, 3] - boxes[:, 1]) / h
return boxes, labels, scores, mask.sum()
def draw_boxes_on_image(image_path,
boxes,
scores,
labels,
label_names,
score_thresh=0.5):
image = np.array(Image.open(image_path))
plt.figure()
_, ax = plt.subplots(1)
ax.imshow(image)
image_name = image_path.split('/')[-1]
print("Image {} detect: ".format(image_name))
colors = {}
for box, score, label in zip(boxes, scores, labels):
if score < score_thresh:
continue
if box[2] <= box[0] or box[3] <= box[1]:
continue
label = int(label)
if label not in colors:
colors[label] = plt.get_cmap('hsv')(label / len(label_names))
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
rect = plt.Rectangle(
(x1, y1),
x2 - x1,
y2 - y1,
fill=False,
linewidth=2.0,
edgecolor=colors[label])
ax.add_patch(rect)
ax.text(
x1,
y1,
'{} {:.4f}'.format(label_names[label], score),
verticalalignment='bottom',
horizontalalignment='left',
bbox={'facecolor': colors[label],
'alpha': 0.5,
'pad': 0},
fontsize=8,
color='white')
print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(
label)], str(list(map(int, list(box)))), score))
image_name = image_name.replace('jpg', 'png')
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig(
"./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0)
print("Detect result save at ./output/{}\n".format(image_name))
plt.cla()
plt.close('all')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from edict import AttrDict
import six
import numpy as np
_C = AttrDict()
cfg = _C
#
# Training options
#
# Snapshot period
_C.snapshot_iter = 2000
# min valid area for gt boxes
_C.gt_min_area = -1
# max target box number in an image
_C.max_box_num = 50
#
# Training options
#
# valid score threshold to include boxes
_C.valid_thresh = 0.005
# threshold vale for box non-max suppression
_C.nms_thresh = 0.45
# the number of top k boxes to perform nms
_C.nms_topk = 400
# the number of output boxes after nms
_C.nms_posk = 100
# score threshold for draw box in debug mode
_C.draw_thresh = 0.5
#
# Model options
#
# pixel mean values
_C.pixel_means = [0.485, 0.456, 0.406]
# pixel std values
_C.pixel_stds = [0.229, 0.224, 0.225]
# anchors box weight and height
_C.anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326
]
# anchor mask of each yolo layer
_C.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
# IoU threshold to ignore objectness loss of pred box
_C.ignore_thresh = .7
#
# SOLVER options
#
# batch size
_C.batch_size = 8
# derived learning rate the to get the final learning rate.
_C.learning_rate = 0.001
# maximum number of iterations
_C.max_iter = 500200
# warm up to learning rate
_C.warm_up_iter = 4000
_C.warm_up_factor = 0.
# lr steps_with_decay
_C.lr_steps = [400000, 450000]
_C.lr_gamma = 0.1
# L2 regularization hyperparameter
_C.weight_decay = 0.0005
# momentum with SGD
_C.momentum = 0.9
#
# ENV options
#
# support both CPU and GPU
_C.use_gpu = True
# Class number
_C.class_num = 80
# dataset path
_C.train_file_list = 'annotations/instances_train2017.json'
_C.train_data_dir = 'train2017'
_C.val_file_list = 'annotations/instances_val2017.json'
_C.val_data_dir = 'val2017'
def merge_cfg_from_args(args):
"""Merge config keys, values in args into the global config."""
for k, v in sorted(six.iteritems(vars(args))):
try:
value = eval(v)
except:
value = v
_C[k] = value
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
import os
import sys
import signal
import time
import numpy as np
import threading
import multiprocessing
try:
import queue
except ImportError:
import Queue as queue
# handle terminate reader process, do not print stack frame
def _reader_quit(signum, frame):
print("Reader process exit.")
sys.exit()
def _term_group(sig_num, frame):
print('pid {} terminated, terminate group '
'{}...'.format(os.getpid(), os.getpgrp()))
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
signal.signal(signal.SIGTERM, _reader_quit)
signal.signal(signal.SIGINT, _term_group)
class GeneratorEnqueuer(object):
"""
Builds a queue out of a data generator.
Args:
generator: a generator function which endlessly yields data
use_multiprocessing (bool): use multiprocessing if True,
otherwise use threading.
wait_time (float): time to sleep in-between calls to `put()`.
random_seed (int): Initial seed for workers,
will be incremented by one for each workers.
"""
def __init__(self,
generator,
use_multiprocessing=False,
wait_time=0.05,
random_seed=None):
self.wait_time = wait_time
self._generator = generator
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self.queue = None
self._manager = None
self.seed = random_seed
def start(self, workers=1, max_queue_size=10):
"""
Start worker threads which add data from the generator into the queue.
Args:
workers (int): number of worker threads
max_queue_size (int): queue size
(when full, threads could block on `put()`)
"""
def data_generator_task():
"""
Data generator task.
"""
def task():
if (self.queue is not None and
self.queue.qsize() < max_queue_size):
generator_output = next(self._generator)
self.queue.put((generator_output))
else:
time.sleep(self.wait_time)
if not self._use_multiprocessing:
while not self._stop_event.is_set():
with self.genlock:
try:
task()
except Exception:
self._stop_event.set()
break
else:
while not self._stop_event.is_set():
try:
task()
except Exception:
self._stop_event.set()
break
try:
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.genlock = threading.Lock()
self.queue = queue.Queue()
self._stop_event = threading.Event()
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.seed)
thread = multiprocessing.Process(target=data_generator_task)
thread.daemon = True
if self.seed is not None:
self.seed += 1
else:
thread = threading.Thread(target=data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise
def is_running(self):
"""
Returns:
bool: Whether the worker theads are running.
"""
return self._stop_event is not None and not self._stop_event.is_set()
def stop(self, timeout=None):
"""
Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called `start()`.
Args:
timeout(int|None): maximum time to wait on `thread.join()`.
"""
if self.is_running():
self._stop_event.set()
for thread in self._threads:
if self._use_multiprocessing:
if thread.is_alive():
thread.join(timeout)
else:
thread.join(timeout)
if self._manager:
self._manager.shutdown()
self._threads = []
self._stop_event = None
self.queue = None
def get(self):
"""
Creates a generator to extract data from the queue.
Skip the data if it is `None`.
# Yields
tuple of data in the queue.
"""
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import sys
import zipfile
import logging
from paddle.dataset.common import download
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DATASETS = {
'coco': [
# coco2017
('http://images.cocodataset.org/zips/train2017.zip',
'cced6f7f71b7629ddf16f17bbcfab6b2', ),
('http://images.cocodataset.org/zips/val2017.zip',
'442b8da7639aecaf257c1dceb8ba8c80', ),
('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
'f4bbac642086de4f52a3fdda2de5fa2c', ),
# coco2014
('http://images.cocodataset.org/zips/train2014.zip',
'0da8c0bd3d6becc4dcb32757491aca88', ),
('http://images.cocodataset.org/zips/val2014.zip',
'a3d79f5ed8d289b7a7554ce06a5782b3', ),
('http://images.cocodataset.org/annotations/annotations_trainval2014.zip',
'0a379cfc70b0e71301e0f377548639bd', ),
],
}
def download_decompress_file(data_dir, url, md5):
logger.info("Downloading from {}".format(url))
zip_file = download(url, data_dir, md5)
logger.info("Decompressing {}".format(zip_file))
with zipfile.ZipFile(zip_file) as zf:
zf.extractall(path=data_dir)
os.remove(zip_file)
if __name__ == "__main__":
data_dir = osp.split(osp.realpath(sys.argv[0]))[0]
for name, infos in DATASETS.items():
for info in infos:
download_decompress_file(data_dir, info[0], info[1])
logger.info("Download dataset {} finished.".format(name))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
print("PADDLE_TRAINERS_NUM", num_trainers)
print("PADDLE_TRAINER_ID", trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import json
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
import reader
from models.yolov3 import YOLOv3
from utility import print_arguments, parse_args, check_gpu
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config import cfg
def eval():
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if '2014' in cfg.dataset:
test_list = 'annotations/instances_val2014.json'
elif '2017' in cfg.dataset:
test_list = 'annotations/instances_val2017.json'
if cfg.debug:
if not os.path.exists('output'):
os.mkdir('output')
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = YOLOv3(3,is_train=False)
# yapf: disable
if cfg.weights:
restore, _ = fluid.load_dygraph(cfg.weights)
model.set_dict(restore)
model.eval()
input_size = cfg.input_size
# batch_size for test must be 1
test_reader = reader.test(input_size, 1)
label_names, label_ids = reader.get_label_infos()
if cfg.debug:
print("Load in labels {} with ids {}".format(label_names, label_ids))
def get_pred_result(boxes, scores, labels, im_id):
result = []
for box, score, label in zip(boxes, scores, labels):
x1, y1, x2, y2 = box
w = x2 - x1 + 1
h = y2 - y1 + 1
bbox = [x1, y1, w, h]
res = {
'image_id': int(im_id),
'category_id': label_ids[int(label)],
'bbox': list(map(float, bbox)),
'score': float(score)
}
result.append(res)
return result
dts_res = []
total_time = 0
for iter_id, data in enumerate(test_reader()):
start_time = time.time()
img_data = np.array([x[0] for x in data]).astype('float32')
img = to_variable(img_data)
im_id_data = np.array([x[1] for x in data]).astype('int32')
im_id = to_variable(im_id_data)
im_shape_data = np.array([x[2] for x in data]).astype('int32')
im_shape = to_variable(im_shape_data)
batch_outputs = model(img, None, None, None, im_id, im_shape)
nmsed_boxes = batch_outputs.numpy()
if nmsed_boxes.shape[1] != 6:
continue
im_id = data[0][1]
nmsed_box=nmsed_boxes
labels = nmsed_box[:, 0]
scores = nmsed_box[:, 1]
boxes = nmsed_box[:, 2:6]
dts_res += get_pred_result(boxes, scores, labels, im_id)
end_time = time.time()
print("batch id: {}, time: {}".format(iter_id, end_time - start_time))
total_time += end_time - start_time
with open("yolov3_result.json", 'w') as outfile:
json.dump(dts_res, outfile)
print("start evaluate detection result with coco api")
coco = COCO(os.path.join(cfg.data_dir, test_list))
cocoDt = coco.loadRes("yolov3_result.json")
cocoEval = COCOeval(coco, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
print("evaluate done.")
print("Time per batch: {}".format(total_time / iter_id))
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
eval()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
from PIL import Image, ImageEnhance
import random
import box_utils
def random_distort(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = Image.fromarray(img)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
img = np.asarray(img)
return img
def random_crop(img,
boxes,
labels,
scores,
scales=[0.3, 1.0],
max_ratio=2.0,
constraints=None,
max_trial=50):
if len(boxes) == 0:
return img, boxes
if not constraints:
constraints = [(0.1, 1.0), (0.3, 1.0), (0.5, 1.0), (0.7, 1.0),
(0.9, 1.0), (0.0, 1.0)]
img = Image.fromarray(img)
w, h = img.size
crops = [(0, 0, w, h)]
for min_iou, max_iou in constraints:
for _ in range(max_trial):
scale = random.uniform(scales[0], scales[1])
aspect_ratio = random.uniform(max(1 / max_ratio, scale * scale), \
min(max_ratio, 1 / scale / scale))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_x = random.randrange(w - crop_w)
crop_y = random.randrange(h - crop_h)
crop_box = np.array([[(crop_x + crop_w / 2.0) / w,
(crop_y + crop_h / 2.0) / h,
crop_w / float(w), crop_h / float(h)]])
iou = box_utils.box_iou_xywh(crop_box, boxes)
if min_iou <= iou.min() and max_iou >= iou.max():
crops.append((crop_x, crop_y, crop_w, crop_h))
break
while crops:
crop = crops.pop(np.random.randint(0, len(crops)))
crop_boxes, crop_labels, crop_scores, box_num = \
box_utils.box_crop(boxes, labels, scores, crop, (w, h))
if box_num < 1:
continue
img = img.crop((crop[0], crop[1], crop[0] + crop[2],
crop[1] + crop[3])).resize(img.size, Image.LANCZOS)
img = np.asarray(img)
return img, crop_boxes, crop_labels, crop_scores
img = np.asarray(img)
return img, boxes, labels, scores
def random_flip(img, gtboxes, thresh=0.5):
if random.random() > thresh:
img = img[:, ::-1, :]
gtboxes[:, 0] = 1.0 - gtboxes[:, 0]
return img, gtboxes
def random_interp(img, size, interp=None):
interp_method = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
]
if not interp or interp not in interp_method:
interp = interp_method[random.randint(0, len(interp_method) - 1)]
h, w, _ = img.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
img = cv2.resize(
img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp)
return img
def random_expand(img,
gtboxes,
max_ratio=4.,
fill=None,
keep_ratio=True,
thresh=0.5):
if random.random() > thresh:
return img, gtboxes
if max_ratio < 1.0:
return img, gtboxes
h, w, c = img.shape
ratio_x = random.uniform(1, max_ratio)
if keep_ratio:
ratio_y = ratio_x
else:
ratio_y = random.uniform(1, max_ratio)
oh = int(h * ratio_y)
ow = int(w * ratio_x)
off_x = random.randint(0, ow - w)
off_y = random.randint(0, oh - h)
out_img = np.zeros((oh, ow, c))
if fill and len(fill) == c:
for i in range(c):
out_img[:, :, i] = fill[i] * 255.0
out_img[off_y:off_y + h, off_x:off_x + w, :] = img
gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow)
gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh)
gtboxes[:, 2] = gtboxes[:, 2] / ratio_x
gtboxes[:, 3] = gtboxes[:, 3] / ratio_y
return out_img.astype('uint8'), gtboxes
def shuffle_gtbox(gtbox, gtlabel, gtscore):
gt = np.concatenate(
[gtbox, gtlabel[:, np.newaxis], gtscore[:, np.newaxis]], axis=1)
idx = np.arange(gt.shape[0])
np.random.shuffle(idx)
gt = gt[idx, :]
return gt[:, :4], gt[:, 4], gt[:, 5]
def image_mixup(img1, gtboxes1, gtlabels1, gtscores1, img2, gtboxes2, gtlabels2,
gtscores2):
factor = np.random.beta(1.5, 1.5)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
return img1, gtboxes1, gtlabels1
if factor <= 0.0:
return img2, gtboxes2, gtlabels2
gtscores1 = gtscores1 * factor
gtscores2 = gtscores2 * (1.0 - factor)
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
img = np.zeros((h, w, img1.shape[2]), 'float32')
img[:img1.shape[0], :img1.shape[1], :] = img1.astype('float32') * factor
img[:img2.shape[0], :img2.shape[1], :] += \
img2.astype('float32') * (1.0 - factor)
gtboxes = np.zeros_like(gtboxes1)
gtlabels = np.zeros_like(gtlabels1)
gtscores = np.zeros_like(gtscores1)
gt_valid_mask1 = np.logical_and(gtboxes1[:, 2] > 0, gtboxes1[:, 3] > 0)
gtboxes1 = gtboxes1[gt_valid_mask1]
gtlabels1 = gtlabels1[gt_valid_mask1]
gtscores1 = gtscores1[gt_valid_mask1]
gtboxes1[:, 0] = gtboxes1[:, 0] * img1.shape[1] / w
gtboxes1[:, 1] = gtboxes1[:, 1] * img1.shape[0] / h
gtboxes1[:, 2] = gtboxes1[:, 2] * img1.shape[1] / w
gtboxes1[:, 3] = gtboxes1[:, 3] * img1.shape[0] / h
gt_valid_mask2 = np.logical_and(gtboxes2[:, 2] > 0, gtboxes2[:, 3] > 0)
gtboxes2 = gtboxes2[gt_valid_mask2]
gtlabels2 = gtlabels2[gt_valid_mask2]
gtscores2 = gtscores2[gt_valid_mask2]
gtboxes2[:, 0] = gtboxes2[:, 0] * img2.shape[1] / w
gtboxes2[:, 1] = gtboxes2[:, 1] * img2.shape[0] / h
gtboxes2[:, 2] = gtboxes2[:, 2] * img2.shape[1] / w
gtboxes2[:, 3] = gtboxes2[:, 3] * img2.shape[0] / h
gtboxes_all = np.concatenate((gtboxes1, gtboxes2), axis=0)
gtlabels_all = np.concatenate((gtlabels1, gtlabels2), axis=0)
gtscores_all = np.concatenate((gtscores1, gtscores2), axis=0)
gt_num = min(len(gtboxes), len(gtboxes_all))
gtboxes[:gt_num] = gtboxes_all[:gt_num]
gtlabels[:gt_num] = gtlabels_all[:gt_num]
gtscores[:gt_num] = gtscores_all[:gt_num]
return img.astype('uint8'), gtboxes, gtlabels, gtscores
def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
img = random_distort(img)
img, gtboxes = random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels, gtscores = \
random_crop(img, gtboxes, gtlabels, gtscores)
img = random_interp(img, size)
img, gtboxes = random_flip(img, gtboxes)
gtboxes, gtlabels, gtscores = shuffle_gtbox(gtboxes, gtlabels, gtscores)
return img.astype('float32'), gtboxes.astype('float32'), \
gtlabels.astype('int32'), gtscores.astype('float32')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import numpy as np
import paddle
import paddle.fluid as fluid
import box_utils
import reader
from utility import print_arguments, parse_args, check_gpu
from models.yolov3 import YOLOv3
from paddle.fluid.dygraph.base import to_variable
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config import cfg
def infer():
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if not os.path.exists('output'):
os.mkdir('output')
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = YOLOv3(3, is_train=False)
input_size = cfg.input_size
# yapf: disable
if cfg.weights:
restore, _ = fluid.load_dygraph(cfg.weights)
model.set_dict(restore)
# yapf: enable
# you can save inference model by following code
# fluid.io.save_inference_model("./output/yolov3",
# feeded_var_names=['image', 'im_shape'],
# target_vars=outputs,
# executor=exe)
image_names = []
if cfg.image_name is not None:
image_names.append(cfg.image_name)
else:
for image_name in os.listdir(cfg.image_path):
if image_name.split('.')[-1] in ['jpg', 'png']:
image_names.append(image_name)
for image_name in image_names:
infer_reader = reader.infer(input_size,
os.path.join(cfg.image_path, image_name))
label_names, _ = reader.get_label_infos()
data = next(infer_reader())
img_data = np.array([x[0] for x in data]).astype('float32')
img = to_variable(img_data)
im_shape_data = np.array([x[2] for x in data]).astype('int32')
im_shape = to_variable(im_shape_data)
outputs = model(img, None, None, None, None, im_shape)
bboxes = outputs.numpy()
if bboxes.shape[1] != 6:
print("No object found in {}".format(image_name))
continue
labels = bboxes[:, 0].astype('int32')
scores = bboxes[:, 1].astype('float32')
boxes = bboxes[:, 2:].astype('float32')
path = os.path.join(cfg.image_path, image_name)
box_utils.draw_boxes_on_image(path, boxes, scores, labels, label_names,
cfg.draw_thresh)
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
infer()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from paddle.fluid.dygraph.base import to_variable
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=1,
groups=1,
padding=0,
act="leaky",
is_test=True):
super(ConvBNLayer, self).__init__()
self.conv = Conv2D(
num_channels=ch_in,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=False,
act=None)
self.batch_norm = BatchNorm(
num_channels=ch_out,
is_test=is_test,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02),
regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.)))
self.act = act
def forward(self, inputs):
out = self.conv(inputs)
out = self.batch_norm(out)
if self.act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
class DownSample(fluid.dygraph.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=2,
padding=1,
is_test=True):
super(DownSample, self).__init__()
self.conv_bn_layer = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
is_test=is_test)
self.ch_out = ch_out
def forward(self, inputs):
out = self.conv_bn_layer(inputs)
return out
class BasicBlock(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out, is_test=True):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=1,
stride=1,
padding=0,
is_test=is_test
)
self.conv2 = ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out*2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test
)
def forward(self, inputs):
conv1 = self.conv1(inputs)
conv2 = self.conv2(conv1)
out = fluid.layers.elementwise_add(x=inputs, y=conv2, act=None)
return out
class LayerWarp(fluid.dygraph.Layer):
def __init__(self, ch_in, ch_out, count, is_test=True):
super(LayerWarp,self).__init__()
self.basicblock0 = BasicBlock(ch_in,
ch_out,
is_test=is_test)
self.res_out_list = []
for i in range(1,count):
res_out = self.add_sublayer("basic_block_%d" % (i),
BasicBlock(
ch_out*2,
ch_out,
is_test=is_test))
self.res_out_list.append(res_out)
self.ch_out = ch_out
def forward(self,inputs):
y = self.basicblock0(inputs)
for basic_block_i in self.res_out_list:
y = basic_block_i(y)
return y
DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
class DarkNet53_conv_body(fluid.dygraph.Layer):
def __init__(self,
ch_in=3,
is_test=True):
super(DarkNet53_conv_body, self).__init__()
self.stages = DarkNet_cfg[53]
self.stages = self.stages[0:5]
self.conv0 = ConvBNLayer(
ch_in=ch_in,
ch_out=32,
filter_size=3,
stride=1,
padding=1,
is_test=is_test)
self.downsample0 = DownSample(
ch_in=32,
ch_out=32 * 2,
is_test=is_test)
self.darknet53_conv_block_list = []
self.downsample_list = []
ch_in = [64,128,256,512,1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer(
"stage_%d" % (i),
LayerWarp(
int(ch_in[i]),
32*(2**i),
stage,
is_test=is_test))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
downsample = self.add_sublayer(
"stage_%d_downsample" % i,
DownSample(
ch_in = 32*(2**(i+1)),
ch_out = 32*(2**(i+2)),
is_test=is_test))
self.downsample_list.append(downsample)
def forward(self,inputs):
out = self.conv0(inputs)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
out = self.downsample_list[i](out)
return blocks[-1:-4:-1]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.initializer import Normal
from paddle.fluid.regularizer import L2Decay
from config import cfg
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm
from darknet import DarkNet53_conv_body
from darknet import ConvBNLayer
from paddle.fluid.dygraph.base import to_variable
class YoloDetectionBlock(fluid.dygraph.Layer):
def __init__(self,ch_in,channel,is_test=True):
super(YoloDetectionBlock, self).__init__()
assert channel % 2 == 0, \
"channel {} cannot be divided by 2".format(channel)
self.conv0 = ConvBNLayer(
ch_in=ch_in,
ch_out=channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test
)
self.conv1 = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test
)
self.conv2 = ConvBNLayer(
ch_in=channel*2,
ch_out=channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test
)
self.conv3 = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test
)
self.route = ConvBNLayer(
ch_in=channel*2,
ch_out=channel,
filter_size=1,
stride=1,
padding=0,
is_test=is_test
)
self.tip = ConvBNLayer(
ch_in=channel,
ch_out=channel*2,
filter_size=3,
stride=1,
padding=1,
is_test=is_test
)
def forward(self, inputs):
out = self.conv0(inputs)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
route = self.route(out)
tip = self.tip(route)
return route, tip
class Upsample(fluid.dygraph.Layer):
def __init__(self,scale=2):
super(Upsample,self).__init__()
self.scale = scale
def forward(self, inputs):
# get dynamic upsample output shape
shape_nchw = fluid.layers.shape(inputs)
shape_hw = fluid.layers.slice(shape_nchw, axes=[0], starts=[2], ends=[4])
shape_hw.stop_gradient = True
in_shape = fluid.layers.cast(shape_hw, dtype='int32')
out_shape = in_shape * self.scale
out_shape.stop_gradient = True
# reisze by actual_shape
out = fluid.layers.resize_nearest(
input=inputs, scale=self.scale, actual_shape=out_shape)
return out
class YOLOv3(fluid.dygraph.Layer):
def __init__(self,ch_in,is_train=True, use_random=False):
super(YOLOv3,self).__init__()
self.is_train = is_train
self.use_random = use_random
self.block = DarkNet53_conv_body(ch_in=ch_in,
is_test = not self.is_train)
self.block_outputs = []
self.yolo_blocks = []
self.route_blocks_2 = []
ch_in_list = [1024,768,384]
for i in range(3):
yolo_block = self.add_sublayer(
"yolo_detecton_block_%d" % (i),
YoloDetectionBlock(ch_in_list[i],
channel = 512//(2**i),
is_test = not self.is_train))
self.yolo_blocks.append(yolo_block)
num_filters = len(cfg.anchor_masks[i]) * (cfg.class_num + 5)
block_out = self.add_sublayer(
"block_out_%d" % (i),
Conv2D(num_channels=1024//(2**i),
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.))))
self.block_outputs.append(block_out)
if i < 2:
route = self.add_sublayer("route2_%d"%i,
ConvBNLayer(ch_in=512//(2**i),
ch_out=256//(2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not self.is_train)))
self.route_blocks_2.append(route)
self.upsample = Upsample()
def forward(self, inputs, gtbox=None, gtlabel=None, gtscore=None, im_id=None, im_shape=None ):
self.outputs = []
self.boxes = []
self.scores = []
self.losses = []
self.downsample = 32
blocks = self.block(inputs)
for i, block in enumerate(blocks):
if i > 0:
block = fluid.layers.concat(input=[route, block], axis=1)
route, tip = self.yolo_blocks[i](block)
block_out = self.block_outputs[i](tip)
self.outputs.append(block_out)
if i < 2:
route = self.route_blocks_2[i](route)
route = self.upsample(route)
self.gtbox = gtbox
self.gtlabel = gtlabel
self.gtscore = gtscore
self.im_id = im_id
self.im_shape = im_shape
# cal loss
for i,out in enumerate(self.outputs):
anchor_mask = cfg.anchor_masks[i]
if self.is_train:
loss = fluid.layers.yolov3_loss(
x=out,
gt_box=self.gtbox,
gt_label=self.gtlabel,
gt_score=self.gtscore,
anchors=cfg.anchors,
anchor_mask=anchor_mask,
class_num=cfg.class_num,
ignore_thresh=cfg.ignore_thresh,
downsample_ratio=self.downsample,
use_label_smooth=cfg.label_smooth)
self.losses.append(fluid.layers.reduce_mean(loss))
else:
mask_anchors = []
for m in anchor_mask:
mask_anchors.append(cfg.anchors[2 * m])
mask_anchors.append(cfg.anchors[2 * m + 1])
boxes, scores = fluid.layers.yolo_box(
x=out,
img_size=self.im_shape,
anchors=mask_anchors,
class_num=cfg.class_num,
conf_thresh=cfg.valid_thresh,
downsample_ratio=self.downsample,
name="yolo_box" + str(i))
self.boxes.append(boxes)
self.scores.append(
fluid.layers.transpose(
scores, perm=[0, 2, 1]))
self.downsample //= 2
if not self.is_train:
# get pred
yolo_boxes = fluid.layers.concat(self.boxes, axis=1)
yolo_scores = fluid.layers.concat(self.scores, axis=2)
pred = fluid.layers.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=cfg.valid_thresh,
nms_top_k=cfg.nms_topk,
keep_top_k=cfg.nms_posk,
nms_threshold=cfg.nms_thresh,
background_label=-1)
return pred
else:
return sum(self.losses)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import os
import sys
import random
import time
import copy
import cv2
import box_utils
import image_utils
from pycocotools.coco import COCO
from data_utils import GeneratorEnqueuer
from config import cfg
import paddle.fluid as fluid
class DataSetReader(object):
"""A class for parsing and read COCO dataset"""
def __init__(self):
self.has_parsed_categpry = False
def _parse_dataset_dir(self, mode):
if 'coco2014' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2014.json'
cfg.train_data_dir = 'train2014'
cfg.val_file_list = 'annotations/instances_val2014.json'
cfg.val_data_dir = 'val2014'
elif 'coco2017' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2017.json'
cfg.train_data_dir = 'train2017'
cfg.val_file_list = 'annotations/instances_val2017.json'
cfg.val_data_dir = 'val2017'
else:
raise NotImplementedError('Dataset {} not supported'.format(
cfg.dataset))
if mode == 'train':
cfg.train_file_list = os.path.join(cfg.data_dir,
cfg.train_file_list)
cfg.train_data_dir = os.path.join(cfg.data_dir, cfg.train_data_dir)
self.COCO = COCO(cfg.train_file_list)
self.img_dir = cfg.train_data_dir
elif mode == 'test' or mode == 'infer':
cfg.val_file_list = os.path.join(cfg.data_dir, cfg.val_file_list)
cfg.val_data_dir = os.path.join(cfg.data_dir, cfg.val_data_dir)
self.COCO = COCO(cfg.val_file_list)
self.img_dir = cfg.val_data_dir
def _parse_dataset_catagory(self):
self.categories = self.COCO.loadCats(self.COCO.getCatIds())
self.num_category = len(self.categories)
self.label_names = []
self.label_ids = []
for category in self.categories:
self.label_names.append(category['name'])
self.label_ids.append(int(category['id']))
self.category_to_id_map = {v: i for i, v in enumerate(self.label_ids)}
print("Load in {} categories.".format(self.num_category))
if self.num_category != cfg.class_num:
raise ValueError("category number({}) in your dataset is not equal "
"to --class_num={} settting, which may incur errors in "
"eval/infer or cause precision loss.".format(
self.num_category, cfg.class_num))
self.has_parsed_categpry = True
def get_label_infos(self):
if not self.has_parsed_categpry:
self._parse_dataset_dir("test")
self._parse_dataset_catagory()
return (self.label_names, self.label_ids)
def _parse_gt_annotations(self, img):
img_height = img['height']
img_width = img['width']
anno = self.COCO.loadAnns(
self.COCO.getAnnIds(
imgIds=img['id'], iscrowd=None))
gt_index = 0
for target in anno:
if target['area'] < cfg.gt_min_area:
continue
if 'ignore' in target and target['ignore']:
continue
box = box_utils.coco_anno_box_to_center_relative(
target['bbox'], img_height, img_width)
if box[2] <= 0 and box[3] <= 0:
continue
img['gt_boxes'][gt_index] = box
img['gt_labels'][gt_index] = \
self.category_to_id_map[target['category_id']]
gt_index += 1
if gt_index >= cfg.max_box_num:
break
def _parse_images(self, is_train):
image_ids = self.COCO.getImgIds()
image_ids.sort()
imgs = copy.deepcopy(self.COCO.loadImgs(image_ids))
for img in imgs:
img['image'] = os.path.join(self.img_dir, img['file_name'])
assert os.path.exists(img['image']), \
"image {} not found.".format(img['image'])
box_num = cfg.max_box_num
img['gt_boxes'] = np.zeros((cfg.max_box_num, 4), dtype=np.float32)
img['gt_labels'] = np.zeros((cfg.max_box_num), dtype=np.int32)
for k in ['date_captured', 'url', 'license', 'file_name']:
if k in img:
del img[k]
if is_train:
self._parse_gt_annotations(img)
print("Loaded {0} images from {1}.".format(len(imgs), cfg.dataset))
return imgs
def _parse_images_by_mode(self, mode):
if mode == 'infer':
return []
else:
return self._parse_images(is_train=(mode == 'train'))
def get_reader(self,
mode,
size=416,
batch_size=None,
shuffle=False,
shuffle_seed=None,
mixup_iter=0,
random_sizes=[],
image=None):
assert mode in ['train', 'test', 'infer'], "Unknow mode type!"
if mode != 'infer':
assert batch_size is not None, \
"batch size connot be None in mode {}".format(mode)
self._parse_dataset_dir(mode)
self._parse_dataset_catagory()
def img_reader(img, size, mean, std):
im_path = img['image']
im = cv2.imread(im_path).astype('float32')
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w, _ = im.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
out_img = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_CUBIC)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (out_img / 255.0 - mean) / std
out_img = out_img.transpose((2, 0, 1))
return (out_img, int(img['id']), (h, w))
def img_reader_with_augment(img, size, mean, std, mixup_img):
im_path = img['image']
im = cv2.imread(im_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
gt_boxes = img['gt_boxes'].copy()
gt_labels = img['gt_labels'].copy()
gt_scores = np.ones_like(gt_labels)
if mixup_img:
mixup_im = cv2.imread(mixup_img['image'])
mixup_im = cv2.cvtColor(mixup_im, cv2.COLOR_BGR2RGB)
mixup_gt_boxes = np.array(mixup_img['gt_boxes']).copy()
mixup_gt_labels = np.array(mixup_img['gt_labels']).copy()
mixup_gt_scores = np.ones_like(mixup_gt_labels)
im, gt_boxes, gt_labels, gt_scores = \
image_utils.image_mixup(im, gt_boxes, gt_labels,
gt_scores, mixup_im, mixup_gt_boxes,
mixup_gt_labels, mixup_gt_scores)
im, gt_boxes, gt_labels, gt_scores = \
image_utils.image_augment(im, gt_boxes, gt_labels,
gt_scores, size, mean)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (im / 255.0 - mean) / std
out_img = out_img.astype('float32').transpose((2, 0, 1))
return (out_img, gt_boxes, gt_labels, gt_scores)
def get_img_size(size, random_sizes=[]):
if len(random_sizes):
return np.random.choice(random_sizes)
return size
def get_mixup_img(imgs, mixup_iter, total_iter, read_cnt):
if total_iter >= mixup_iter:
return None
mixup_idx = np.random.randint(1, len(imgs))
mixup_img = imgs[(read_cnt + mixup_idx) % len(imgs)]
return mixup_img
def reader():
if mode == 'train':
imgs = self._parse_images_by_mode(mode)
if shuffle:
if shuffle_seed is not None:
np.random.seed(shuffle_seed)
np.random.shuffle(imgs)
read_cnt = 0
total_iter = 0
batch_out = []
img_size = get_img_size(size, random_sizes)
while True:
img = imgs[read_cnt % len(imgs)]
mixup_img = get_mixup_img(imgs, mixup_iter, total_iter,
read_cnt)
read_cnt += 1
if read_cnt % len(imgs) == 0 and shuffle:
np.random.shuffle(imgs)
im, gt_boxes, gt_labels, gt_scores = \
img_reader_with_augment(img, img_size, cfg.pixel_means,
cfg.pixel_stds, mixup_img)
batch_out.append([im, gt_boxes, gt_labels, gt_scores])
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
total_iter += 1
img_size = get_img_size(size, random_sizes)
elif mode == 'test':
imgs = self._parse_images_by_mode(mode)
batch_out = []
for img in imgs:
im, im_id, im_shape = img_reader(img, size, cfg.pixel_means,
cfg.pixel_stds)
batch_out.append((im, im_id, im_shape))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
if len(batch_out) != 0:
yield batch_out
else:
img = {}
img['image'] = image
img['id'] = 0
im, im_id, im_shape = img_reader(img, size, cfg.pixel_means,
cfg.pixel_stds)
batch_out = [(im, im_id, im_shape)]
yield batch_out
# NOTE: yolov3 is a special model, if num_trainers > 1, each process
# trian the completed dataset.
# num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
# if mode == 'train' and num_trainers > 1:
# assert shuffle_seed is not None, \
# "If num_trainers > 1, the shuffle_seed must be set, because " \
# "the order of batch data generated by reader " \
# "must be the same in the respective processes."
# reader = fluid.contrib.reader.distributed_batch_reader(reader)
return reader
dsr = DataSetReader()
def train(size=416,
batch_size=64,
shuffle=True,
shuffle_seed=None,
total_iter=0,
mixup_iter=0,
random_sizes=[],
num_workers=8,
max_queue=32,
use_multiprocess_reader=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, shuffle_seed,
int(mixup_iter / num_workers), random_sizes)
if not use_multiprocess_reader:
return generator
else:
if sys.platform == "win32":
print("multiprocess is not fully compatible with Windows, "
"you can set --use_multiprocess_reader=False if you "
"are training on Windows and there are errors incured "
"by multiprocess.")
print("multiprocess reader starting up, it takes a while...")
def infinite_reader():
while True:
for data in generator():
yield data
def reader():
cnt = 0
try:
enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=use_multiprocess_reader)
enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_out = None
while True:
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_out = enqueuer.queue.get()
break
else:
time.sleep(0.02)
yield generator_out
cnt += 1
if cnt >= total_iter:
enqueuer.stop()
return
generator_out = None
except Exception as e:
print("Exception occured in reader: {}".format(str(e)))
finally:
if enqueuer:
enqueuer.stop()
return reader
def test(size=416, batch_size=1):
return dsr.get_reader('test', size, batch_size)
def infer(size=416, image=None):
return dsr.get_reader('infer', size, image=image)
def get_label_infos():
return dsr.get_label_infos()
python -m paddle.distributed.launch --selected_gpu=0,1,2,3 --started_port=9999 train.py --batch_size=16 --use_data_parallel=1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
def set_paddle_flags(flags):
for key, value in flags.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
set_paddle_flags({
'FLAGS_eager_delete_tensor_gb': 0, # enable gc
'FLAGS_memory_fraction_of_eager_deletion': 1,
'FLAGS_fraction_of_gpu_memory_to_use': 0.98
})
import sys
import numpy as np
import random
import time
import shutil
import subprocess
from utility import (parse_args, print_arguments,
SmoothedValue, check_gpu)
import paddle
import paddle.fluid as fluid
import reader
from models.yolov3 import YOLOv3
from config import cfg
import dist_utils
from paddle.fluid.dygraph.base import to_variable
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
def get_device_num():
# NOTE(zcd): for multi-processe training, each process use one GPU card.
if num_trainers > 1:
return 1
return fluid.core.get_cuda_device_count()
def train():
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
devices_num = get_device_num() if cfg.use_gpu else 1
print("Found {} CUDA/CPU devices.".format(devices_num))
if cfg.debug or args.enable_ce:
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
random.seed(0)
np.random.seed(0)
if not os.path.exists(cfg.model_save_dir):
os.makedirs(cfg.model_save_dir)
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
model = YOLOv3(3, is_train=True)
if args.use_data_parallel:
model = fluid.dygraph.parallel.DataParallel(model, strategy)
if cfg.pretrain:
restore, _ = fluid.load_dygraph(cfg.pretrain)
model.blocks.set_dict(restore)
if cfg.finetune:
restore, _ = fluid.load_dygraph(cfg.finetune)
model.set_dict(restore)
boundaries = cfg.lr_steps
gamma = cfg.lr_gamma
step_num = len(cfg.lr_steps)
learning_rate = cfg.learning_rate
values = [learning_rate * (gamma ** i) for i in range(step_num + 1)]
lr = fluid.dygraph.PiecewiseDecay(
boundaries=boundaries,
values=values,
begin=args.start_iter)
lr = fluid.layers.linear_lr_warmup(
learning_rate=lr,
warmup_steps=cfg.warm_up_iter,
start_lr=0.0,
end_lr=cfg.learning_rate,
)
optimizer = fluid.optimizer.Momentum(
learning_rate=lr,
regularization=fluid.regularizer.L2Decay(cfg.weight_decay),
momentum=cfg.momentum,
parameter_list=model.parameters()
)
start_time = time.time()
snapshot_loss = 0
snapshot_time = 0
total_sample = 0
input_size = cfg.input_size
shuffle = True
shuffle_seed = None
total_iter = cfg.max_iter - cfg.start_iter
mixup_iter = total_iter - cfg.no_mixup_iter
random_sizes = [cfg.input_size]
if cfg.random_shape:
random_sizes = [32 * i for i in range(10,20)]
train_reader = reader.train(
input_size,
batch_size=cfg.batch_size,
shuffle=shuffle,
shuffle_seed=shuffle_seed,
total_iter=total_iter * devices_num,
mixup_iter=mixup_iter * devices_num,
random_sizes=random_sizes,
use_multiprocess_reader=cfg.use_multiprocess_reader,
num_workers=cfg.worker_num)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(train_reader)
smoothed_loss = SmoothedValue()
for iter_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
img = np.array([x[0] for x in data]).astype('float32')
img = to_variable(img)
gt_box = np.array([x[1] for x in data]).astype('float32')
gt_box = to_variable(gt_box)
gt_label = np.array([x[2] for x in data]).astype('int32')
gt_label = to_variable(gt_label)
gt_score = np.array([x[3] for x in data]).astype('float32')
gt_score = to_variable(gt_score)
loss = model(img, gt_box, gt_label, gt_score, None, None)
smoothed_loss.add_value(np.mean(loss.numpy()))
snapshot_loss += loss.numpy()
snapshot_time += start_time - prev_start_time
total_sample += 1
print("Iter {:d}, loss {:.6f}, time {:.5f}".format(
iter_id,
smoothed_loss.get_mean_value(),
start_time-prev_start_time))
if args.use_data_parallel:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters and iter_id > 1 and iter_id % cfg.snapshot_iter == 0:
fluid.save_dygraph(model.state_dict(), args.model_save_dir + "/yolov3_{}".format(iter_id))
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
train()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
Contains common utility functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import distutils.util
import numpy as np
import six
import ast
from collections import deque
import paddle.fluid as fluid
import argparse
import functools
from config import *
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self):
self.loss_sum = 0.0
self.iter_cnt = 0
def add_value(self, value):
self.loss_sum += np.mean(value)
self.iter_cnt += 1
def get_mean_value(self):
return self.loss_sum / self.iter_cnt
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=True in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as True while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set --use_gpu=False to run model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
print(err)
sys.exit(1)
except Exception as e:
pass
def parse_args():
"""return all args
"""
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
# ENV
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('model_save_dir', str, 'checkpoints', "The path to save model.")
add_arg('pretrain', str, 'weights/darknet53', "The pretrain model path.")
add_arg('finetune', str, False, "The finetune model path.")
add_arg('weights', str, 'weights/yolov3', "The weights path.")
add_arg('dataset', str, 'coco2017', "Dataset: coco2014, coco2017.")
add_arg('class_num', int, 80, "Class number.")
add_arg('data_dir', str, 'dataset/coco', "The data root path.")
add_arg('start_iter', int, 0, "Start iteration.")
add_arg('use_multiprocess_reader', bool, True, "whether use multiprocess reader.")
add_arg('worker_num', int, 8, "worker number for multiprocess reader.")
add_arg('use_data_parallel', ast.literal_eval, False, "the flag indicating whether to use data parallel model to train the model")
#SOLVER
add_arg('batch_size', int, 8, "Mini-batch size per device.")
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('max_iter', int, 500200, "Iter number.")
add_arg('snapshot_iter', int, 2000, "Save model every snapshot stride.")
add_arg('label_smooth', bool, True, "Use label smooth in class label.")
add_arg('no_mixup_iter', int, 40000, "Disable mixup in last N iter.")
# TRAIN TEST INFER
add_arg('input_size', int, 608, "Image input size of YOLOv3.")
add_arg('syncbn', bool, True, "Whether to use synchronized batch normalization.")
add_arg('random_shape', bool, True, "Resize to random shape for train reader.")
add_arg('valid_thresh', float, 0.005, "Valid confidence score for NMS.")
add_arg('nms_thresh', float, 0.45, "NMS threshold.")
add_arg('nms_topk', int, 400, "The number of boxes to perform NMS.")
add_arg('nms_posk', int, 100, "The number of boxes of NMS output.")
add_arg('debug', bool, False, "Debug mode")
# SINGLE EVAL AND DRAW
add_arg('image_path', str, 'image',
"The image path used to inference and visualize.")
add_arg('image_name', str, None,
"The single image used to inference and visualize. None to inference all images in image_path")
add_arg('draw_thresh', float, 0.5,
"Confidence score threshold to draw prediction box in image in debug mode")
add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
# yapf: enable
args = parser.parse_args()
file_name = sys.argv[0]
merge_cfg_from_args(args)
return args
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the pretrain weights.
echo "Downloading..."
wget https://paddlemodels.bj.bcebos.com/yolo/darknet53.pdparams
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册