未验证 提交 4dd4632b 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #1871 from heavengate/yolo_model

Yolo model
*.log
*.json
*.jpg
*.png
output/
checkpoints/
weights/
!weights/*.sh
dataset/coco/
log*
output*
# YOLO V3 Objective Detection
---
## Table of Contents
- [Installation](#installation)
- [Introduction](#introduction)
- [Data preparation](#data-preparation)
- [Training](#training)
- [Evaluation](#evaluation)
- [Inference and Visualization](#inference-and-visualization)
- [Appendix](#appendix)
## Installation
Running sample code in this directory requires PaddelPaddle Fluid v.1.4 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/install_doc.html#paddlepaddle) and make an update.
## Introduction
[YOLOv3](https://arxiv.org/abs/1804.02767) is a one stage end to end detector。the detection principle of YOLOv3 is as follow:
<p align="center">
<img src="image/YOLOv3.jpg" height=400 width=600 hspace='10'/> <br />
YOLOv3 detection principle
</p>
YOLOv3 divides the input image in to S\*S grids and predict B bounding boxes in each grid, predictions of boxes include Location(x, y, w, h), Confidence Score and probabilities of C classes, therefore YOLOv3 output layer has S\*S\*B\*(5 + C) channels. YOLOv3 loss consists of three parts: location loss, confidence loss and classification loss.
The bone network of YOLOv3 is darknet53, the structure of YOLOv3 is as follow:
<p align="center">
<img src="image/YOLOv3_structure.jpg" height=400 width=400 hspace='10'/> <br />
YOLOv3 structure
</p>
YOLOv3 networks are composed of base feature extraction network, multi-scale feature fusion layers, and output layers.
1. Feature extraction network: YOLOv3 uses [DarkNet53](https://arxiv.org/abs/1612.08242) for feature extracion. Darknet53 uses a full convolution structure, replacing the pooling layer with a convolution operation of step size 2, and adding Residual-block to avoid gradient dispersion when the number of network layers is too deep.
2. Feature fusion layer. In order to solve the problem that the previous YOLO version is not sensitive to small objects, YOLOv3 uses three different scale feature maps for target detection, which are 13\*13, 26\*26, 52\*52, respectively, for detecting large, medium and small objects. The feature fusion layer selects the three scale feature maps produced by DarkNet as input, and draws on the idea of FPN (feature pyramid networks) to fuse the feature maps of each scale through a series of convolutional layers and upsampling.
3. Output layer: The output layer also uses a full convolution structure. The number of convolution kernels in the last convolutional layer is 255:3\*(80+4+1)=255, and 3 indicates that a grid cell contains 3 bounding boxes. 4 represents the four coordinate information of the box, 1 represents the Confidence Score, and 80 represents the probability of 80 categories in the COCO dataset.
## Data preparation
Train the model on [MS-COCO dataset](http://cocodataset.org/#download), download dataset as below:
cd dataset/coco
./download.sh
## Training
After data preparation, one can start the training step by:
python train.py \
--model_save_dir=output/ \
--pretrain=${path_to_pretrain_model}
--data_dir=${path_to_data}
- Set ```export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7``` to specifiy 8 GPU to train.
- For more help on arguments:
python train.py --help
**download the pre-trained model:** This sample provides Resnet-50 pre-trained model which is converted from Caffe. The model fuses the parameters in batch normalization layer. One can download pre-trained model as:
sh ./weights/download.sh
Set `pretrain` to load pre-trained model. In addition, this parameter is used to load trained model when finetuning as well.
Please make sure that pre-trained model is downloaded and loaded correctly, otherwise, the loss may be NAN during training.
**Install the [cocoapi](https://github.com/cocodataset/cocoapi):**
To train the model, [cocoapi](https://github.com/cocodataset/cocoapi) is needed. Install the cocoapi:
git clone https://github.com/cocodataset/cocoapi.git
cd PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python2 setup.py install --user
**data reader introduction:**
* Data reader is defined in `reader.py` .
**model configuration:**
* The model uses 9 anchors generated based on the COCO dataset, which are 10x13, 16x30, 33x23, 30x61, 62x45, 59x119, 116x90, 156x198, 373x326.
* NMS threshold=0.45, NMS valid=0.005 nms_topk=400, nms_posk=100
**training strategy:**
* Use momentum optimizer with momentum=0.9.
* In first 4000 iteration, the learning rate increases linearly from 0.0 to 0.01. Then lr is decayed at 450000, 500000 iteration with multiplier 0.1, 0.01. The maximum iteration is 500000.
Training result is shown as below:
<p align="center">
<img src="image/train_loss.png" height="500" width="650" hspace="10"/><br />
Train Loss
</p>
## Evaluation
Evaluation is to evaluate the performance of a trained model. This sample provides `eval.py` which uses a COCO-specific mAP metric defined by [COCO committee](http://cocodataset.org/#detections-eval).
`eval.py` is the main executor for evalution, one can start evalution step by:
python eval.py \
--dataset=coco2017 \
--weights=${path_to_weights} \
- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to eval.
Evalutaion result is shown as below:
| input size | mAP(IoU=0.50:0.95) | mAP(IoU=0.50) | mAP(IoU=0.75) |
| :------: | :------: | :------: | :------: |
| 608x608| 37.7 | 59.8 | 40.8 |
| 416x416 | 36.5 | 58.2 | 39.1 |
| 320x320 | 34.1 | 55.4 | 36.3 |
## Inference and Visualization
Inference is used to get prediction score or image features based on trained models. `infer.py` is the main executor for inference, one can start infer step by:
python infer.py \
--dataset=coco2017 \
--weights=${path_to_weights} \
--image_path=data/COCO17/val2017/ \
--image_name=000000000139.jpg \
--draw_threshold=0.5
Inference speed:
| input size | 608x608 | 416x416 | 320x320 |
|:-------------:| :-----: | :-----: | :-----: |
| infer speed | 50 ms/frame | 29 ms/frame |24 ms/frame |
Visualization of infer result is shown as below:
<p align="center">
<img src="image/000000000139.png" height=300 width=400 hspace='10'/>
<img src="image/000000127517.png" height=300 width=400 hspace='10'/>
<img src="image/000000203864.png" height=300 width=400 hspace='10'/>
<img src="image/000000515077.png" height=300 width=400 hspace='10'/> <br />
YOLOv3 Visualization Examples
</p>
# YOLO V3 目标检测
---
## 内容
- [安装](#安装)
- [简介](#简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断及可视化](#模型推断及可视化)
- [附录](#附录)
## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.4或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。
## 简介
[YOLOv3](https://arxiv.org/abs/1804.02767) 是一阶段End2End的目标检测器。其目标检测原理如下图所示:
<p align="center">
<img src="image/YOLOv3.jpg" height=400 width=600 hspace='10'/> <br />
YOLOv3检测原理
</p>
YOLOv3将输入图像分成S\*S个格子,每个格子预测B个bounding box,每个bounding box预测内容包括: Location(x, y, w, h)、Confidence Score和C个类别的概率,因此YOLOv3输出层的channel数为S\*S\*B\*(5 + C)。YOLOv3的loss函数也有三部分组成:Location误差,Confidence误差和分类误差。
YOLOv3的网络结构如下图所示:
<p align="center">
<img src="image/YOLOv3_structure.jpg" height=400 width=400 hspace='10'/> <br />
YOLOv3网络结构
</p>
YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层和输出层组成。
1. 特征提取网络。YOLOv3使用 [DarkNet53](https://arxiv.org/abs/1612.08242)作为特征提取网络:DarkNet53 基本采用了全卷积网络,用步长为2的卷积操作替代了池化层,同时添加了 Residual 单元,避免在网络层数过深时发生梯度弥散。
2. 特征融合层。为了解决之前YOLO版本对小目标不敏感的问题,YOLOv3采用了3个不同尺度的特征图来进行目标检测,分别为13\*13,26\*26,52\*52,用来检测大、中、小三种目标。特征融合层选取 DarkNet 产出的三种尺度特征图作为输入,借鉴了FPN(feature pyramid networks)的思想,通过一系列的卷积层和上采样对各尺度的特征图进行融合。
3. 输出层。同样使用了全卷积结构,其中最后一个卷积层的卷积核个数是255:3\*(80+4+1)=255,3表示一个grid cell包含3个bounding box,4表示框的4个坐标信息,1表示Confidence Score,80表示COCO数据集中80个类别的概率。
## 数据准备
[MS-COCO数据集](http://cocodataset.org/#download)上进行训练,通过如下方式下载数据集。
cd dataset/coco
./download.sh
## 模型训练
数据准备完毕后,可以通过如下的方式启动训练:
python train.py \
--model_save_dir=output/ \
--pretrain=${path_to_pretrain_model}
--data_dir=${path_to_data}
- 通过设置export CUDA\_VISIBLE\_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。
- 可选参数见:
python train.py --help
**下载预训练模型:** 本示例提供darknet53预训练模型,该模型转换自作者提供的darknet53在ImageNet上预训练的权重,采用如下命令下载预训练模型:
sh ./weights/download_pretrained_weight.sh
通过初始化`pretrain` 加载预训练模型。同时在参数微调时也采用该设置加载已训练模型。
请在训练前确认预训练模型下载与加载正确,否则训练过程中损失可能会出现NAN。
**安装[cocoapi](https://github.com/cocodataset/cocoapi):**
训练前需要首先下载[cocoapi](https://github.com/cocodataset/cocoapi)
git clone https://github.com/cocodataset/cocoapi.git
cd PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python2 setup.py install --user
**数据读取器说明:**
* 数据读取器定义在reader.py中。
**模型设置:**
* 模型使用了基于COCO数据集生成的9个先验框:10x13,16x30,33x23,30x61,62x45,59x119,116x90,156x198,373x326
* 检测过程中,nms_topk=400, nms_posk=100,nms_thresh=0.45
**训练策略:**
* 采用momentum优化算法训练YOLOv3,momentum=0.9。
* 学习率采用warmup算法,前4000轮学习率从0.0线性增加至0.01。在400000,450000轮时使用0.1,0.01乘子进行学习率衰减,最大训练500000轮。
下图为模型训练结果:
<p align="center">
<img src="image/train_loss.png" height="400" width="550" hspace="10"/><br />
Train Loss
</p>
## 模型评估
模型评估是指对训练完毕的模型评估各类性能指标。本示例采用[COCO官方评估](http://cocodataset.org/#detections-eval)
`eval.py`是评估模块的主要执行程序,调用示例如下:
python eval.py \
--dataset=coco2017 \
--weights=${path_to_weights} \
- 通过设置export CUDA\_VISIBLE\_DEVICES=0指定单卡GPU评估。
模型评估结果:
| input size | mAP(IoU=0.50:0.95) | mAP(IoU=0.50) | mAP(IoU=0.75) |
| :------: | :------: | :------: | :------: |
| 608x608| 37.7 | 59.8 | 40.8 |
| 416x416 | 36.5 | 58.2 | 39.1 |
| 320x320 | 34.1 | 55.4 | 36.3 |
## 模型推断及可视化
模型推断可以获取图像中的物体及其对应的类别,`infer.py`是主要执行程序,调用示例如下:
python infer.py \
--dataset=coco2017 \
--weights=${path_to_weights} \
--image_path=data/COCO17/val2017/ \
--image_name=000000000139.jpg \
--draw_threshold=0.5
模型预测速度:
| input size | 608x608 | 416x416 | 320x320 |
|:-------------:| :-----: | :-----: | :-----: |
| infer speed | 50 ms/frame | 29 ms/frame |24 ms/frame |
下图为模型可视化预测结果:
<p align="center">
<img src="image/000000000139.png" height=300 width=400 hspace='10'/>
<img src="image/000000127517.png" height=300 width=400 hspace='10'/>
<img src="image/000000203864.png" height=300 width=400 hspace='10'/>
<img src="image/000000515077.png" height=300 width=400 hspace='10'/> <br />
YOLOv3 预测可视化
</p>
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from PIL import Image
def coco_anno_box_to_center_relative(box, img_height, img_width):
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert len(box) == 4, "box should be a len(4) list or tuple"
x, y, w, h = box
x1 = max(x, 0)
x2 = min(x + w - 1, img_width - 1)
y1 = max(y, 0)
y2 = min(y + h - 1, img_height - 1)
x = (x1 + x2) / 2 / img_width
y = (y1 + y2) / 2 / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
return np.array([x, y, w, h])
def clip_relative_box_in_image(x, y, w, h):
"""Clip relative box coordinates x, y, w, h to [0, 1]"""
x1 = max(x - w / 2, 0.)
x2 = min(x + w / 2, 1.)
y1 = min(y - h / 2, 0.)
y2 = max(y + h / 2, 1.)
x = (x1 + x2) / 2
y = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
def box_xywh_to_xyxy(box):
shape = box.shape
assert shape[-1] == 4, "Box shape[-1] should be 4."
box = box.reshape((-1, 4))
box[:, 0], box[:, 2] = box[:, 0] - box[:, 2] / 2, box[:, 0] + box[:, 2] / 2
box[:, 1], box[:, 3] = box[:, 1] - box[:, 3] / 2, box[:, 1] + box[:, 3] / 2
box = box.reshape(shape)
return box
def box_iou_xywh(box1, box2):
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.minimum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1 + 1
inter_h = inter_y2 - inter_y1 + 1
inter_w[inter_w < 0] = 0
inter_h[inter_h < 0] = 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
return inter_area / (b1_area + b2_area - inter_area)
def box_iou_xyxy(box1, box2):
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.minimum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1
inter_h = inter_y2 - inter_y1
inter_w[inter_w < 0] = 0
inter_h[inter_h < 0] = 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
return inter_area / (b1_area + b2_area - inter_area)
def box_crop(boxes, labels, scores, crop, img_shape):
x, y, w, h = map(float, crop)
im_w, im_h = map(float, img_shape)
boxes = boxes.copy()
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (boxes[:, 0] + boxes[:, 2] / 2) * im_w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (boxes[:, 1] + boxes[:, 3] / 2) * im_h
crop_box = np.array([x, y, x + w, y + h])
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(axis=1)
boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2])
boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:])
boxes[:, :2] -= crop_box[:2]
boxes[:, 2:] -= crop_box[:2]
mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1))
boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1)
labels = labels * mask.astype('float32')
scores = scores * mask.astype('float32')
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (boxes[:, 2] - boxes[:, 0]) / w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (boxes[:, 3] - boxes[:, 1]) / h
return boxes, labels, scores, mask.sum()
def draw_boxes_on_image(image_path, boxes, scores, labels, label_names, score_thresh=0.5):
image = np.array(Image.open(image_path))
plt.figure()
_, ax = plt.subplots(1)
ax.imshow(image)
image_name = image_path.split('/')[-1]
print("Image {} detect: ".format(image_name))
colors = {}
for box, score, label in zip(boxes, scores, labels):
if score < score_thresh:
continue
if box[2] <= box[0] or box[3] <= box[1]:
continue
label = int(label)
if label not in colors:
colors[label] = plt.get_cmap('hsv')(label / len(label_names))
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
fill=False, linewidth=2.0,
edgecolor=colors[label])
ax.add_patch(rect)
ax.text(x1, y1, '{} {:.4f}'.format(label_names[label], score),
verticalalignment='bottom', horizontalalignment='left',
bbox={'facecolor': colors[label], 'alpha': 0.5, 'pad': 0},
fontsize=8, color='white')
print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(label)], map(int, list(box)), score))
image_name = image_name.replace('jpg', 'png')
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0)
print("Detect result save at ./output/{}\n".format(image_name))
plt.cla()
plt.close('all')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from edict import AttrDict
import six
import numpy as np
_C = AttrDict()
cfg = _C
#
# Training options
#
# Snapshot period
_C.snapshot_iter = 2000
# min valid area for gt boxes
_C.gt_min_area = -1
# max target box number in an image
_C.max_box_num = 50
#
# Training options
#
# valid score threshold to include boxes
_C.valid_thresh = 0.005
# threshold vale for box non-max suppression
_C.nms_thresh = 0.45
# the number of top k boxes to perform nms
_C.nms_topk = 400
# the number of output boxes after nms
_C.nms_posk = 100
# score threshold for draw box in debug mode
_C.draw_thresh = 0.5
#
# Model options
#
# pixel mean values
_C.pixel_means = [0.485, 0.456, 0.406]
# pixel std values
_C.pixel_stds = [0.229, 0.224, 0.225]
# anchors box weight and height
_C.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
# anchor mask of each yolo layer
_C.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
# IoU threshold to ignore objectness loss of pred box
_C.ignore_thresh = .7
#
# SOLVER options
#
# batch size
_C.batch_size = 8
# derived learning rate the to get the final learning rate.
_C.learning_rate = 0.001
# maximum number of iterations
_C.max_iter = 500200
# warm up to learning rate
_C.warm_up_iter = 4000
_C.warm_up_factor = 0.
# lr steps_with_decay
_C.lr_steps = [400000, 450000]
_C.lr_gamma = 0.1
# L2 regularization hyperparameter
_C.weight_decay = 0.0005
# momentum with SGD
_C.momentum = 0.9
#
# ENV options
#
# support both CPU and GPU
_C.use_gpu = True
# Class number
_C.class_num = 80
# dataset path
_C.train_file_list = 'annotations/instances_train2017.json'
_C.train_data_dir = 'train2017'
_C.val_file_list = 'annotations/instances_val2017.json'
_C.val_data_dir = 'val2017'
def merge_cfg_from_args(args):
"""Merge config keys, values in args into the global config."""
for k, v in sorted(six.iteritems(vars(args))):
try:
value = eval(v)
except:
value = v
_C[k] = value
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
import time
import numpy as np
import threading
import multiprocessing
try:
import queue
except ImportError:
import Queue as queue
class GeneratorEnqueuer(object):
"""
Builds a queue out of a data generator.
Args:
generator: a generator function which endlessly yields data
use_multiprocessing (bool): use multiprocessing if True,
otherwise use threading.
wait_time (float): time to sleep in-between calls to `put()`.
random_seed (int): Initial seed for workers,
will be incremented by one for each workers.
"""
def __init__(self,
generator,
use_multiprocessing=False,
wait_time=0.05,
random_seed=None):
self.wait_time = wait_time
self._generator = generator
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self.queue = None
self._manager = None
self.seed = random_seed
def start(self, workers=1, max_queue_size=10):
"""
Start worker threads which add data from the generator into the queue.
Args:
workers (int): number of worker threads
max_queue_size (int): queue size
(when full, threads could block on `put()`)
"""
def data_generator_task():
"""
Data generator task.
"""
def task():
if (self.queue is not None and
self.queue.qsize() < max_queue_size):
generator_output = next(self._generator)
self.queue.put((generator_output))
else:
time.sleep(self.wait_time)
if not self._use_multiprocessing:
while not self._stop_event.is_set():
with self.genlock:
try:
task()
except Exception:
self._stop_event.set()
break
else:
while not self._stop_event.is_set():
try:
task()
except Exception:
self._stop_event.set()
break
try:
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.genlock = threading.Lock()
self.queue = queue.Queue()
self._stop_event = threading.Event()
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.seed)
thread = multiprocessing.Process(target=data_generator_task)
thread.daemon = True
if self.seed is not None:
self.seed += 1
else:
thread = threading.Thread(target=data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise
def is_running(self):
"""
Returns:
bool: Whether the worker theads are running.
"""
return self._stop_event is not None and not self._stop_event.is_set()
def stop(self, timeout=None):
"""
Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called `start()`.
Args:
timeout(int|None): maximum time to wait on `thread.join()`.
"""
if self.is_running():
self._stop_event.set()
for thread in self._threads:
if self._use_multiprocessing:
if thread.is_alive():
thread.terminate()
else:
thread.join(timeout)
if self._manager:
self._manager.shutdown()
self._threads = []
self._stop_event = None
self.queue = None
def get(self):
"""
Creates a generator to extract data from the queue.
Skip the data if it is `None`.
# Yields
tuple of data in the queue.
"""
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the data.
echo "Downloading..."
wget http://images.cocodataset.org/zips/train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
# Extract the data.
echo "Extracting..."
unzip train2014.zip
unzip val2014.zip
unzip train2017.zip
unzip val2017.zip
unzip annotations_trainval2014.zip
unzip annotations_trainval2017.zip
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import json
import numpy as np
import paddle
import paddle.fluid as fluid
import reader
from models.yolov3 import YOLOv3
from utility import print_arguments, parse_args
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config import cfg
def eval():
if '2014' in cfg.dataset:
test_list = 'annotations/instances_val2014.json'
elif '2017' in cfg.dataset:
test_list = 'annotations/instances_val2017.json'
if cfg.debug:
if not os.path.exists('output'):
os.mkdir('output')
model = YOLOv3(is_train=False)
model.build_model()
outputs = model.get_pred()
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if cfg.weights:
def if_exist(var):
return os.path.exists(os.path.join(cfg.weights, var.name))
fluid.io.load_vars(exe, cfg.weights, predicate=if_exist)
# yapf: enable
input_size = cfg.input_size
test_reader = reader.test(input_size, 1)
label_names, label_ids = reader.get_label_infos()
if cfg.debug:
print("Load in labels {} with ids {}".format(label_names, label_ids))
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
def get_pred_result(boxes, scores, labels, im_id):
result = []
for box, score, label in zip(boxes, scores, labels):
x1, y1, x2, y2 = box
w = x2 - x1 + 1
h = y2 - y1 + 1
bbox = [x1, y1, w, h]
res = {
'image_id': im_id,
'category_id': label_ids[int(label)],
'bbox': map(float, bbox),
'score': float(score)
}
result.append(res)
return result
dts_res = []
fetch_list = [outputs]
total_time = 0
for batch_id, batch_data in enumerate(test_reader()):
start_time = time.time()
batch_outputs = exe.run(
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(batch_data),
return_numpy=False,
use_program_cache=True)
lod = batch_outputs[0].lod()[0]
nmsed_boxes = np.array(batch_outputs[0])
if nmsed_boxes.shape[1] != 6:
continue
for i in range(len(lod) - 1):
im_id = batch_data[i][1]
start = lod[i]
end = lod[i + 1]
if start == end:
continue
nmsed_box = nmsed_boxes[start:end, :]
labels = nmsed_box[:, 0]
scores = nmsed_box[:, 1]
boxes = nmsed_box[:, 2:6]
dts_res += get_pred_result(boxes, scores, labels, im_id)
end_time = time.time()
print("batch id: {}, time: {}".format(batch_id, end_time - start_time))
total_time += end_time - start_time
with open("yolov3_result.json", 'w') as outfile:
json.dump(dts_res, outfile)
print("start evaluate detection result with coco api")
coco = COCO(os.path.join(cfg.data_dir, test_list))
cocoDt = coco.loadRes("yolov3_result.json")
cocoEval = COCOeval(coco, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
print("evaluate done.")
print("Time per batch: {}".format(total_time / batch_id))
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
eval()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
from PIL import Image, ImageEnhance
import random
import box_utils
def random_distort(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = Image.fromarray(img)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
img = np.asarray(img)
return img
def random_crop(img, boxes, labels, scores, scales=[0.3, 1.0], max_ratio=2.0, constraints=None, max_trial=50):
if len(boxes) == 0:
return img, boxes
if not constraints:
constraints = [
(0.1, 1.0),
(0.3, 1.0),
(0.5, 1.0),
(0.7, 1.0),
(0.9, 1.0),
(0.0, 1.0)]
img = Image.fromarray(img)
w, h = img.size
crops = [(0, 0, w, h)]
for min_iou, max_iou in constraints:
for _ in range(max_trial):
scale = random.uniform(scales[0], scales[1])
aspect_ratio = random.uniform(max(1 / max_ratio, scale * scale), \
min(max_ratio, 1 / scale / scale))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_x = random.randrange(w - crop_w)
crop_y = random.randrange(h - crop_h)
crop_box = np.array([[
(crop_x + crop_w / 2.0) / w,
(crop_y + crop_h / 2.0) / h,
crop_w / float(w),
crop_h /float(h)
]])
iou = box_utils.box_iou_xywh(crop_box, boxes)
if min_iou <= iou.min() and max_iou >= iou.max():
crops.append((crop_x, crop_y, crop_w, crop_h))
break
while crops:
crop = crops.pop(np.random.randint(0, len(crops)))
crop_boxes, crop_labels, crop_scores, box_num = box_utils.box_crop(boxes, labels, scores, crop, (w, h))
if box_num < 1:
continue
img = img.crop((crop[0], crop[1], crop[0] + crop[2], crop[1] + crop[3])).resize(img.size, Image.LANCZOS)
img = np.asarray(img)
return img, crop_boxes, crop_labels, crop_scores
img = np.asarray(img)
return img, boxes, labels, scores
def random_flip(img, gtboxes, thresh=0.5):
if random.random() > thresh:
img = img[:, ::-1, :]
gtboxes[:, 0] = 1.0 - gtboxes[:, 0]
return img, gtboxes
def random_interp(img, size, interp=None):
interp_method = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
]
if not interp or interp not in interp_method:
interp = interp_method[random.randint(0, len(interp_method) - 1)]
h, w, _ = img.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp)
return img
def random_expand(img, gtboxes, max_ratio=4., fill=None, keep_ratio=True, thresh=0.5):
if random.random() > thresh:
return img, gtboxes
if max_ratio < 1.0:
return img, gtboxes
h, w, c = img.shape
ratio_x = random.uniform(1, max_ratio)
if keep_ratio:
ratio_y = ratio_x
else:
ratio_y = random.uniform(1, max_ratio)
oh = int(h * ratio_y)
ow = int(w * ratio_x)
off_x = random.randint(0, ow -w)
off_y = random.randint(0, oh -h)
out_img = np.zeros((oh, ow, c))
if fill and len(fill) == c:
for i in range(c):
out_img[:, :, i] = fill[i] * 255.0
out_img[off_y: off_y + h, off_x: off_x + w, :] = img
gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow)
gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh)
gtboxes[:, 2] = gtboxes[:, 2] / ratio_x
gtboxes[:, 3] = gtboxes[:, 3] / ratio_y
return out_img.astype('uint8'), gtboxes
def shuffle_gtbox(gtbox, gtlabel, gtscore):
gt = np.concatenate([gtbox, gtlabel[:, np.newaxis], gtscore[:, np.newaxis]], axis=1)
idx = np.arange(gt.shape[1])
np.random.shuffle(idx)
gt = gt[idx, :]
return gt[:, :4], gt[:, 4], gt[:, 5]
def image_mixup(img1, gtboxes1, gtlabels1, gtscores1, img2, gtboxes2, gtlabels2, gtscores2):
factor = np.random.beta(1.5, 1.5)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
return img1, gtboxes1, gtlabels1
if factor <= 0.0:
return img2, gtboxes2, gtlabels2
gtscores1 = gtscores1 * factor
gtscores2 = gtscores2 * (1.0 - factor)
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
img = np.zeros((h, w, img1.shape[2]), 'float32')
img[:img1.shape[0], :img1.shape[1], :] = img1.astype('float32') * factor
img[:img2.shape[0], :img2.shape[1], :] += img2.astype('float32') * (1.0 - factor)
gtboxes = np.zeros_like(gtboxes1)
gtlabels = np.zeros_like(gtlabels1)
gtscores = np.zeros_like(gtscores1)
gt_valid_mask1 = np.logical_and(gtboxes1[:, 2] > 0, gtboxes1[:, 3] > 0)
gtboxes1 = gtboxes1[gt_valid_mask1]
gtlabels1 = gtlabels1[gt_valid_mask1]
gtscores1 = gtscores1[gt_valid_mask1]
gtboxes1[:, 0] = gtboxes1[:, 0] * img1.shape[1] / w
gtboxes1[:, 1] = gtboxes1[:, 1] * img1.shape[0] / h
gtboxes1[:, 2] = gtboxes1[:, 2] * img1.shape[1] / w
gtboxes1[:, 3] = gtboxes1[:, 3] * img1.shape[0] / h
gt_valid_mask2 = np.logical_and(gtboxes2[:, 2] > 0, gtboxes2[:, 3] > 0)
gtboxes2 = gtboxes2[gt_valid_mask2]
gtlabels2 = gtlabels2[gt_valid_mask2]
gtscores2 = gtscores2[gt_valid_mask2]
gtboxes2[:, 0] = gtboxes2[:, 0] * img2.shape[1] / w
gtboxes2[:, 1] = gtboxes2[:, 1] * img2.shape[0] / h
gtboxes2[:, 2] = gtboxes2[:, 2] * img2.shape[1] / w
gtboxes2[:, 3] = gtboxes2[:, 3] * img2.shape[0] / h
gtboxes_all = np.concatenate((gtboxes1, gtboxes2), axis=0)
gtlabels_all = np.concatenate((gtlabels1, gtlabels2), axis=0)
gtscores_all = np.concatenate((gtscores1, gtscores2), axis=0)
gt_num = min(len(gtboxes), len(gtboxes_all))
gtboxes[:gt_num] = gtboxes_all[:gt_num]
gtlabels[:gt_num] = gtlabels_all[:gt_num]
gtscores[:gt_num] = gtscores_all[:gt_num]
return img.astype('uint8'), gtboxes, gtlabels, gtscores
def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
img = random_distort(img)
img, gtboxes = random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels, gtscores = random_crop(img, gtboxes, gtlabels, gtscores)
img = random_interp(img, size)
img, gtboxes = random_flip(img, gtboxes)
gtboxes, gtlabels, gtscores = shuffle_gtbox(gtboxes, gtlabels, gtscores)
return img.astype('float32'), gtboxes.astype('float32'), \
gtlabels.astype('int32'), gtscores.astype('float32')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import numpy as np
import paddle
import paddle.fluid as fluid
import box_utils
import reader
from utility import print_arguments, parse_args
from models.yolov3 import YOLOv3
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config import cfg
def infer():
if not os.path.exists('output'):
os.mkdir('output')
model = YOLOv3(is_train=False)
model.build_model()
outputs = model.get_pred()
input_size = cfg.input_size
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if cfg.weights:
def if_exist(var):
return os.path.exists(os.path.join(cfg.weights, var.name))
fluid.io.load_vars(exe, cfg.weights, predicate=if_exist)
# yapf: enable
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
fetch_list = [outputs]
image_names = []
if cfg.image_name is not None:
image_names.append(cfg.image_name)
else:
for image_name in os.listdir(cfg.image_path):
if image_name.split('.')[-1] in ['jpg', 'png']:
image_names.append(image_name)
for image_name in image_names:
infer_reader = reader.infer(input_size, os.path.join(cfg.image_path, image_name))
label_names, _ = reader.get_label_infos()
data = next(infer_reader())
im_shape = data[0][2]
outputs = exe.run(
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data),
return_numpy=False)
bboxes = np.array(outputs[0])
if bboxes.shape[1] != 6:
print("No object found in {}".format(image_name))
continue
labels = bboxes[:, 0].astype('int32')
scores = bboxes[:, 1].astype('float32')
boxes = bboxes[:, 2:].astype('float32')
path = os.path.join(cfg.image_path, image_name)
box_utils.draw_boxes_on_image(path, boxes, scores, labels, label_names, cfg.draw_thresh)
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
infer()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler
from paddle.fluid.layers import control_flow
def exponential_with_warmup_decay(learning_rate, boundaries, values,
warmup_iter, warmup_factor):
global_step = lr_scheduler._decay_step_counter()
lr = fluid.layers.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_iter_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warmup_iter), force_cpu=True)
with control_flow.Switch() as switch:
with switch.case(global_step < warmup_iter_var):
alpha = global_step / warmup_iter_var
factor = warmup_factor * (1 - alpha) + alpha
decayed_lr = learning_rate * factor
fluid.layers.assign(decayed_lr, lr)
for i in range(len(boundaries)):
boundary_val = fluid.layers.fill_constant(
shape=[1],
dtype='float32',
value=float(boundaries[i]),
force_cpu=True)
value_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
fluid.layers.assign(value_var, lr)
last_value_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(values[len(values) - 1]))
with switch.default():
fluid.layers.assign(last_value_var, lr)
return lr
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.regularizer import L2Decay
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='leaky',
is_test=True,
name=None):
conv1 = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02),
name=name+".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
out = fluid.layers.batch_norm(
input=conv1,
act=None,
is_test=is_test,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02),
regularizer=L2Decay(0.),
name=bn_name + '.scale'),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.),
name=bn_name + '.offset'),
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
def downsample(input, ch_out, filter_size=3, stride=2, padding=1, is_test=True, name=None):
return conv_bn_layer(input,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
is_test=is_test,
name=name)
def basicblock(input, ch_out, is_test=True, name=None):
conv1 = conv_bn_layer(input, ch_out, 1, 1, 0, is_test=is_test, name=name+".0")
conv2 = conv_bn_layer(conv1, ch_out*2, 3, 1, 1, is_test=is_test, name=name+".1")
out = fluid.layers.elementwise_add(x=input, y=conv2, act=None)
return out
def layer_warp(block_func, input, ch_out, count, is_test=True, name=None):
res_out = block_func(input, ch_out, is_test=is_test, name='{}.0'.format(name))
for j in range(1, count):
res_out = block_func(res_out, ch_out, is_test=is_test, name='{}.{}'.format(name, j))
return res_out
DarkNet_cfg = {
53: ([1,2,8,8,4],basicblock)
}
def add_DarkNet53_conv_body(body_input, is_test=True):
stages, block_func = DarkNet_cfg[53]
stages = stages[0:5]
conv1 = conv_bn_layer(
body_input, ch_out=32, filter_size=3, stride=1, padding=1, is_test=is_test, name="yolo_input")
downsample_ = downsample(conv1, ch_out=conv1.shape[1]*2, is_test=is_test, name="yolo_input.downsample")
blocks = []
for i, stage in enumerate(stages):
block = layer_warp(block_func, downsample_, 32 *(2**i), stage, is_test=is_test, name="stage.{}".format(i))
blocks.append(block)
if i < len(stages) - 1: # do not downsaple in the last stage
downsample_ = downsample(block, ch_out=block.shape[1]*2, is_test=is_test, name="stage.{}.downsample".format(i))
return blocks[-1:-4:-1]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.initializer import Normal
from paddle.fluid.regularizer import L2Decay
from config import cfg
from .darknet import add_DarkNet53_conv_body
from .darknet import conv_bn_layer
def yolo_detection_block(input, channel, is_test=True, name=None):
assert channel % 2 == 0, "channel {} cannot be divided by 2".format(channel)
conv = input
for j in range(2):
conv = conv_bn_layer(conv, channel, filter_size=1, stride=1, padding=0, is_test=is_test, name='{}.{}.0'.format(name, j))
conv = conv_bn_layer(conv, channel*2, filter_size=3, stride=1, padding=1, is_test=is_test, name='{}.{}.1'.format(name, j))
route = conv_bn_layer(conv, channel, filter_size=1, stride=1, padding=0, is_test=is_test, name='{}.2'.format(name))
tip = conv_bn_layer(route,channel*2, filter_size=3, stride=1, padding=1, is_test=is_test, name='{}.tip'.format(name))
return route, tip
def upsample(input, scale=2,name=None):
# get dynamic upsample output shape
shape_nchw = fluid.layers.shape(input)
shape_hw = fluid.layers.slice(shape_nchw, axes=[0], starts=[2], ends=[4])
shape_hw.stop_gradient = True
in_shape = fluid.layers.cast(shape_hw, dtype='int32')
out_shape = in_shape * scale
out_shape.stop_gradient = True
# reisze by actual_shape
out = fluid.layers.resize_nearest(
input=input,
scale=scale,
actual_shape=out_shape,
name=name)
return out
class YOLOv3(object):
def __init__(self,
is_train=True,
use_random=True):
self.is_train = is_train
self.use_random = use_random
self.outputs = []
self.losses = []
self.downsample = 32
def build_input(self):
self.image_shape = [3, cfg.input_size, cfg.input_size]
if self.is_train:
self.py_reader = fluid.layers.py_reader(
capacity=64,
shapes = [[-1] + self.image_shape, [-1, cfg.max_box_num, 4], [-1, cfg.max_box_num], [-1, cfg.max_box_num]],
lod_levels=[0, 0, 0, 0],
dtypes=['float32'] * 2 + ['int32'] + ['float32'],
use_double_buffer=True)
self.image, self.gtbox, self.gtlabel, self.gtscore = fluid.layers.read_file(self.py_reader)
else:
self.image = fluid.layers.data(
name='image', shape=self.image_shape, dtype='float32'
)
self.im_shape = fluid.layers.data(
name="im_shape", shape=[2], dtype='int32')
self.im_id = fluid.layers.data(
name="im_id", shape=[1], dtype='int32')
def feeds(self):
if not self.is_train:
return [self.image, self.im_id, self.im_shape]
return [self.image, self.gtbox, self.gtlabel, self.gtscore]
def build_model(self):
self.build_input()
self.outputs = []
self.boxes = []
self.scores = []
blocks = add_DarkNet53_conv_body(self.image, not self.is_train)
for i, block in enumerate(blocks):
if i > 0:
block = fluid.layers.concat(
input=[route, block],
axis=1)
route, tip = yolo_detection_block(block, channel=512//(2**i),
is_test=(not self.is_train),
name="yolo_block.{}".format(i))
block_out = fluid.layers.conv2d(
input=tip,
num_filters=255,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02),
name="yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.),
name="yolo_output.{}.conv.bias".format(i)))
self.outputs.append(block_out)
if i < len(blocks) - 1:
route = conv_bn_layer(
input=route,
ch_out=256//(2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not self.is_train),
name="yolo_transition.{}".format(i))
# upsample
route = upsample(route)
for i, out in enumerate(self.outputs):
anchor_mask = cfg.anchor_masks[i]
if self.is_train:
ignore_thresh = float(self.ignore_thresh)
loss = fluid.layers.yolov3_loss(
x=out,
gtbox=self.gtbox,
gtlabel=self.gtlabel,
gtscore=self.gtscore,
anchors=cfg.anchors,
anchor_mask=anchor_mask,
class_num=cfg.class_num,
ignore_thresh=cfg.ignore_thresh,
downsample_ratio=self.downsample,
use_label_smooth=cfg.label_smooth,
name="yolo_loss"+str(i))
self.losses.append(fluid.layers.reduce_mean(loss))
else:
mask_anchors=[]
for m in anchor_mask:
mask_anchors.append(cfg.anchors[2 * m])
mask_anchors.append(cfg.anchors[2 * m + 1])
boxes, scores = fluid.layers.yolo_box(
x=out,
img_size=self.im_shape,
anchors=mask_anchors,
class_num=cfg.class_num,
conf_thresh=cfg.valid_thresh,
downsample_ratio=self.downsample,
name="yolo_box"+str(i))
self.boxes.append(boxes)
self.scores.append(fluid.layers.transpose(scores, perm=[0, 2, 1]))
self.downsample //= 2
def loss(self):
return sum(self.losses)
def get_pred(self):
yolo_boxes = fluid.layers.concat(self.boxes, axis=1)
yolo_scores = fluid.layers.concat(self.scores, axis=2)
return fluid.layers.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=cfg.valid_thresh,
nms_top_k=cfg.nms_topk,
keep_top_k=cfg.nms_posk,
nms_threshold=cfg.nms_thresh,
background_label=-1,
name="multiclass_nms")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import os
import random
import time
import copy
import cv2
import box_utils
import image_utils
from pycocotools.coco import COCO
from data_utils import GeneratorEnqueuer
from config import cfg
class DataSetReader(object):
"""A class for parsing and read COCO dataset"""
def __init__(self):
self.has_parsed_categpry = False
def _parse_dataset_dir(self, mode):
if 'coco2014' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2014.json'
cfg.train_data_dir = 'train2014'
cfg.val_file_list = 'annotations/instances_val2014.json'
cfg.val_data_dir = 'val2014'
elif 'coco2017' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2017.json'
cfg.train_data_dir = 'train2017'
cfg.val_file_list = 'annotations/instances_val2017.json'
cfg.val_data_dir = 'val2017'
else:
raise NotImplementedError('Dataset {} not supported'.format(
cfg.dataset))
if mode == 'train':
cfg.train_file_list = os.path.join(cfg.data_dir, cfg.train_file_list)
cfg.train_data_dir = os.path.join(cfg.data_dir, cfg.train_data_dir)
self.COCO = COCO(cfg.train_file_list)
self.img_dir = cfg.train_data_dir
elif mode == 'test' or mode == 'infer':
cfg.val_file_list = os.path.join(cfg.data_dir, cfg.val_file_list)
cfg.val_data_dir = os.path.join(cfg.data_dir, cfg.val_data_dir)
self.COCO = COCO(cfg.val_file_list)
self.img_dir = cfg.val_data_dir
def _parse_dataset_catagory(self):
self.categories = self.COCO.loadCats(self.COCO.getCatIds())
self.num_category = len(self.categories)
self.label_names = []
self.label_ids = []
for category in self.categories:
self.label_names.append(category['name'])
self.label_ids.append(int(category['id']))
self.category_to_id_map = {
v: i
for i, v in enumerate(self.label_ids)
}
print("Load in {} categories.".format(self.num_category))
self.has_parsed_categpry = True
def get_label_infos(self):
if not self.has_parsed_categpry:
self._parse_dataset_dir("test")
self._parse_dataset_catagory()
return (self.label_names, self.label_ids)
def _parse_gt_annotations(self, img):
img_height = img['height']
img_width = img['width']
anno = self.COCO.loadAnns(self.COCO.getAnnIds(imgIds=img['id'], iscrowd=None))
gt_index = 0
for target in anno:
if target['area'] < cfg.gt_min_area:
continue
if target.has_key('ignore') and target['ignore']:
continue
box = box_utils.coco_anno_box_to_center_relative(target['bbox'], img_height, img_width)
if box[2] <= 0 and box[3] <= 0:
continue
img['gt_id'][gt_index] = np.int32(target['id'])
img['gt_boxes'][gt_index] = box
img['gt_labels'][gt_index] = self.category_to_id_map[target['category_id']]
gt_index += 1
if gt_index >= cfg.max_box_num:
break
def _parse_images(self, is_train):
image_ids = self.COCO.getImgIds()
image_ids.sort()
imgs = copy.deepcopy(self.COCO.loadImgs(image_ids))
for img in imgs:
img['image'] = os.path.join(self.img_dir, img['file_name'])
assert os.path.exists(img['image']), \
"image {} not found.".format(img['image'])
box_num = cfg.max_box_num
img['gt_id'] = np.zeros((cfg.max_box_num), dtype=np.int32)
img['gt_boxes'] = np.zeros((cfg.max_box_num, 4), dtype=np.float32)
img['gt_labels'] = np.zeros((cfg.max_box_num), dtype=np.int32)
for k in ['date_captured', 'url', 'license', 'file_name']:
if img.has_key(k):
del img[k]
if is_train:
self._parse_gt_annotations(img)
print("Loaded {0} images from {1}.".format(len(imgs), cfg.dataset))
return imgs
def _parse_images_by_mode(self, mode):
if mode == 'infer':
return []
else:
return self._parse_images(is_train=(mode=='train'))
def get_reader(self, mode, size=416, batch_size=None, shuffle=False, mixup_iter=0, random_sizes=[], image=None):
assert mode in ['train', 'test', 'infer'], "Unknow mode type!"
if mode != 'infer':
assert batch_size is not None, "batch size connot be None in mode {}".format(mode)
self._parse_dataset_dir(mode)
self._parse_dataset_catagory()
def img_reader(img, size, mean, std):
im_path = img['image']
im = cv2.imread(im_path).astype('float32')
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w, _ = im.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (out_img / 255.0 - mean) / std
out_img = out_img.transpose((2, 0, 1))
return (out_img, int(img['id']), (h, w))
def img_reader_with_augment(img, size, mean, std, mixup_img):
im_path = img['image']
im = cv2.imread(im_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
gt_boxes = img['gt_boxes'].copy()
gt_labels = img['gt_labels'].copy()
gt_scores = np.ones_like(gt_labels)
if mixup_img:
mixup_im = cv2.imread(mixup_img['image'])
mixup_im = cv2.cvtColor(mixup_im, cv2.COLOR_BGR2RGB)
mixup_gt_boxes = np.array(mixup_img['gt_boxes']).copy()
mixup_gt_labels = np.array(mixup_img['gt_labels']).copy()
mixup_gt_scores = np.ones_like(mixup_gt_labels)
im, gt_boxes, gt_labels, gt_scores = image_utils.image_mixup(im, gt_boxes, \
gt_labels, gt_scores, mixup_im, mixup_gt_boxes, mixup_gt_labels, \
mixup_gt_scores)
im, gt_boxes, gt_labels, gt_scores = image_utils.image_augment(im, gt_boxes, gt_labels, gt_scores, size, mean)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (im / 255.0 - mean) / std
out_img = out_img.astype('float32').transpose((2, 0, 1))
return (out_img, gt_boxes, gt_labels, gt_scores)
def get_img_size(size, random_sizes=[]):
if len(random_sizes):
return np.random.choice(random_sizes)
return size
def get_mixup_img(imgs, mixup_iter, total_iter, read_cnt):
if total_iter >= mixup_iter:
return None
mixup_idx = np.random.randint(1, len(imgs))
mixup_img = imgs[(read_cnt + mixup_idx) % len(imgs)]
return mixup_img
def reader():
if mode == 'train':
imgs = self._parse_images_by_mode(mode)
if shuffle:
np.random.shuffle(imgs)
read_cnt = 0
total_iter = 0
batch_out = []
img_size = get_img_size(size, random_sizes)
while True:
img = imgs[read_cnt % len(imgs)]
mixup_img = get_mixup_img(imgs, mixup_iter, total_iter, read_cnt)
read_cnt += 1
if read_cnt % len(imgs) == 0 and shuffle:
np.random.shuffle(imgs)
im, gt_boxes, gt_labels, gt_scores = img_reader_with_augment(img, img_size, cfg.pixel_means, cfg.pixel_stds, mixup_img)
batch_out.append([im, gt_boxes, gt_labels, gt_scores])
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
total_iter += 1
img_size = get_img_size(size, random_sizes)
elif mode == 'test':
imgs = self._parse_images_by_mode(mode)
batch_out = []
for img in imgs:
im, im_id, im_shape = img_reader(img, size, cfg.pixel_means, cfg.pixel_stds)
batch_out.append((im, im_id, im_shape))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
if len(batch_out) != 0:
yield batch_out
else:
img = {}
img['image'] = image
img['id'] = 0
im, im_id, im_shape = img_reader(img, size, cfg.pixel_means, cfg.pixel_stds)
batch_out = [(im, im_id, im_shape)]
yield batch_out
return reader
dsr = DataSetReader()
def train(size=416,
batch_size=64,
shuffle=True,
mixup_iter=0,
random_sizes=[],
num_workers=12,
max_queue=36,
use_multiprocessing=True):
generator = dsr.get_reader('train', size, batch_size, shuffle, int(mixup_iter/num_workers), random_sizes)
if not use_multiprocessing:
return generator
def infinite_reader():
while True:
for data in generator():
yield data
def reader():
try:
enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=use_multiprocessing)
enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_out = None
while True:
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_out = enqueuer.queue.get()
break
else:
time.sleep(0.02)
yield generator_out
generator_out = None
finally:
if enqueuer is not None:
enqueuer.stop()
return reader
def test(size=416, batch_size=1):
return dsr.get_reader('test', size, batch_size)
def infer(size=416, image=None):
return dsr.get_reader('infer', size, image=image)
def get_label_infos():
return dsr.get_label_infos()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
import random
import time
import shutil
from utility import parse_args, print_arguments, SmoothedValue
import paddle
import paddle.fluid as fluid
import reader
from models.yolov3 import YOLOv3
from learning_rate import exponential_with_warmup_decay
from config import cfg
def train():
if cfg.debug:
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
random.seed(0)
np.random.seed(0)
if not os.path.exists(cfg.model_save_dir):
os.makedirs(cfg.model_save_dir)
model = YOLOv3()
model.build_model()
input_size = cfg.input_size
loss = model.loss()
loss.persistable = True
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
print("Found {} CUDA devices.".format(devices_num))
learning_rate = cfg.learning_rate
boundaries = cfg.lr_steps
gamma = cfg.lr_gamma
step_num = len(cfg.lr_steps)
values = [learning_rate * (gamma**i) for i in range(step_num + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=exponential_with_warmup_decay(
learning_rate=learning_rate,
boundaries=boundaries,
values=values,
warmup_iter=cfg.warm_up_iter,
warmup_factor=cfg.warm_up_factor),
regularization=fluid.regularizer.L2Decay(cfg.weight_decay),
momentum=cfg.momentum)
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if cfg.pretrain:
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrain, var.name))
fluid.io.load_vars(exe, cfg.pretrain, predicate=if_exist)
compile_program = fluid.compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=loss.name)
random_sizes = [cfg.input_size]
if cfg.random_shape:
random_sizes = [32 * i for i in range(10, 20)]
mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter
train_reader = reader.train(input_size, batch_size=cfg.batch_size, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess)
py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader)
def save_model(postfix):
model_path = os.path.join(cfg.model_save_dir, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
fluid.io.save_persistables(exe, model_path)
fetch_list = [loss]
py_reader.start()
smoothed_loss = SmoothedValue()
try:
start_time = time.time()
prev_start_time = start_time
snapshot_loss = 0
snapshot_time = 0
for iter_id in range(cfg.start_iter, cfg.max_iter):
prev_start_time = start_time
start_time = time.time()
losses = exe.run(compile_program, fetch_list=[v.name for v in fetch_list])
smoothed_loss.add_value(np.mean(np.array(losses[0])))
snapshot_loss += np.mean(np.array(losses[0]))
snapshot_time += start_time - prev_start_time
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format(
iter_id, lr[0],
smoothed_loss.get_mean_value(), start_time - prev_start_time))
sys.stdout.flush()
if (iter_id + 1) % cfg.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
print("Snapshot {} saved, average loss: {}, average time: {}".format(
iter_id + 1, snapshot_loss / float(cfg.snapshot_iter),
snapshot_time / float(cfg.snapshot_iter)))
snapshot_loss = 0
snapshot_time = 0
except fluid.core.EOFException:
py_reader.reset()
save_model('model_final')
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
train()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
Contains common utility functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import distutils.util
import numpy as np
import six
from collections import deque
from paddle.fluid import core
import argparse
import functools
from config import *
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self):
self.loss_sum = 0.0
self.iter_cnt = 0
def add_value(self, value):
self.loss_sum += np.mean(value)
self.iter_cnt += 1
def get_mean_value(self):
return self.loss_sum / self.iter_cnt
def parse_args():
"""return all args
"""
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
# ENV
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('model_save_dir', str, 'checkpoints', "The path to save model.")
add_arg('pretrain', str, 'weights/darknet53', "The pretrain model path.")
add_arg('weights', str, 'weights/yolov3', "The weights path.")
add_arg('dataset', str, 'coco2017', "Dataset: coco2014, coco2017.")
add_arg('class_num', int, 80, "Class number.")
add_arg('data_dir', str, 'dataset/coco', "The data root path.")
add_arg('start_iter', int, 0, "Start iteration.")
add_arg('use_multiprocess', bool, True, "add multiprocess.")
#SOLVER
add_arg('batch_size', int, 8, "Mini-batch size per device.")
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('max_iter', int, 500200, "Iter number.")
add_arg('snapshot_iter', int, 2000, "Save model every snapshot stride.")
add_arg('label_smooth', bool, True, "Use label smooth in class label.")
add_arg('no_mixup_iter', int, 40000, "Disable mixup in last N iter.")
# TRAIN TEST INFER
add_arg('input_size', int, 608, "Image input size of YOLOv3.")
add_arg('random_shape', bool, True, "Resize to random shape for train reader.")
add_arg('valid_thresh', float, 0.005, "Valid confidence score for NMS.")
add_arg('nms_thresh', float, 0.45, "NMS threshold.")
add_arg('nms_topk', int, 400, "The number of boxes to perform NMS.")
add_arg('nms_posk', int, 100, "The number of boxes of NMS output.")
add_arg('debug', bool, False, "Debug mode")
# SINGLE EVAL AND DRAW
add_arg('image_path', str, 'image', "The image path used to inference and visualize.")
add_arg('image_name', str, None, "The single image used to inference and visualize. None to inference all images in image_path")
add_arg('draw_thresh', float, 0.5, "Confidence score threshold to draw prediction box in image in debug mode")
# yapf: enable
args = parser.parse_args()
file_name = sys.argv[0]
merge_cfg_from_args(args)
return args
DIR="$(dirname "$PWD -P")"
cd "$DIR"
# Download the pretrain weights.
echo "Downloading..."
wget https://paddlemodels.bj.bcebos.com/yolo/darknet53.tar.gz
wget https://paddlemodels.bj.bcebos.com/yolo/yolov3.tar.gz
echo "Extracting..."
tar -xf darknet53.tar.gz
tar -xf yolov3.tar.gz
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册