提交 2c3a3b36 编写于 作者: D dengkaipeng

add yolov3 object detection model

上级 2c65b659
*.pyc
*.swp
*.log
*.json
*.jpg
*.png
output/
test/
checkpoints/
weights/
!weights/*.py
!weights/*.sh
dataset/coco/
log*
output*
# YOLO V3 Objective Detection
---
## Table of Contents
- [Installation](#installation)
- [Introduction](#introduction)
- [Data preparation](#data-preparation)
- [Training](#training)
- [Evaluation](#evaluation)
- [Inference and Visualization](#inference-and-visualization)
- [Appendix](#appendix)
## Installation
Running sample code in this directory requires PaddelPaddle Fluid v.1.1.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle) and make an update.
## Introduction
[YOLOv3](https://arxiv.org/abs/1804.02767) is a one stage end to end detector。the detection principle of YOLOv3 is as follow:
<p align="center">
<img src"image/YOLOv3.jpg" height=400 width=400 hspace='10'/> <br />
YOLOv3 detection principle
</p>
YOLOv3 divides the input image in to S\*S grids and predict B bounding boxes in each grid, predictions of boxes include Location(x, y, w, h), Confidence Score and probabilities of C classes, therefore YOLOv3 output layer has S\*S\*B\*(5 + C) channels. YOLOv3 loss consist of three parts: location loss, IoU loss and classification loss.
The bone network of YOLOv3 is darknet53, the structure of YOLOv3 is as follow:
<p align="center">
<img src"image/YOLOv3_structure.jpg" height=400 width=400 hspace='10'/> <br />
YOLOv3 structure
</p>
## Data preparation
Train the model on [MS-COCO dataset](http://cocodataset.org/#download), download dataset as below:
cd dataset/coco
./download.sh
## Training
After data preparation, one can start the training step by:
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model}
--data_dir=${path_to_data}
- Set ```export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7``` to specifiy 8 GPU to train.
- For more help on arguments:
python train.py --help
**download the pre-trained model:** This sample provides Resnet-50 pre-trained model which is converted from Caffe. The model fuses the parameters in batch normalization layer. One can download pre-trained model as:
sh ./weights/download_pretrain_weights.sh
Set `pretrained_model` to load pre-trained model. In addition, this parameter is used to load trained model when finetuning as well.
Please make sure that pretrained_model is downloaded and loaded correctly, otherwise, the loss may be NAN during training.
**Install the [cocoapi](https://github.com/cocodataset/cocoapi):**
To train the model, [cocoapi](https://github.com/cocodataset/cocoapi) is needed. Install the cocoapi:
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python2 setup.py install --user
**training strategy:**
* Use momentum optimizer with momentum=0.9.
* In first 4000 iteration, the learning rate increases linearly from 0.0 to 0.01. Then lr is decayed at 450000, 500000 iteration with multiplier 0.1, 0.01. The maximum iteration is 500000.
Training result is shown as below:
<p align="center">
<img src="image/train_loss.jpg" height=500 width=650 hspace='10'/> <br />
YOLOv3
</p>
## Evaluation
Evaluation is to evaluate the performance of a trained model. This sample provides `eval.py` which uses a COCO-specific mAP metric defined by [COCO committee](http://cocodataset.org/#detections-eval).
`eval.py` is the main executor for evalution, one can start evalution step by:
python eval.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to eval.
Evalutaion result is shown as below:
<p align="center">
<img src="image/mAP.jpg" height=500 width=650 hspace='10'/> <br />
YOLOv3 mAP
</p>
## Inference and Visualization
Inference is used to get prediction score or image features based on trained models. `infer.py` is the main executor for inference, one can start infer step by:
python infer.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
--image_path=data/COCO17/val2017/ \
--image_name=000000000139.jpg \
--draw_threshold=0.5
Visualization of infer result is shown as below:
<p align="center">
<img src="image/000000000139.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000127517.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000203864.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000515077.jpg" height=300 width=400 hspace='10'/> <br />
YOLOv3 Visualization Examples
</p>
# YOLO V3 目标检测
---
## 内容
- [安装](#安装)
- [简介](#简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断及可视化](#模型推断及可视化)
- [附录](#附录)
## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.1.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。
## 简介
[YOLOv3](https://arxiv.org/abs/1804.02767) 是一阶段End2End的目标检测器。其目标检测原理如下图所示:
<p align="center">
<img src"image/YOLOv3.jpg" height=400 width=400 hspace='10'/> <br />
YOLOv3检测原理
</p>
YOLOv3将输入图像分成S\*S个格子,每个格子预测B个bounding box,每个bounding box预测内容包括: Location(x, y, w, h)、Confidence Score和C个类别的概率,因此YOLOv3输出层的channel数为S\*S\*B\*(5 + C)。YOLOv3的loss函数也有三部分组成:坐标误差,IOU误差和分类误差。
YOLOv3的主干网络为darknet53,其网络结构如下图所示:
<p align="center">
<img src"image/YOLOv3_structure.jpg" height=400 width=400 hspace='10'/> <br />
YOLOv3网络结构
</p>
在darknet53的基础上通过三个尺度完成目标检测
## 数据准备
[MS-COCO数据集](http://cocodataset.org/#download)上进行训练,通过如下方式下载数据集。
cd dataset/coco
./download.sh
## 模型训练
数据准备完毕后,可以通过如下的方式启动训练:
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model}
--data_dir=${path_to_data}
- 通过设置export CUDA\_VISIBLE\_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。
- 可选参数见:
python train.py --help
**下载预训练模型:** 本示例提供darknet53预训练模型,该模性转换自作者提供的darknet53在ImageNet上预训练的权重,采用如下命令下载预训练模型:
sh ./weights/download_pretrained_weight.sh
通过初始化`pretrained_model` 加载预训练模型。同时在参数微调时也采用该设置加载已训练模型。
请在训练前确认预训练模型下载与加载正确,否则训练过程中损失可能会出现NAN。
**安装[cocoapi](https://github.com/cocodataset/cocoapi):**
训练前需要首先下载[cocoapi](https://github.com/cocodataset/cocoapi)
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
make install
# Alternatively, if you do not have permissions or prefer
# not to install the COCO API into global site-packages
python2 setup.py install --user
**训练策略:**
* 采用momentum优化算法训练YOLOv3,momentum=0.9。
* 学习率采用warmup算法,前4000轮学习率从0.0线性增加至0.01。在400000,450000轮时使用0.1,0.01乘子进行学习率衰减,最大训练500000轮。
## 模型评估
模型评估是指对训练完毕的模型评估各类性能指标。本示例采用[COCO官方评估](http://cocodataset.org/#detections-eval)
`eval.py`是评估模块的主要执行程序,调用示例如下:
python eval.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
- 通过设置export CUDA\_VISIBLE\_DEVICES=0指定单卡GPU评估。
下图为模型评估结果:
<p align="center">
<img src="image/mAP.jpg" height=500 width=650 hspace='10'/> <br />
YOLOv3
</p>
## 模型推断及可视化
模型推断可以获取图像中的物体及其对应的类别,`infer.py`是主要执行程序,调用示例如下:
python infer.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
--image_path=data/COCO17/val2017/ \
--image_name=000000000139.jpg \
--draw_threshold=0.5
下图为模型可视化预测结果:
<p align="center">
<img src="image/000000000139.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000127517.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000203864.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000515077.jpg" height=300 width=400 hspace='10'/> <br />
YOLOv3 预测可视化
</p>
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from PIL import Image
def sigmoid(x):
"""Perform sigmoid to input numpy array"""
return 1.0 / (1.0 + np.exp(-1.0 * x))
def coco_anno_box_to_center_relative(box, img_height, img_width):
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert len(box) == 4, "box should be a len(4) list or tuple"
x, y, w, h = box
x1 = max(x, 0)
x2 = min(x + w, img_width - 1)
y1 = max(y, 0)
y2 = min(y + h, img_height - 1)
x = (x1 + x2) / 2 / img_width
y = (y1 + y2) / 2 / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
return np.array([x, y, w, h])
def clip_relative_box_in_image(x, y, w, h):
"""Clip relative box coordinates x, y, w, h to [0, 1]"""
x1 = max(x - w / 2, 0.)
x2 = min(x + w / 2, 1.)
y1 = min(y - h / 2, 0.)
y2 = max(y + h / 2, 1.)
x = (x1 + x2) / 2
y = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
def box_xywh_to_xyxy(box):
shape = box.shape
assert shape[-1] == 4, "Box shape[-1] should be 4."
box = box.reshape((-1, 4))
box[:, 0], box[:, 2] = box[:, 0] - box[:, 2] / 2, box[:, 0] + box[:, 2] / 2
box[:, 1], box[:, 3] = box[:, 1] - box[:, 3] / 2, box[:, 1] + box[:, 3] / 2
box = box.reshape(shape)
return box
def box_iou_xywh(box1, box2):
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.maximum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1 + 1
inter_h = inter_y2 - inter_y1 + 1
inter_w[inter_w < 0] == 0
inter_h[inter_h < 0] == 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
return inter_area / (b1_area + b2_area - inter_area)
def box_iou_xyxy(box1, box2):
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.maximum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1 + 1
inter_h = inter_y2 - inter_y1 + 1
inter_w[inter_w < 0] == 0
inter_h[inter_h < 0] == 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
return inter_area / (b1_area + b2_area - inter_area)
def rescale_box_in_input_image(boxes, im_shape, input_size):
"""Scale (x1, x2, y1, y2) box of yolo output to input image"""
h, w = im_shape
# max_dim = max(h , w)
# boxes = boxes * max_dim / input_size
# dim_diff = np.abs(h - w)
# pad = dim_diff // 2
# if h <= w:
# boxes[:, 1] -= pad
# boxes[:, 3] -= pad
# else:
# boxes[:, 0] -= pad
# boxes[:, 2] -= pad
fx = w / input_size
fy = h / input_size
boxes[:, 0] *= fx
boxes[:, 1] *= fy
boxes[:, 2] *= fx
boxes[:, 3] *= fy
boxes[boxes<0] = 0
return boxes
def box_crop(boxes, labels, crop, img_shape):
x, y, w, h = map(float, crop)
im_w, im_h = map(float, img_shape)
boxes = boxes.copy()
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (boxes[:, 0] + boxes[:, 2] / 2) * im_w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (boxes[:, 1] + boxes[:, 3] / 2) * im_h
crop_box = np.array([x, y, x + w, y + h])
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(axis=1)
boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2])
boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:])
boxes[:, :2] -= crop_box[:2]
boxes[:, 2:] -= crop_box[:2]
mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1))
boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1)
labels = labels * mask.astype('float32')
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (boxes[:, 2] - boxes[:, 0]) / w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (boxes[:, 3] - boxes[:, 1]) / h
return boxes, labels, mask.sum()
def get_yolo_detection(preds, anchors, class_num, img_width, img_height):
"""Get yolo box, confidence score, class label from Darknet53 output"""
preds_n = np.array(preds)
n, c, h, w = preds_n.shape
anchor_num = len(anchors) // 2
preds_n = preds_n.reshape([n, anchor_num, class_num + 5, h, w]) \
.transpose((0, 1, 3, 4, 2))
preds_n[:, :, :, :, :2] = sigmoid(preds_n[:, :, :, :, :2])
preds_n[:, :, :, :, 4:] = sigmoid(preds_n[:, :, :, :, 4:])
pred_boxes = preds_n[:, :, :, :, :4]
pred_confs = preds_n[:, :, :, :, 4]
pred_scores = preds_n[:, :, :, :, 5:] * np.expand_dims(pred_confs, axis=4)
grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1))
grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w))
anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors), 2)]
anchors_s = np.array([(an_w, an_h) for an_w, an_h in anchors])
anchor_w = anchors_s[:, 0:1].reshape((1, anchor_num, 1, 1))
anchor_h = anchors_s[:, 1:2].reshape((1, anchor_num, 1, 1))
pred_boxes[:, :, :, :, 0] += grid_x
pred_boxes[:, :, :, :, 1] += grid_y
pred_boxes[:, :, :, :, 2] = np.exp(pred_boxes[:, :, :, :, 2]) * anchor_w
pred_boxes[:, :, :, :, 3] = np.exp(pred_boxes[:, :, :, :, 3]) * anchor_h
pred_boxes[:, :, :, :, 0] = pred_boxes[:, :, :, :, 0] * img_width / w
pred_boxes[:, :, :, :, 1] = pred_boxes[:, :, :, :, 1] * img_height / h
pred_boxes[:, :, :, :, 2] = pred_boxes[:, :, :, :, 2]
pred_boxes[:, :, :, :, 3] = pred_boxes[:, :, :, :, 3]
pred_boxes = box_xywh_to_xyxy(pred_boxes)
pred_boxes = np.tile(np.expand_dims(pred_boxes, axis=4), (1, 1, 1, 1, class_num, 1))
pred_labels = np.zeros_like(pred_scores) + np.arange(class_num)
return (
pred_boxes.reshape((n, -1, 4)),
pred_scores.reshape((n, -1)),
pred_labels.reshape((n, -1)),
)
def get_all_yolo_pred(outputs, yolo_anchors, yolo_classes, input_shape):
all_pred_boxes = []
all_pred_scores = []
all_pred_labels = []
for output, anchors, classes in zip(outputs, yolo_anchors, yolo_classes):
pred_boxes, pred_scores, pred_labels = get_yolo_detection(output, anchors, classes, input_shape[0], input_shape[1])
all_pred_boxes.append(pred_boxes)
all_pred_labels.append(pred_labels)
all_pred_scores.append(pred_scores)
pred_boxes = np.concatenate(all_pred_boxes, axis=1)
pred_scores = np.concatenate(all_pred_scores, axis=1)
pred_labels = np.concatenate(all_pred_labels, axis=1)
return (pred_boxes, pred_scores, pred_labels)
def calc_nms_box_new(pred_boxes, pred_scores, pred_labels, valid_thresh=0.01, nms_thresh=0.4, nms_topk=400, nms_posk=100):
output_boxes = np.empty((0, 4))
output_scores = np.empty(0)
output_labels = np.empty(0)
for boxes, labels, scores in zip(pred_boxes, pred_labels, pred_scores):
valid_mask = scores > valid_thresh
boxes = boxes[valid_mask]
scores = scores[valid_mask]
labels = labels[valid_mask]
score_sort_index = np.argsort(scores)[::-1]
boxes = boxes[score_sort_index][:nms_topk]
scores = scores[score_sort_index][:nms_topk]
labels = labels[score_sort_index][:nms_topk]
for c in np.unique(labels):
c_mask = labels == c
c_boxes = boxes[c_mask]
c_scores = scores[c_mask]
detect_boxes = []
detect_scores = []
detect_labels = []
while c_boxes.shape[0]:
detect_boxes.append(c_boxes[0])
detect_scores.append(c_scores[0])
detect_labels.append(c)
if c_boxes.shape[0] == 1:
break
iou = box_iou_xyxy(detect_boxes[-1].reshape((1, 4)), c_boxes[1:])
c_boxes = c_boxes[1:][iou < nms_thresh]
c_scores = c_scores[1:][iou < nms_thresh]
output_boxes = np.append(output_boxes, detect_boxes, axis=0)
output_scores = np.append(output_scores, detect_scores)
output_labels = np.append(output_labels, detect_labels)
return (output_boxes, output_scores, output_labels)
def calc_nms_box(pred_boxes, pred_confs, pred_labels, im_shape, input_size, valid_thresh=0.8, nms_thresh=0.4, nms_topk=400, nms_posk=100):
"""
Removes detections which confidence score under valid_thresh and perform
Non-Maximun Suppression to filtered boxes
"""
_, box_num, class_num = pred_labels.shape
pred_boxes = box_xywh_to_xyxy(pred_boxes)
output_boxes = np.empty((0, 4))
output_scores = np.empty(0)
output_labels = np.empty((0))
for i, (boxes, confs, classes) in enumerate(zip(pred_boxes, pred_confs, pred_labels)):
conf_mask = confs > valid_thresh
if conf_mask.sum() == 0:
continue
boxes = boxes[conf_mask]
classes = classes[conf_mask]
confs = confs[conf_mask]
conf_sort_index = np.argsort(confs)[::-1]
boxes = boxes[conf_sort_index][:nms_topk]
classes = classes[conf_sort_index][:nms_topk]
confs = confs[conf_sort_index][:nms_topk]
cls_score = np.max(classes, axis=1)
cls_pred = np.argmax(classes, axis=1)
for c in np.unique(cls_pred):
c_mask = cls_pred == c
c_confs = confs[c_mask]
c_boxes = boxes[c_mask]
c_scores = cls_score[c_mask]
c_score_index = np.argsort(c_scores)
c_boxes_s = c_boxes[c_score_index[::-1]]
c_confs_s = c_confs[c_score_index[::-1]]
c_scores_s = c_scores[c_score_index[::-1]]
detect_boxes = []
detect_scores = []
detect_labels = []
while c_boxes_s.shape[0]:
detect_boxes.append(c_boxes_s[0])
detect_scores.append(c_scores_s[0])
detect_labels.append(c)
if c_boxes_s.shape[0] == 1:
break
iou = box_iou_xyxy(detect_boxes[-1].reshape((1, 4)), c_boxes_s[1:])
c_boxes_s = c_boxes_s[1:][iou < nms_thresh]
c_confs_s = c_confs_s[1:][iou < nms_thresh]
c_scores_s = c_scores_s[1:][iou < nms_thresh]
output_boxes = np.append(output_boxes, detect_boxes, axis=0)
output_scores = np.append(output_scores, detect_scores)
output_labels = np.append(output_labels, detect_labels)
output_boxes = output_boxes[:nms_posk]
output_scores = output_scores[:nms_posk]
output_labels = output_labels[:nms_posk]
output_boxes = rescale_box_in_input_image(output_boxes, im_shape, input_size)
return (output_boxes, output_scores, output_labels)
def draw_boxes_on_image(image_path, boxes, scores, labels, label_names, score_thresh=0.5):
image = np.array(Image.open(image_path))
plt.figure()
_, ax = plt.subplots(1)
ax.imshow(image)
image_name = image_path.split('/')[-1]
print("Image {} detect: ".format(image_name))
colors = {}
for box, score, label in zip(boxes, scores, labels):
if score < score_thresh:
continue
if box[2] <= box[0] or box[3] <= box[1]:
continue
label = int(label)
if label not in colors:
colors[label] = plt.get_cmap('hsv')(label / len(label_names))
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
fill=False, linewidth=2.0,
edgecolor=colors[label])
ax.add_patch(rect)
ax.text(x1, y1, '{} {:.4f}'.format(label_names[label], score),
verticalalignment='bottom', horizontalalignment='left',
bbox={'facecolor': colors[label], 'alpha': 0.5, 'pad': 0},
fontsize=8, color='white')
print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(label)], map(int, list(box)), score))
image_name = image_name.replace('jpg', 'png')
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0)
print("Detect result save at ./output/{}\n".format(image_name))
plt.cla()
plt.close('all')
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from .edict import AttrDict
import six
import numpy as np
_C = AttrDict()
cfg = _C
#
# Training options
#
# Snapshot period
_C.snapshot_iter = 2000
# min valid area for gt boxes
_C.gt_min_area = -1
# max target box number in an image
_C.max_box_num = 50
#
# Training options
#
# valid score threshold to include boxes
_C.valid_thresh = 0.01
# threshold vale for box non-max suppression
_C.nms_thresh = 0.45
# the number of top k boxes to perform nms
_C.nms_topk = 400
# the number of output boxes after nms
_C.nms_posk = 100
# score threshold for draw box in debug mode
_C.conf_thresh = 0.5
#
# Model options
#
# pixel mean values
_C.pixel_means = [0.485, 0.456, 0.406]
# pixel std values
_C.pixel_stds = [0.229, 0.224, 0.225]
#
# SOLVER options
#
# derived learning rate the to get the final learning rate.
_C.learning_rate = 0.001
# maximum number of iterations
_C.max_iter = 500200
# warm up to learning rate
_C.warm_up_iter = 4000
_C.warm_up_factor = 0.
# lr steps_with_decay
_C.lr_steps = [400000, 450000]
_C.lr_gamma = 0.1
# L2 regularization hyperparameter
_C.weight_decay = 0.0005
# momentum with SGD
_C.momentum = 0.9
#
# ENV options
#
# support both CPU and GPU
_C.use_gpu = True
# Whether use parallel
_C.parallel = True
# Class number
_C.class_num = 80
# support pyreader
_C.use_pyreader = True
# dataset path
_C.train_file_list = 'annotations/instances_train2017.json'
_C.train_data_dir = 'train2017'
_C.val_file_list = 'annotations/instances_val2017.json'
_C.val_data_dir = 'val2017'
def merge_cfg_from_args(args):
"""Merge config keys, values in args into the global config."""
for k, v in sorted(six.iteritems(vars(args))):
try:
value = eval(v)
except:
value = v
_C[k] = value
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
LAYER_TYPES = [
"net",
"convolutional",
"shortcut",
"route",
"upsample",
"maxpool",
"yolo",
]
class ConfigPaser(object):
def __init__(self, config_path):
self.config_path = config_path
def parse(self):
with open(self.config_path) as cfg_file:
model_defs = []
for line in cfg_file.readlines():
line = line.strip()
if len(line) == 0:
continue
if line.startswith('#'):
continue
if line.startswith('['):
layer_type = line[1:-1].strip()
if layer_type not in LAYER_TYPES:
print("Unknow config layer type: ", layer_type)
return None
model_defs.append({})
model_defs[-1]['type'] = layer_type
else:
key, value = line.split('=')
model_defs[-1][key.strip()] = value.strip()
return model_defs
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
[net]
# Testing
batch=1
subdivisions=1
# Training
# batch=64
# subdivisions=2
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1
[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=1
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
###########
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 3,4,5
anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
classes=80
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 8
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 0,1,2
anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
classes=80
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[net]
# Testing
# batch=1
# subdivisions=1
# Training
batch=64
subdivisions=16
width=608
height=608
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
# Downsample
[convolutional]
batch_normalize=1
filters=64
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=32
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=128
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=256
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=512
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
######################
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 6,7,8
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 61
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 3,4,5
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 36
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 0,1,2
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
import time
import numpy as np
import threading
import multiprocessing
try:
import queue
except ImportError:
import Queue as queue
class GeneratorEnqueuer(object):
"""
Builds a queue out of a data generator.
Args:
generator: a generator function which endlessly yields data
use_multiprocessing (bool): use multiprocessing if True,
otherwise use threading.
wait_time (float): time to sleep in-between calls to `put()`.
random_seed (int): Initial seed for workers,
will be incremented by one for each workers.
"""
def __init__(self,
generator,
use_multiprocessing=False,
wait_time=0.05,
random_seed=None):
self.wait_time = wait_time
self._generator = generator
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self.queue = None
self._manager = None
self.seed = random_seed
def start(self, workers=1, max_queue_size=10):
"""
Start worker threads which add data from the generator into the queue.
Args:
workers (int): number of worker threads
max_queue_size (int): queue size
(when full, threads could block on `put()`)
"""
def data_generator_task():
"""
Data generator task.
"""
def task():
if (self.queue is not None and
self.queue.qsize() < max_queue_size):
generator_output = next(self._generator)
self.queue.put((generator_output))
else:
time.sleep(self.wait_time)
if not self._use_multiprocessing:
while not self._stop_event.is_set():
with self.genlock:
try:
task()
except Exception:
self._stop_event.set()
break
else:
while not self._stop_event.is_set():
try:
task()
except Exception:
self._stop_event.set()
break
try:
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.genlock = threading.Lock()
self.queue = queue.Queue()
self._stop_event = threading.Event()
for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.seed)
thread = multiprocessing.Process(target=data_generator_task)
thread.daemon = True
if self.seed is not None:
self.seed += 1
else:
thread = threading.Thread(target=data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise
def is_running(self):
"""
Returns:
bool: Whether the worker theads are running.
"""
return self._stop_event is not None and not self._stop_event.is_set()
def stop(self, timeout=None):
"""
Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called `start()`.
Args:
timeout(int|None): maximum time to wait on `thread.join()`.
"""
if self.is_running():
self._stop_event.set()
for thread in self._threads:
if self._use_multiprocessing:
if thread.is_alive():
thread.terminate()
else:
thread.join(timeout)
if self._manager:
self._manager.shutdown()
self._threads = []
self._stop_event = None
self.queue = None
def get(self):
"""
Creates a generator to extract data from the queue.
Skip the data if it is `None`.
# Yields
tuple of data in the queue.
"""
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
coco2014/
!coco2014/*.sh
!coco2014/*.py
!coco2014/coco.*
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the data.
echo "Downloading..."
wget http://images.cocodataset.org/zips/train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
# Extract the data.
echo "Extracting..."
unzip train2014.zip
unzip val2014.zip
unzip train2017.zip
unzip val2017.zip
unzip annotations_trainval2014.zip
unzip annotations_trainval2017.zip
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import paddle
import paddle.fluid as fluid
import box_utils
import reader
import models
from utility import print_arguments, parse_args
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config.config import cfg
def eval():
if '2014' in cfg.dataset:
test_list = 'annotations/instances_val2014.json'
elif '2017' in cfg.dataset:
test_list = 'annotations/instances_val2017.json'
if cfg.debug:
if not os.path.exists('output'):
os.mkdir('output')
model = models.YOLOv3(cfg.model_cfg_path, is_train=False)
model.build_model()
outputs = model.get_pred()
hyperparams = model.get_hyperparams()
yolo_anchors = model.get_yolo_anchors()
yolo_classes = model.get_yolo_classes()
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if cfg.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
# yapf: enable
input_size = model.get_input_size()
test_reader = reader.test(input_size, 1)
label_names, label_ids = reader.get_label_infos()
if cfg.debug:
print("Load in labels {} with ids {}".format(label_names, label_ids))
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
def get_pred_result(boxes, confs, labels, im_id):
result = []
for box, conf, label in zip(boxes, confs, labels):
x1, y1, x2, y2 = box
w = x2 - x1 + 1
h = y2 - y1 + 2
bbox = [x1, y1, w, h]
res = {
'image_id': im_id,
'category_id': label_ids[int(label)],
'bbox': bbox,
'score': conf
}
result.append(res)
return result
dts_res = []
fetch_list = outputs
total_time = 0
for batch_id, batch_data in enumerate(test_reader()):
start_time = time.time()
batch_outputs = exe.run(
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(batch_data),
return_numpy=False)
for data, outputs in zip(batch_data, batch_outputs):
im_id = data[1]
im_shape = data[2]
pred_boxes, pred_scores, pred_labels = box_utils.get_all_yolo_pred(
batch_outputs, yolo_anchors, yolo_classes, (input_size, input_size))
boxes, scores, labels = box_utils.calc_nms_box_new(pred_boxes, pred_scores, pred_labels,
cfg.valid_thresh, cfg.nms_thresh)
boxes = box_utils.rescale_box_in_input_image(boxes, im_shape, input_size)
dts_res += get_pred_result(boxes, scores, labels, im_id)
end_time = time.time()
print("batch id: {}, time: {}".format(batch_id, end_time - start_time))
total_time += (end_time - start_time)
if cfg.debug:
if '2014' in cfg.dataset:
img_name = "COCO_val2014_{:012d}.jpg".format(im_id)
box_utils.draw_boxes_on_image(os.path.join("./dataset/coco/val2014", img_name), boxes, scores, labels, label_names)
if '2017' in cfg.dataset:
img_name = "{:012d}.jpg".format(im_id)
box_utils.draw_boxes_on_image(os.path.join("./dataset/coco/val2017", img_name), boxes, scores, labels, label_names)
with open("yolov3_result.json", 'w') as outfile:
json.dump(dts_res, outfile)
print("start evaluate detection result with coco api")
coco = COCO(os.path.join(cfg.data_dir, test_list))
cocoDt = coco.loadRes("yolov3_result.json")
cocoEval = COCOeval(coco, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
print("evaluate done.")
print("Time per batch: {}".format(total_time / batch_id))
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
eval()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
from PIL import Image, ImageEnhance
import random
import box_utils
def random_distort(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = Image.fromarray(img)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
img = np.asarray(img)
return img
def random_crop(img, boxes, labels, scales=[0.3, 1.0], max_ratio=2.0, constraints=None, max_trial=50):
if len(boxes) == 0:
return img, boxes
if not constraints:
constraints = [
(0.1, 1.0),
(0.3, 1.0),
(0.5, 1.0),
(0.7, 1.0),
(0.9, 1.0),
(0.0, 1.0)]
img = Image.fromarray(img)
w, h = map(float, img.size)
crops = [(0, 0, w, h)]
for min_iou, max_iou in constraints:
for _ in range(max_trial):
scale = random.uniform(scales[0], scales[1])
aspect_ratio = random.uniform(max(1 / max_ratio, scale * scale), \
min(max_ratio, 1 / scale / scale))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_x = random.randrange(w - crop_w)
crop_y = random.randrange(h - crop_h)
crop_box = np.array([[
(crop_x + crop_w / 2.0) / w,
(crop_y + crop_h / 2.0) / h,
crop_w / w,
crop_h /h
]])
iou = box_utils.box_iou_xywh(crop_box, boxes)
if min_iou <= iou.min() and max_iou >= iou.max():
crops.append((crop_x, crop_y, crop_w, crop_h))
break
while crops:
crop = crops.pop(np.random.randint(0, len(crops)))
crop_boxes, crop_labels, box_num = box_utils.box_crop(boxes, labels, crop, (w, h))
if box_num < 1:
continue
img = img.crop((crop[0], crop[1], crop[0] + crop[2], crop[1] + crop[3])).resize(img.size, Image.LANCZOS)
img = np.asarray(img)
return img, crop_boxes, crop_labels
img = np.asarray(img)
return img, boxes, labels
def random_flip(img, gtboxes, thresh=0.5):
if random.random() > thresh:
img = img[:, ::-1, :]
gtboxes[:, 0] = 1.0 - gtboxes[:, 0]
return img, gtboxes
def random_interp(img, size):
interp_method = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
]
interp = interp_method[random.randint(0, len(interp_method) - 1)]
h, w, _ = img.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp)
return img
def random_expand(img, gtboxes, max_ratio=4., fill=None, keep_ratio=True, thresh=0.5):
if random.random() > thresh:
return img, gtboxes
if max_ratio < 1.0:
return img, gtboxes
h, w, c = img.shape
ratio_x = random.uniform(1, max_ratio)
if keep_ratio:
ratio_y = ratio_x
else:
ratio_y = random.uniform(1, max_ratio)
oh = int(h * ratio_y)
ow = int(w * ratio_x)
off_x = random.randint(0, ow -w)
off_y = random.randint(0, oh -h)
out_img = np.zeros((oh, ow, c))
if fill and len(fill) == c:
for i in range(c):
out_img[:, :, i] = fill[i] * 255.0
out_img[off_y: off_y + h, off_x: off_x + w, :] = img
gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow)
gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh)
gtboxes[:, 2] = gtboxes[:, 2] / ratio_x
gtboxes[:, 3] = gtboxes[:, 3] / ratio_y
return out_img.astype('uint8'), gtboxes
def image_mixup(img1, gtboxes1, gtlabels1, img2, gtboxes2, gtlabels2):
factor = np.random.beta(1.5, 1.5)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
return img1, gtboxes1, gtlabels1
if factor <= 0.0:
return img2, gtboxes2, gtlabels2
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
img = np.zeros((h, w, img1.shape[2]), 'float32')
img[:img1.shape[0], :img1.shape[1], :] = img1.astype('float32') * factor
img[:img2.shape[0], :img2.shape[1], :] += img2.astype('float32') * (1.0 - factor)
gtboxes = np.zeros_like(gtboxes1)
gtlabels = np.zeros_like(gtlabels1)
gt_valid_mask1 = np.logical_and(gtboxes1[:, 2] > 0, gtboxes1[:, 3] > 0)
gtboxes1 = gtboxes1[gt_valid_mask1]
gtlabels1 = gtlabels1[gt_valid_mask1]
gtboxes1[:, 0] = gtboxes1[:, 0] * img1.shape[1] / w
gtboxes1[:, 1] = gtboxes1[:, 1] * img1.shape[0] / h
gtboxes1[:, 2] = gtboxes1[:, 2] * img1.shape[1] / w
gtboxes1[:, 3] = gtboxes1[:, 3] * img1.shape[0] / h
gt_valid_mask2 = np.logical_and(gtboxes2[:, 2] > 0, gtboxes2[:, 3] > 0)
gtboxes2 = gtboxes2[gt_valid_mask2]
gtlabels2 = gtlabels2[gt_valid_mask2]
gtboxes2[:, 0] = gtboxes2[:, 0] * img2.shape[1] / w
gtboxes2[:, 1] = gtboxes2[:, 1] * img2.shape[0] / h
gtboxes2[:, 2] = gtboxes2[:, 2] * img2.shape[1] / w
gtboxes2[:, 3] = gtboxes2[:, 3] * img2.shape[0] / h
gtboxes_all = np.concatenate((gtboxes1, gtboxes2), axis=0)
gtlabels_all = np.concatenate((gtlabels1, gtlabels2), axis=0)
gt_num = min(len(gtboxes), len(gtboxes_all))
gtboxes[:gt_num] = gtboxes_all[:gt_num]
gtlabels[:gt_num] = gtlabels_all[:gt_num]
return img.astype('uint8'), gtboxes, gtlabels
def image_augment(img, gtboxes, gtlabels, size, means=None):
img = random_distort(img)
img, gtboxes = random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels = random_crop(img, gtboxes, gtlabels)
img = random_interp(img, size)
img, gtboxes = random_flip(img, gtboxes)
return img.astype('float32'), gtboxes.astype('float32'), gtlabels.astype('int32')
import os
import time
import numpy as np
import paddle
import paddle.fluid as fluid
import box_utils
import reader
from utility import print_arguments, parse_args
import models
# from coco_reader import load_label_names
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config.config import cfg
def infer():
if not os.path.exists('output'):
os.mkdir('output')
model = models.YOLOv3(cfg.model_cfg_path, is_train=False)
model.build_model()
outputs = model.get_pred()
input_size = model.get_input_size()
yolo_anchors = model.get_yolo_anchors()
yolo_classes = model.get_yolo_classes()
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if cfg.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
# yapf: enable
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
fetch_list = outputs
image_names = []
if cfg.image_name is not None:
image_names.append(cfg.image_name)
else:
for image_name in os.listdir(cfg.image_path):
if image_name.split('.')[-1] in ['jpg', 'png']:
image_names.append(image_name)
for image_name in image_names:
infer_reader = reader.infer(input_size, os.path.join(cfg.image_path, image_name))
label_names, _ = reader.get_label_infos()
data = next(infer_reader())
im_shape = data[0][2]
outputs = exe.run(
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data),
return_numpy=True)
pred_boxes, pred_scores, pred_labels = box_utils.get_all_yolo_pred(outputs, yolo_anchors,
yolo_classes, (input_size, input_size))
boxes, scores, labels = box_utils.calc_nms_box_new(pred_boxes, pred_scores, pred_labels,
cfg.valid_thresh, cfg.nms_thresh)
boxes = box_utils.rescale_box_in_input_image(boxes, im_shape, input_size)
path = os.path.join(cfg.image_path, image_name)
box_utils.draw_boxes_on_image(path, boxes, scores, labels, label_names, cfg.draw_thresh)
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
infer()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler
from paddle.fluid.layers import control_flow
def exponential_with_warmup_decay(learning_rate, boundaries, values,
warmup_iter, warmup_factor, start_step):
global_step = lr_scheduler._decay_step_counter() + start_step
lr = fluid.layers.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_iter_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warmup_iter), force_cpu=True)
with control_flow.Switch() as switch:
with switch.case(global_step < warmup_iter_var):
alpha = global_step / warmup_iter_var
factor = warmup_factor * (1 - alpha) + alpha
decayed_lr = learning_rate * factor
fluid.layers.assign(decayed_lr, lr)
for i in range(len(boundaries)):
boundary_val = fluid.layers.fill_constant(
shape=[1],
dtype='float32',
value=float(boundaries[i]),
force_cpu=True)
value_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(global_step < boundary_val):
fluid.layers.assign(value_var, lr)
last_value_var = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(values[len(values) - 1]))
with switch.default():
fluid.layers.assign(last_value_var, lr)
return lr
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.initializer import Normal
from paddle.fluid.regularizer import L2Decay
import box_utils
from config.config_parser import ConfigPaser
from config.config import cfg
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act=None,
bn=False,
name=None,
is_train=True):
if bn:
out = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02),
name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1')
bn_name = "bn" + name[4:]
out = fluid.layers.batch_norm(input=out,
act=None,
is_test=not is_train,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02),
regularizer=L2Decay(0.),
name=bn_name + '_scale'),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.),
name=bn_name + '_offset'),
moving_mean_name=bn_name+'_mean',
moving_variance_name=bn_name+'_var',
name=bn_name+'.output')
else:
out = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02),
name=name + "_weights"),
bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.),
name=name + "_bias"),
name=name + '.conv2d.output.1')
if act == 'relu':
out = fluid.layers.relu(x=out)
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
class YOLOv3(object):
def __init__(self,
model_cfg_path,
is_train=True,
use_pyreader=True,
use_random=True):
self.model_cfg_path = model_cfg_path
self.config_parser = ConfigPaser(model_cfg_path)
self.is_train = is_train
self.use_pyreader = use_pyreader
self.use_random = use_random
self.outputs = []
self.losses = []
self.downsample = 32
def build_model(self):
model_defs = self.config_parser.parse()
if model_defs is None:
return None
self.hyperparams = model_defs.pop(0)
assert self.hyperparams['type'].lower() == "net", \
"net config params should be given in the first segment named 'net'"
self.img_height = cfg.input_size
self.img_width = cfg.input_size
self.build_input()
out = self.image
layer_outputs = []
self.yolo_layer_defs = []
self.yolo_anchors = []
self.yolo_classes = []
self.outputs = []
for i, layer_def in enumerate(model_defs):
if layer_def['type'] == 'convolutional':
bn = layer_def.get('batch_normalize', 0)
ch_out = int(layer_def['filters'])
filter_size = int(layer_def['size'])
stride = int(layer_def['stride'])
padding = (filter_size - 1) // 2 if int(layer_def['pad']) else 0
act = layer_def['activation']
out = conv_bn_layer(
input=out,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
act=act,
bn=bool(bn),
name="conv"+str(i),
is_train=self.is_train)
elif layer_def['type'] == 'shortcut':
layer_from = int(layer_def['from'])
out = fluid.layers.elementwise_add(
x=out,
y=layer_outputs[layer_from],
name="res"+str(i))
elif layer_def['type'] == 'route':
layers = map(int, layer_def['layers'].split(","))
out = fluid.layers.concat(
input=[layer_outputs[i] for i in layers],
axis=1)
elif layer_def['type'] == 'upsample':
scale = int(layer_def['stride'])
# get dynamic upsample output shape
shape_nchw = fluid.layers.shape(out)
shape_hw = fluid.layers.slice(shape_nchw, axes=[0], \
starts=[2], ends=[4])
shape_hw.stop_gradient = True
in_shape = fluid.layers.cast(shape_hw, dtype='int32')
out_shape = in_shape * scale
out_shape.stop_gradient = True
# reisze by actual_shape
out = fluid.layers.resize_nearest(
input=out,
scale=scale,
actual_shape=out_shape,
name="upsample"+str(i))
elif layer_def['type'] == 'maxpool':
pool_size = int(layer_def['size'])
pool_stride = int(layer_def['stride'])
pool_padding = 0
if pool_stride == 1 and pool_size == 2:
pool_padding = 1
out = fluid.layers.pool2d(
input=out,
pool_type='max',
pool_size=pool_size,
pool_stride=pool_stride,
pool_padding=pool_padding)
elif layer_def['type'] == 'yolo':
self.yolo_layer_defs.append(layer_def)
self.outputs.append(out)
anchor_mask = map(int, layer_def['mask'].split(','))
anchors = map(int, layer_def['anchors'].split(','))
mask_anchors = []
for m in anchor_mask:
mask_anchors.append(anchors[2 * m])
mask_anchors.append(anchors[2 * m + 1])
self.yolo_anchors.append(mask_anchors)
class_num = int(layer_def['classes'])
self.yolo_classes.append(class_num)
if self.is_train:
ignore_thresh = float(layer_def['ignore_thresh'])
loss = fluid.layers.yolov3_loss(
x=out,
gtbox=self.gtbox,
gtlabel=self.gtlabel,
anchors=anchors,
anchor_mask=anchor_mask,
class_num=class_num,
ignore_thresh=ignore_thresh,
downsample=self.downsample,
name="yolo_loss"+str(i))
self.losses.append(fluid.layers.reduce_mean(loss))
self.downsample //= 2
layer_outputs.append(out)
def loss(self):
return sum(self.losses)
def get_pred(self):
return self.outputs
def get_yolo_anchors(self):
return self.yolo_anchors
def get_yolo_classes(self):
return self.yolo_classes
def build_input(self):
self.image_shape = [3, self.img_height, self.img_width]
if self.use_pyreader and self.is_train:
self.py_reader = fluid.layers.py_reader(
capacity=64,
shapes = [[-1] + self.image_shape, [-1, cfg.max_box_num, 4], [-1, cfg.max_box_num]],
lod_levels=[0, 0, 0],
dtypes=['float32'] * 2 + ['int32'],
use_double_buffer=True)
self.image, self.gtbox, self.gtlabel = fluid.layers.read_file(self.py_reader)
else:
self.image = fluid.layers.data(
name='image', shape=self.image_shape, dtype='float32'
)
self.gtbox = fluid.layers.data(
name='gtbox', shape=[cfg.max_box_num, 4], dtype='float32'
)
self.gtlabel = fluid.layers.data(
name='gtlabel', shape=[cfg.max_box_num], dtype='int32'
)
self.im_shape = fluid.layers.data(
name="im_shape", shape=[2], dtype='int32')
self.im_id = fluid.layers.data(
name="im_id", shape=[1], dtype='int32')
def feeds(self):
if not self.is_train:
return [self.image, self.im_id, self.im_shape]
return [self.image, self.gtbox, self.gtlabel]
def get_hyperparams(self):
return self.hyperparams
def get_input_size(self):
return cfg.input_size
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import os
import random
import time
import copy
import cv2
import box_utils
import image_utils
from pycocotools.coco import COCO
from data_utils import GeneratorEnqueuer
from config.config import cfg
class DataSetReader(object):
"""A class for parsing and read COCO dataset"""
def __init__(self):
self.has_parsed_categpry = False
def _parse_dataset_dir(self, mode):
# cfg.data_dir = "dataset/coco"
# cfg.train_file_list = 'annotations/instances_val2017.json'
# cfg.train_data_dir = 'val2017'
cfg.dataset = "coco2017"
if 'coco2014' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2014.json'
cfg.train_data_dir = 'train2014'
cfg.val_file_list = 'annotations/instances_val2014.json'
cfg.val_data_dir = 'val2014'
elif 'coco2017' in cfg.dataset:
cfg.train_file_list = 'annotations/instances_train2017.json'
cfg.train_data_dir = 'train2017'
cfg.val_file_list = 'annotations/instances_val2017.json'
cfg.val_data_dir = 'val2017'
else:
raise NotImplementedError('Dataset {} not supported'.format(
cfg.dataset))
if mode == 'train':
cfg.train_file_list = os.path.join(cfg.data_dir, cfg.train_file_list)
cfg.train_data_dir = os.path.join(cfg.data_dir, cfg.train_data_dir)
self.COCO = COCO(cfg.train_file_list)
self.img_dir = cfg.train_data_dir
elif mode == 'test' or mode == 'infer':
cfg.val_file_list = os.path.join(cfg.data_dir, cfg.val_file_list)
cfg.val_data_dir = os.path.join(cfg.data_dir, cfg.val_data_dir)
self.COCO = COCO(cfg.val_file_list)
self.img_dir = cfg.val_data_dir
def _parse_dataset_catagory(self):
self.categories = self.COCO.loadCats(self.COCO.getCatIds())
self.num_category = len(self.categories)
self.label_names = []
self.label_ids = []
for category in self.categories:
self.label_names.append(category['name'])
self.label_ids.append(int(category['id']))
self.category_to_id_map = {
v: i
for i, v in enumerate(self.label_ids)
}
print("Load in {} categories.".format(self.num_category))
self.has_parsed_categpry = True
def get_label_infos(self):
if not self.has_parsed_categpry:
self._parse_dataset_dir("test")
self._parse_dataset_catagory()
return (self.label_names, self.label_ids)
def _parse_gt_annotations(self, img):
img_height = img['height']
img_width = img['width']
anno = self.COCO.loadAnns(self.COCO.getAnnIds(imgIds=img['id'], iscrowd=None))
gt_index = 0
for target in anno:
if target['area'] < cfg.gt_min_area:
continue
if target.has_key('ignore') and target['ignore']:
continue
box = box_utils.coco_anno_box_to_center_relative(target['bbox'], img_height, img_width)
if box[2] <= 0 and box[3] <= 0:
continue
img['gt_id'][gt_index] = np.int32(target['id'])
img['gt_boxes'][gt_index] = box
img['gt_labels'][gt_index] = self.category_to_id_map[target['category_id']]
gt_index += 1
if gt_index >= cfg.max_box_num:
break
def _parse_images(self, is_train):
image_ids = self.COCO.getImgIds()
image_ids.sort()
imgs = copy.deepcopy(self.COCO.loadImgs(image_ids))
# imgs = imgs[:8]
for img in imgs:
img['image'] = os.path.join(self.img_dir, img['file_name'])
assert os.path.exists(img['image']), \
"image {} not found.".format(img['image'])
box_num = cfg.max_box_num
img['gt_id'] = np.zeros((cfg.max_box_num), dtype=np.int32)
img['gt_boxes'] = np.zeros((cfg.max_box_num, 4), dtype=np.float32)
img['gt_labels'] = np.zeros((cfg.max_box_num), dtype=np.int32)
for k in ['date_captured', 'url', 'license', 'file_name']:
if img.has_key(k):
del img[k]
if is_train:
self._parse_gt_annotations(img)
print("Loaded {0} images from {1}.".format(len(imgs), cfg.dataset))
return imgs
def _parse_images_by_mode(self, mode):
if mode == 'infer':
return []
else:
return self._parse_images(is_train=(mode=='train'))
def get_reader(self, mode, size=416, batch_size=None, shuffle=False, mixup_iter=0, random_sizes=[], image=None):
assert mode in ['train', 'test', 'infer'], "Unknow mode type!"
if mode != 'infer':
assert batch_size is not None, "batch size connot be None in mode {}".format(mode)
self._parse_dataset_dir(mode)
self._parse_dataset_catagory()
def img_reader(img, size, mean, std):
im_path = img['image']
im = cv2.imread(im_path).astype('float32')
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w, _ = im.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (out_img / 255.0 - mean) / std
out_img = out_img.transpose((2, 0, 1))
return (out_img, int(img['id']), (h, w))
def img_reader_with_augment(img, size, mean, std, mixup_img):
im_path = img['image']
im = cv2.imread(im_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
gt_boxes = img['gt_boxes'].copy()
gt_labels = img['gt_labels'].copy()
if mixup_img:
mixup_im = cv2.imread(mixup_img['image'])
mixup_im = cv2.cvtColor(mixup_im, cv2.COLOR_BGR2RGB)
mixup_gt_boxes = mixup_img['gt_boxes'].copy()
mixup_gt_labels = mixup_img['gt_labels'].copy()
im, gt_boxes, gt_labels = image_utils.image_mixup(im, gt_boxes, gt_labels, \
mixup_im, mixup_gt_boxes, mixup_gt_labels)
im, gt_boxes, gt_labels = image_utils.image_augment(im, gt_boxes, gt_labels, size, mean)
# h, w, _ = im.shape
# im_scale_x = size / float(w)
# im_scale_y = size / float(h)
# im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (im / 255.0 - mean) / std
out_img = out_img.transpose((2, 0, 1)).astype('float32')
return (out_img, gt_boxes, gt_labels)
def get_img_size(size, random_sizes=[]):
if len(random_sizes):
return np.random.choice(random_sizes)
return size
def get_mixup_img(imgs, mixup_iter, total_read_cnt):
if total_read_cnt >= mixup_iter:
return None
mixup_idx = np.random.randint(1, len(imgs))
mixup_img = imgs[(total_read_cnt + mixup_idx) % len(imgs)]
return mixup_img
def reader():
if mode == 'train':
imgs = self._parse_images_by_mode(mode)
if shuffle:
np.random.shuffle(imgs)
read_cnt = 0
total_read_cnt = 0
batch_out = []
img_size = get_img_size(size, random_sizes)
# img_ids = []
while True:
img = imgs[read_cnt % len(imgs)]
mixup_img = get_mixup_img(imgs, mixup_iter, total_read_cnt)
read_cnt += 1
total_read_cnt += 1
if read_cnt % len(imgs) == 0 and shuffle:
np.random.shuffle(imgs)
im, gt_boxes, gt_labels = img_reader_with_augment(img, img_size, cfg.pixel_means, cfg.pixel_stds, mixup_img)
batch_out.append((im, gt_boxes, gt_labels))
# img_ids.append(img['id'])
if len(batch_out) == batch_size:
# print("img_ids: ", img_ids)
yield batch_out
batch_out = []
img_size = get_img_size(size, random_sizes)
# img_ids = []
elif mode == 'test':
imgs = self._parse_images_by_mode(mode)
batch_out = []
for img in imgs:
im, im_id, im_shape = img_reader(img, size, cfg.pixel_means, cfg.pixel_stds)
batch_out.append((im, im_id, im_shape))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
if len(batch_out) != 0:
yield batch_out
else:
img = {}
img['image'] = image
img['id'] = 0
im, im_id, im_shape = img_reader(img, size, cfg.pixel_means, cfg.pixel_stds)
batch_out = [(im, im_id, im_shape)]
yield batch_out
return reader
dsr = DataSetReader()
def train(size=416,
batch_size=64,
shuffle=True,
mixup_iter=0,
random_sizes=[],
use_multiprocessing=True,
num_workers=8,
max_queue=24):
generator = dsr.get_reader('train', size, batch_size, shuffle, mixup_iter, random_sizes)
if not use_multiprocessing:
return generator
def infinite_reader():
while True:
for data in generator():
yield data
def reader():
try:
enqueuer = GeneratorEnqueuer(
infinite_reader(), use_multiprocessing=use_multiprocessing)
enqueuer.start(max_queue_size=max_queue, workers=num_workers)
generator_out = None
while True:
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_out = enqueuer.queue.get()
break
else:
time.sleep(0.02)
yield generator_out
generator_out = None
finally:
if enqueuer is not None:
enqueuer.stop()
return reader
def test(size=416, batch_size=1):
return dsr.get_reader('test', size, batch_size)
def infer(size=416, image=None):
return dsr.get_reader('infer', size, image=image)
def get_label_infos():
return dsr.get_label_infos()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
import random
import time
import shutil
from utility import parse_args, print_arguments, SmoothedValue
import paddle
import paddle.fluid as fluid
import reader
import models
from learning_rate import exponential_with_warmup_decay
from config.config import cfg
def train():
if cfg.debug:
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
random.seed(0)
np.random.seed(0)
if not os.path.exists(cfg.model_save_dir):
os.makedirs(cfg.model_save_dir)
model = models.YOLOv3(cfg.model_cfg_path, use_pyreader=cfg.use_pyreader)
model.build_model()
input_size = model.get_input_size()
loss = model.loss()
loss.persistable = True
hyperparams = model.get_hyperparams()
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
print("Found {} CUDA devices.".format(devices_num))
learning_rate = float(hyperparams['learning_rate'])
boundaries = cfg.lr_steps
gamma = cfg.lr_gamma
step_num = len(cfg.lr_steps)
values = [learning_rate * (gamma**i) for i in range(step_num + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=exponential_with_warmup_decay(
learning_rate=learning_rate,
boundaries=boundaries,
values=values,
warmup_iter=cfg.warm_up_iter,
warmup_factor=cfg.warm_up_factor,
start_step=cfg.start_iter),
regularization=fluid.regularizer.L2Decay(float(hyperparams['decay'])),
momentum=float(hyperparams['momentum']))
optimizer.minimize(loss)
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
base_exe = fluid.Executor(place)
base_exe.run(fluid.default_startup_program())
if cfg.pretrain_base:
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrain_base, var.name))
fluid.io.load_vars(base_exe, cfg.pretrain_base, predicate=if_exist)
if cfg.parallel:
exe = fluid.ParallelExecutor( use_cuda=bool(cfg.use_gpu), loss_name=loss.name)
else:
exe = base_exe
random_sizes = []
if cfg.random_shape:
random_sizes = [32 * i for i in range(10, 20)]
mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter
if cfg.use_pyreader:
train_reader = reader.train(input_size, batch_size=int(hyperparams['batch'])/devices_num, shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes)
py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader)
else:
train_reader = reader.train(input_size, batch_size=int(hyperparams['batch']), shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
def save_model(postfix):
model_path = os.path.join(cfg.model_save_dir, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
fluid.io.save_persistables(base_exe, model_path)
fetch_list = [loss]
def train_loop_pyreader():
py_reader.start()
smoothed_loss = SmoothedValue(cfg.log_window)
try:
start_time = time.time()
prev_start_time = start_time
snapshot_loss = 0
snapshot_time = 0
for iter_id in range(cfg.start_iter, cfg.max_iter):
prev_start_time = start_time
start_time = time.time()
losses = exe.run(fetch_list=[v.name for v in fetch_list])
smoothed_loss.add_value(np.mean(np.array(losses[0])))
snapshot_loss += np.mean(np.array(losses[0]))
snapshot_time += start_time - prev_start_time
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format(
iter_id, lr[0],
smoothed_loss.get_median_value(), start_time - prev_start_time))
sys.stdout.flush()
if (iter_id + 1) % cfg.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
print("Snapshot {} saved, average loss: {}, average time: {}".format(
iter_id + 1, snapshot_loss / float(cfg.snapshot_iter),
snapshot_time / float(cfg.snapshot_iter)))
snapshot_loss = 0
snapshot_time = 0
except fluid.core.EOFException:
py_reader.reset()
def train_loop():
start_time = time.time()
prev_start_time = start_time
start = start_time
smoothed_loss = SmoothedValue(cfg.log_window)
snapshot_loss = 0
snapshot_time = 0
for iter_id, data in enumerate(train_reader()):
iter_id += cfg.start_iter
prev_start_time = start_time
start_time = time.time()
losses = exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data))
smoothed_loss.add_value(losses[0])
snapshot_loss += losses[0]
snapshot_time += start_time - prev_start_time
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Iter {:d}, lr: {:.6f}, loss: {:.4f}, time {:.5f}".format(
iter_id, lr[0], smoothed_loss.get_median_value(), start_time - prev_start_time))
sys.stdout.flush()
if (iter_id + 1) % cfg.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
print("Snapshot {} saved, average loss: {}, average time: {}".format(
iter_id + 1, snapshot_loss / float(cfg.snapshot_iter),
snapshot_time / float(cfg.snapshot_iter)))
snapshot_loss = 0
snapshot_time = 0
if (iter_id + 1) == cfg.max_iter:
print("Finish iter {}".format(iter_id))
break
if cfg.use_pyreader:
train_loop_pyreader()
else:
train_loop()
save_model('model_final')
if __name__ == '__main__':
args = parse_args()
print_arguments(args)
train()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
Contains common utility functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import distutils.util
import numpy as np
import six
from collections import deque
from paddle.fluid import core
import argparse
import functools
from config.config import *
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size):
self.deque = deque(maxlen=window_size)
def add_value(self, value):
self.deque.append(value)
def get_median_value(self):
return np.median(self.deque)
def parse_args():
"""return all args
"""
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
# ENV
add_arg('parallel', bool, True, "Whether use parallel.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('model_cfg_path', str, 'config/yolov3.cfg', "YOLO model config file path.")
add_arg('model_save_dir', str, 'checkpoints', "The path to save model.")
add_arg('pretrain_base', str, 'weights/darknet53', "The init model weights path.")
add_arg('pretrained_model', str, 'weights/mxnet', "The pretrained model path.")
add_arg('dataset', str, 'coco2017', "Dataset: coco2014, coco2017.")
add_arg('class_num', int, 80, "Class number.")
add_arg('data_dir', str, 'dataset/coco', "The data root path.")
add_arg('use_pyreader', bool, True, "Use pyreader.")
add_arg('use_profile', bool, False, "Whether use profiler.")
add_arg('start_iter', int, 0, "Start iteration.")
#SOLVER
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('max_iter', int, 500200, "Iter number.")
add_arg('snapshot_iter', int, 2000, "Save model every snapshot stride.")
add_arg('log_window', int, 20, "Log smooth window, set 1 for debug, set 20 for train.")
# TRAIN TEST INFER
add_arg('input_size', int, 608, "Image input size of YOLOv3.")
add_arg('random_shape', bool, False, "Resize to random shape for train reader")
add_arg('no_mixup_iter', int, 4000, "Disable mixup in last N iter.")
add_arg('valid_thresh', float, 0.01, "Valid confidence score for NMS.")
add_arg('nms_thresh', float, 0.45, "NMS threshold.")
add_arg('nms_topk', int, 400, "The number of boxes to perform NMS.")
add_arg('nms_posk', int, 100, "The number of boxes of NMS output.")
add_arg('debug', bool, False, "Debug mode")
# SINGLE EVAL AND DRAW
add_arg('image_path', str, 'image', "The image path used to inference and visualize.")
add_arg('image_name', str, None, "The single image used to inference and visualize. None to inference all images in image_path")
add_arg('draw_thresh', float, 0.5, "Confidence score threshold to draw prediction box in image in debug mode")
# yapf: enable
args = parser.parse_args()
file_name = sys.argv[0]
merge_cfg_from_args(args)
return args
#!/bin/bash
wget https://pjreddie.com/media/files/darknet53.conv.74 -O darknet53.pretrain
echo "download finish"
python weight_parser.py pretrain
echo "parse finish"
#!/bin/bash
wget https://pjreddie.com/media/files/yolov3.weights
echo "download finish"
python weight_parser.py yolov3
echo "parse finish"
#! /usr/bin/env bash
wget https://pjreddie.com/media/files/yolov3-tiny.weights
echo "download finish"
python weight_parser.py yolov3-tiny
echo "parse finish"
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
import shutil
import glob
import numpy as np
sys.path.append("..")
from config.config_parser import ConfigPaser
class WeightParser(object):
def __init__(self, weight_file, cfg_file, save_dir, conv_num=None):
self.weight_file = weight_file
self.cfg_file = cfg_file
self.save_dir = save_dir
self.conv_num = conv_num
self.cfg_parser = ConfigPaser(cfg_file)
def init_dir(self):
if os.path.exists(self.save_dir):
shutil.rmtree(self.save_dir)
os.mkdir(self.save_dir)
return self.save_dir
def parse_weight_to_separate_file(self):
self.save_dir = self.init_dir()
weights = np.fromfile(
open(self.weight_file, 'rb'),
dtype = np.float32)[5:]
# print("Total weight num: ", weights.shape[0])
w_idx = 0
model_defs = self.cfg_parser.parse()
if model_defs is None:
return None
hyperparams = model_defs.pop(0)
in_channels = [int(hyperparams['channels'])]
parsed_conv_num = 0
for i, layer_def in enumerate(model_defs):
if layer_def['type'] == 'convolutional':
filters = int(layer_def['filters'])
size = int(layer_def['size'])
conv_name = "conv" + str(i)
if layer_def.get('batch_normalize', 0):
bn_name = "bn" + str(i)
offset = weights[w_idx: w_idx + filters]
offset.tofile(os.path.join(self.save_dir, bn_name+"_offset"))
w_idx += filters
scale = weights[w_idx: w_idx + filters]
scale.tofile(os.path.join(self.save_dir, bn_name+"_scale"))
w_idx += filters
mean = weights[w_idx: w_idx + filters]
mean.tofile(os.path.join(self.save_dir, bn_name+"_mean"))
w_idx += filters
var = weights[w_idx: w_idx + filters]
var.tofile(os.path.join(self.save_dir, bn_name+"_var"))
w_idx += filters
else:
conv_bias = weights[w_idx: w_idx + filters]
conv_bias.tofile(os.path.join(self.save_dir, conv_name+"_bias"))
w_idx += filters
conv_weight_num = in_channels[-1] * filters * size * size
conv_weight = weights[w_idx: w_idx + conv_weight_num]
conv_weight.tofile(os.path.join(self.save_dir, conv_name+"_weights"))
w_idx += conv_weight_num
in_channels.append(filters)
# print(conv_name, "parse weight index: ", w_idx)
parsed_conv_num += 1
if self.conv_num is not None:
if parsed_conv_num >= self.conv_num:
break
if layer_def['type'] == 'route':
layers = map(int, layer_def['layers'].split(','))
out_channel = 0
for layer in layers:
if layer < 0:
out_channel += in_channels[layer]
else:
out_channel += in_channels[layer + 1]
in_channels.append(out_channel)
if layer_def['type'] in ['shortcut', 'yolo', 'upsample', 'maxpool']:
in_channels.append(in_channels[-1])
assert w_idx == weights.shape[0], "parse imcomplete"
def convert_file_to_fluid(self):
filenames = glob.glob(self.save_dir+"/*")
for filename in filenames:
src_filename = "./test/" + filename.split("/")[-1]
assert os.path.exists(src_filename)
with open(src_filename, 'rb') as f:
src_data = f.read()
with open(filename, 'rb') as f:
data = f.read()
head_len = len(src_data) - len(data)
with open(filename, 'wb') as f:
f.write(src_data[:head_len])
f.write(data)
def check_conver_result(self):
filenames = glob.glob(self.save_dir+"/*")
for filename in filenames:
src_filename = "./test/" + filename.split("/")[-1]
assert os.path.exists(src_filename)
f = np.fromfile(open(filename, 'rb'), dtype=np.int8)
sf = np.fromfile(open(src_filename, 'rb'), dtype=np.int8)
assert f.shape == sf.shape, "check {} failed {}, {}".format(filename, f.shape, sf.shape)
if __name__ == "__main__":
model = sys.argv[1]
if model == "pretrain":
weight_path = "darknet53.pretrain"
cfg_path = "../config/yolov3.cfg"
conv_num = 53 - 1
else:
weight_path = model + '.weights'
cfg_path = "../config/" + model + ".cfg"
conv_num = None
for path in [weight_path, cfg_path]:
if not os.path.isfile(path):
print(path, "not found!")
exit()
wp = WeightParser(weight_path, cfg_path, model, conv_num)
wp.parse_weight_to_separate_file()
wp.convert_file_to_fluid()
wp.check_conver_result()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册