未验证 提交 d48417e7 编写于 作者: J jerrywgz 提交者: GitHub

Detection model of Mask-RCNN. (#1501)

add mask-rcnn head 
上级 747f947e
# Faster RCNN Objective Detection
# RCNN Objective Detection
---
## Table of Contents
......@@ -17,17 +17,20 @@ Running sample code in this directory requires PaddelPaddle Fluid v.1.0.0 and la
## Introduction
[Faster Rcnn](https://arxiv.org/abs/1506.01497) is a typical two stage detector. The total framework of network can be divided into four parts, as shown below:
<p align="center">
<img src="image/Faster_RCNN.jpg" height=400 width=400 hspace='10'/> <br />
Faster RCNN model
</p>
Region Convolutional Neural Network (RCNN) models are two stages detector. According to proposals and feature extraction, obtain class and more precise proposals.
Now RCNN model contains two typical models: Faster RCNN and Mask RCNN.
[Faster RCNN](https://arxiv.org/abs/1506.01497), The total framework of network can be divided into four parts:
1. Base conv layer. As a CNN objective dection, Faster RCNN extract feature maps using a basic convolutional network. The feature maps then can be shared by RPN and fc layers. This sampel uses [ResNet-50](https://arxiv.org/abs/1512.03385) as base conv layer.
2. Region Proposal Network (RPN). RPN generates proposals for detection。This block generates anchors by a set of size and ratio and classifies anchors into fore-ground and back-ground by softmax. Then refine anchors to obtain more precise proposals using box regression.
3. RoI Align. This layer takes feature maps and proposals as input. The proposals are mapped to feature maps and pooled to the same size. The output are sent to fc layers for classification and regression. RoIPool and RoIAlign are used separately to this layer and it can be set in roi\_func in config.py.
4. Detection layer. Using the output of roi pooling to compute the class and locatoin of each proposal in two fc layers.
[Mask RCNN](https://arxiv.org/abs/1703.06870) is a classical instance segmentation model and an extension of Faster RCNN
Mask RCNN is a two stage model as well. At the first stage, it generates proposals from input images. At the second stage, it obtains class result, bbox and mask which is the result from segmentation branch on original Faster RCNN model. It decouples the relation between mask and classification.
## Data preparation
Train the model on [MS-COCO dataset](http://cocodataset.org/#download), download dataset as below:
......@@ -64,10 +67,12 @@ After data preparation, one can start the training step by:
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model}
--data_dir=${path_to_data}
--pretrained_model=${path_to_pretrain_model} \
--data_dir=${path_to_data} \
--MASK_ON=False
- Set ```export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7``` to specifiy 8 GPU to train.
- Set ```MASK\_ON``` to choose Faster RCNN or Mask RCNN model.
- For more help on arguments:
python train.py --help
......@@ -93,7 +98,6 @@ After data preparation, one can start the training step by:
* In first 500 iteration, the learning rate increases linearly from 0.00333 to 0.01. Then lr is decayed at 120000, 160000 iteration with multiplier 0.1, 0.01. The maximum iteration is 180000. Also, we released a 2x model which has 360000 iterations and lr is decayed at 240000, 320000. These configuration can be set by max_iter and lr_steps in config.py.
* Set the learning rate of bias to two times as global lr in non basic convolutional layers.
* In basic convolutional layers, parameters of affine layers and res body do not update.
* Use Nvidia Tesla V100 8GPU, total time for training is about 40 hours.
## Evaluation
......@@ -109,6 +113,8 @@ Evaluation is to evaluate the performance of a trained model. This sample provid
Evalutaion result is shown as below:
Faster RCNN:
| Model | RoI function | Batch size | Max iteration | mAP |
| :--------------- | :--------: | :------------: | :------------------: |------: |
| [Fluid RoIPool minibatch padding](http://paddlemodels.bj.bcebos.com/faster_rcnn/model_pool_minibatch_padding.tar.gz) | RoIPool | 8 | 180000 | 0.316 |
......@@ -121,6 +127,14 @@ Evalutaion result is shown as below:
* Fluid RoIAlign no padding: Images without padding.
* Fluid RoIAlign no padding 2x: Images without padding, train for 360000 iterations, learning rate is decayed at 240000, 320000.
Mask RCNN:
| Model | Batch size | Max iteration | box mAP | mask mAP |
| :--------------- | :--------: | :------------: | :--------: |------: |
| [Fluid mask no padding](https://paddlemodels.bj.bcebos.com/faster_rcnn/Fluid_mask_no_padding.tar.gz) | 8 | 180000 | 0.359 | 0.314 |
* Fluid mask no padding: Use RoIAlign. Images without padding.
## Inference and Visualization
Inference is used to get prediction score or image features based on trained models. `infer.py` is the main executor for inference, one can start infer step by:
......@@ -135,8 +149,12 @@ Inference is used to get prediction score or image features based on trained mod
Visualization of infer result is shown as below:
<p align="center">
<img src="image/000000000139.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000127517.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000203864.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000515077.jpg" height=300 width=400 hspace='10'/> <br />
<img src="image/000000127517.jpg" height=300 width=400 hspace='10'/> <br />
Faster RCNN Visualization Examples
</p>
<p align="center">
<img src="image/000000000139_mask.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000127517_mask.jpg" height=300 width=400 hspace='10'/> <br />
Mask RCNN Visualization Examples
</p>
# Faster RCNN 目标检测
# RCNN 系列目标检测
---
## 内容
......@@ -16,18 +16,21 @@
在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.0.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。
## 简介
区域卷积神经网络(RCNN)系列模型为两阶段目标检测器。通过对图像生成候选区域,提取特征,判别特征类别并修正候选框位置。
RCNN系列目前包含两个代表模型:Faster RCNN,Mask RCNN
[Faster Rcnn](https://arxiv.org/abs/1506.01497) 是典型的两阶段目标检测器。如下图所示,整体网络可以分为4个主要内容:
<p align="center">
<img src="image/Faster_RCNN.jpg" height=400 width=400 hspace='10'/> <br />
Faster RCNN 目标检测模型
</p>
[Faster RCNN](https://arxiv.org/abs/1506.01497) 整体网络可以分为4个主要内容:
1. 基础卷积层。作为一种卷积神经网络目标检测方法,Faster RCNN首先使用一组基础的卷积网络提取图像的特征图。特征图被后续RPN层和全连接层共享。本示例采用[ResNet-50](https://arxiv.org/abs/1512.03385)作为基础卷积层。
2. 区域生成网络(RPN)。RPN网络用于生成候选区域(proposals)。该层通过一组固定的尺寸和比例得到一组锚点(anchors), 通过softmax判断锚点属于前景或者背景,再利用区域回归修正锚点从而获得精确的候选区域。
3. RoI Align。该层收集输入的特征图和候选区域,将候选区域映射到特征图中并池化为统一大小的区域特征图,送入全连接层判定目标类别, 该层可选用RoIPool和RoIAlign两种方式,在config.py中设置roi\_func。
4. 检测层。利用区域特征图计算候选区域的类别,同时再次通过区域回归获得检测框最终的精确位置。
[Mask RCNN](https://arxiv.org/abs/1703.06870) 扩展自Faster RCNN,是经典的实例分割模型。
Mask RCNN同样为两阶段框架,第一阶段扫描图像生成候选框;第二阶段根据候选框得到分类结果,边界框,同时在原有Faster RCNN模型基础上添加分割分支,得到掩码结果,实现了掩码和类别预测关系的解藕。
## 数据准备
[MS-COCO数据集](http://cocodataset.org/#download)上进行训练,通过如下方式下载数据集。
......@@ -63,10 +66,12 @@ Faster RCNN 目标检测模型
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model}
--data_dir=${path_to_data}
--pretrained_model=${path_to_pretrain_model} \
--data_dir=${path_to_data} \
--MASK_ON=False
- 通过设置export CUDA\_VISIBLE\_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。
- 通过设置MASK\_ON选择Faster RCNN和Mask RCNN模型。
- 可选参数见:
python train.py --help
......@@ -83,11 +88,10 @@ Faster RCNN 目标检测模型
**训练策略:**
* 采用momentum优化算法训练Faster RCNN,momentum=0.9。
* 采用momentum优化算法训练,momentum=0.9。
* 权重衰减系数为0.0001,前500轮学习率从0.00333线性增加至0.01。在120000,160000轮时使用0.1,0.01乘子进行学习率衰减,最大训练180000轮。同时我们也提供了2x模型,该模型采用更多的迭代轮数进行训练,训练360000轮,学习率在240000,320000轮衰减,其他参数不变,训练最大轮数和学习率策略可以在config.py中对max_iter和lr_steps进行设置。
* 非基础卷积层卷积bias学习率为整体学习率2倍。
* 基础卷积层中,affine_layers参数不更新,res2层参数不更新。
* 使用Nvidia Tesla V100 8卡并行,总共训练时长大约40小时。
## 模型评估
......@@ -103,6 +107,8 @@ Faster RCNN 目标检测模型
下表为模型评估结果:
Faster RCNN
| 模型 | RoI处理方式 | 批量大小 | 迭代次数 | mAP |
| :--------------- | :--------: | :------------: | :------------------: |------: |
| [Fluid RoIPool minibatch padding](http://paddlemodels.bj.bcebos.com/faster_rcnn/model_pool_minibatch_padding.tar.gz) | RoIPool | 8 | 180000 | 0.316 |
......@@ -117,6 +123,13 @@ Faster RCNN 目标检测模型
* Fluid RoIAlign no padding: 使用RoIAlign,不对图像做填充处理。
* Fluid RoIAlign no padding 2x: 使用RoIAlign,不对图像做填充处理。训练360000轮,学习率在240000,320000轮衰减。
Mask RCNN
| 模型 | 批量大小 | 迭代次数 | box mAP | mask mAP |
| :--------------- | :--------: | :------------: | :--------: |------: |
| [Fluid mask no padding](https://paddlemodels.bj.bcebos.com/faster_rcnn/Fluid_mask_no_padding.tar.gz) | 8 | 180000 | 0.359 | 0.314 |
* Fluid mask no padding: 使用RoIAlign,不对图像做填充处理
## 模型推断及可视化
模型推断可以获取图像中的物体及其对应的类别,`infer.py`是主要执行程序,调用示例如下:
......@@ -131,8 +144,12 @@ Faster RCNN 目标检测模型
下图为模型可视化预测结果:
<p align="center">
<img src="image/000000000139.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000127517.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000203864.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000515077.jpg" height=300 width=400 hspace='10'/> <br />
<img src="image/000000127517.jpg" height=300 width=400 hspace='10'/> <br />
Faster RCNN 预测可视化
</p>
<p align="center">
<img src="image/000000000139_mask.jpg" height=300 width=400 hspace='10'/>
<img src="image/000000127517_mask.jpg" height=300 width=400 hspace='10'/> <br />
Mask RCNN 预测可视化
</p>
......@@ -69,6 +69,7 @@ def clip_xyxy_to_image(x1, y1, x2, y2, height, width):
y2 = np.minimum(height - 1., np.maximum(0., y2))
return x1, y1, x2, y2
def nms(dets, thresh):
"""Apply classic DPM-style greedy NMS."""
if dets.shape[0] == 0:
......@@ -123,3 +124,21 @@ def nms(dets, thresh):
return np.where(suppressed == 0)[0]
def expand_boxes(boxes, scale):
"""Expand an array of boxes by a given scale."""
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half *= scale
h_half *= scale
boxes_exp = np.zeros(boxes.shape)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
# Based on:
# --------------------------------------------------------
# Detectron
# Copyright (c) 2017-present, Facebook, Inc.
# Licensed under the Apache License, Version 2.0;
# Written by Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
def colormap(rgb=False):
color_list = np.array([
0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494,
0.184, 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078,
0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000,
1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000,
0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667,
0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000,
0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000,
1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000,
0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500,
0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667,
0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333,
0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000,
0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333,
0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000,
1.000, 0.667, 0.333, 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000,
1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.167,
0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000,
0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000,
0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000,
0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000,
0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833,
0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286,
0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714,
0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000
]).astype(np.float32)
color_list = color_list.reshape((-1, 3)) * 255
if not rgb:
color_list = color_list[:, ::-1]
return color_list
......@@ -90,6 +90,9 @@ _C.TRAIN.freeze_at = 2
# min area of ground truth box
_C.TRAIN.gt_min_area = -1
# Use horizontally-flipped images during training?
_C.TRAIN.use_flipped = True
#
# Inference options
#
......@@ -120,7 +123,7 @@ _C.TEST.rpn_post_nms_top_n = 1000
_C.TEST.rpn_min_size = 0.0
# max number of detections
_C.TEST.detectiions_per_im = 100
_C.TEST.detections_per_im = 100
# NMS threshold used on RPN proposals
_C.TEST.rpn_nms_thresh = 0.7
......@@ -129,6 +132,9 @@ _C.TEST.rpn_nms_thresh = 0.7
# Model options
#
# Whether use mask rcnn head
_C.MASK_ON = True
# weight for bbox regression targets
_C.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2]
......@@ -156,6 +162,15 @@ _C.roi_resolution = 14
# spatial scale
_C.spatial_scale = 1. / 16.
# resolution to represent mask labels
_C.resolution = 14
# Number of channels in the mask head
_C.dim_reduced = 256
# Threshold for converting soft masks to hard masks
_C.mrcnn_thresh_binarize = 0.5
#
# SOLVER options
#
......
......@@ -18,8 +18,7 @@ from __future__ import print_function
import os
import time
import numpy as np
from eval_helper import get_nmsed_box
from eval_helper import get_dt_res
from eval_helper import *
import paddle
import paddle.fluid as fluid
import reader
......@@ -44,7 +43,7 @@ def eval():
devices_num = len(devices.split(","))
total_batch_size = devices_num * cfg.TRAIN.im_per_batch
cocoGt = COCO(os.path.join(cfg.data_dir, test_list))
numId_to_catId_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
num_id_to_cat_id_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
category_ids = cocoGt.getCatIds()
label_list = {
item['id']: item['name']
......@@ -58,45 +57,76 @@ def eval():
use_pyreader=False,
is_train=False)
model.build_model(image_shape)
rpn_rois, confs, locs = model.eval_out()
rpn_rois, confs, locs = model.eval_bbox_out()
pred_boxes = model.eval()
if cfg.MASK_ON:
masks = model.eval_mask_out()
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# yapf: disable
if cfg.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
# yapf: enable
test_reader = reader.test(total_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
dts_res = []
fetch_list = [rpn_rois, confs, locs]
segms_res = []
if cfg.MASK_ON:
fetch_list = [rpn_rois, confs, locs, pred_boxes, masks]
else:
fetch_list = [rpn_rois, confs, locs]
for batch_id, batch_data in enumerate(test_reader()):
start = time.time()
im_info = []
for data in batch_data:
im_info.append(data[1])
rpn_rois_v, confs_v, locs_v = exe.run(
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(batch_data),
return_numpy=False)
new_lod, nmsed_out = get_nmsed_box(rpn_rois_v, confs_v, locs_v,
class_nums, im_info,
numId_to_catId_map)
result = exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(batch_data),
return_numpy=False)
rpn_rois_v = result[0]
confs_v = result[1]
locs_v = result[2]
if cfg.MASK_ON:
pred_boxes_v = result[3]
masks_v = result[4]
new_lod = pred_boxes_v.lod()
nmsed_out = pred_boxes_v
dts_res += get_dt_res(total_batch_size, new_lod, nmsed_out, batch_data)
dts_res += get_dt_res(total_batch_size, new_lod[0], nmsed_out,
batch_data, num_id_to_cat_id_map)
if cfg.MASK_ON and np.array(masks_v).shape != (1, 1):
segms_out = segm_results(nmsed_out, masks_v, im_info)
segms_res += get_segms_res(total_batch_size, new_lod[0], segms_out,
batch_data, num_id_to_cat_id_map)
end = time.time()
print('batch id: {}, time: {}'.format(batch_id, end - start))
with open("detection_result.json", 'w') as outfile:
with open("detection_bbox_result.json", 'w') as outfile:
json.dump(dts_res, outfile)
print("start evaluate using coco api")
cocoDt = cocoGt.loadRes("detection_result.json")
print("start evaluate bbox using coco api")
cocoDt = cocoGt.loadRes("detection_bbox_result.json")
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
if cfg.MASK_ON:
with open("detection_segms_result.json", 'w') as outfile:
json.dump(segms_res, outfile)
print("start evaluate mask using coco api")
cocoDt = cocoGt.loadRes("detection_segms_result.json")
cocoEval = COCOeval(cocoGt, cocoDt, 'segm')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
if __name__ == '__main__':
args = parse_args()
......
......@@ -21,6 +21,10 @@ from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
from config import cfg
import pycocotools.mask as mask_util
import six
from colormap import colormap
import cv2
def box_decoder(deltas, boxes, weights):
......@@ -80,8 +84,7 @@ def clip_tiled_boxes(boxes, im_shape):
return boxes
def get_nmsed_box(rpn_rois, confs, locs, class_nums, im_info,
numId_to_catId_map):
def get_nmsed_box(rpn_rois, confs, locs, class_nums, im_info):
lod = rpn_rois.lod()[0]
rpn_rois_v = np.array(rpn_rois)
variance_v = np.array(cfg.bbox_reg_weights)
......@@ -106,38 +109,41 @@ def get_nmsed_box(rpn_rois, confs, locs, class_nums, im_info,
inds = np.where(scores_n[:, j] > cfg.TEST.score_thresh)[0]
scores_j = scores_n[inds, j]
rois_j = rois_n[inds, j * 4:(j + 1) * 4]
dets_j = np.hstack((rois_j, scores_j[:, np.newaxis])).astype(
dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
np.float32, copy=False)
keep = box_utils.nms(dets_j, cfg.TEST.nms_thresh)
nms_dets = dets_j[keep, :]
#add labels
cat_id = numId_to_catId_map[j]
label = np.array([cat_id for _ in range(len(keep))])
label = np.array([j for _ in range(len(keep))])
nms_dets = np.hstack((nms_dets, label[:, np.newaxis])).astype(
np.float32, copy=False)
cls_boxes[j] = nms_dets
# Limit to max_per_image detections **over all classes**
image_scores = np.hstack(
[cls_boxes[j][:, -2] for j in range(1, class_nums)])
if len(image_scores) > cfg.TEST.detectiions_per_im:
image_thresh = np.sort(image_scores)[-cfg.TEST.detectiions_per_im]
[cls_boxes[j][:, 1] for j in range(1, class_nums)])
if len(image_scores) > cfg.TEST.detections_per_im:
image_thresh = np.sort(image_scores)[-cfg.TEST.detections_per_im]
for j in range(1, class_nums):
keep = np.where(cls_boxes[j][:, -2] >= image_thresh)[0]
keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0]
cls_boxes[j] = cls_boxes[j][keep, :]
im_results_n = np.vstack([cls_boxes[j] for j in range(1, class_nums)])
im_results[i] = im_results_n
new_lod.append(len(im_results_n) + new_lod[-1])
boxes = im_results_n[:, :-2]
scores = im_results_n[:, -2]
labels = im_results_n[:, -1]
boxes = im_results_n[:, 2:]
scores = im_results_n[:, 1]
labels = im_results_n[:, 0]
im_results = np.vstack([im_results[k] for k in range(len(lod) - 1)])
return new_lod, im_results
def get_dt_res(batch_size, lod, nmsed_out, data):
def get_dt_res(batch_size, lod, nmsed_out, data, num_id_to_cat_id_map):
dts_res = []
nmsed_out_v = np.array(nmsed_out)
if nmsed_out_v.shape == (
1,
1, ):
return dts_res
assert (len(lod) == batch_size + 1), \
"Error Lod Tensor offset dimension. Lod({}) vs. batch_size({})"\
.format(len(lod), batch_size)
......@@ -150,7 +156,8 @@ def get_dt_res(batch_size, lod, nmsed_out, data):
for j in range(dt_num_this_img):
dt = nmsed_out_v[k]
k = k + 1
xmin, ymin, xmax, ymax, score, category_id = dt.tolist()
num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
category_id = num_id_to_cat_id_map[num_id]
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
......@@ -164,24 +171,131 @@ def get_dt_res(batch_size, lod, nmsed_out, data):
return dts_res
def draw_bounding_box_on_image(image_path, nms_out, draw_threshold, label_list):
image = Image.open(image_path)
def get_segms_res(batch_size, lod, segms_out, data, num_id_to_cat_id_map):
segms_res = []
segms_out_v = np.array(segms_out)
k = 0
for i in range(batch_size):
dt_num_this_img = lod[i + 1] - lod[i]
image_id = int(data[i][-1])
for j in range(dt_num_this_img):
dt = segms_out_v[k]
k = k + 1
segm, num_id, score = dt.tolist()
cat_id = num_id_to_cat_id_map[num_id]
if six.PY3:
if 'counts' in segm:
segm['counts'] = rle['counts'].decode("utf8")
segm_res = {
'image_id': image_id,
'category_id': cat_id,
'segmentation': segm,
'score': score
}
segms_res.append(segm_res)
return segms_res
def draw_bounding_box_on_image(image_path,
nms_out,
draw_threshold,
label_list,
num_id_to_cat_id_map,
image=None):
if image is None:
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
for dt in nms_out:
xmin, ymin, xmax, ymax, score, category_id = dt.tolist()
for dt in np.array(nms_out):
num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
category_id = num_id_to_cat_id_map[num_id]
if score < draw_threshold:
continue
bbox = dt[:4]
xmin, ymin, xmax, ymax = bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=4,
width=2,
fill='red')
if image.mode == 'RGB':
draw.text((xmin, ymin), label_list[int(category_id)], (255, 255, 0))
image_name = image_path.split('/')[-1]
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)
def draw_mask_on_image(image_path, segms_out, draw_threshold, alpha=0.7):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
mask_color_id = 0
w_ratio = .4
image = np.array(image).astype('float32')
for dt in np.array(segms_out):
segm, num_id, score = dt.tolist()
if score < draw_threshold:
continue
mask = mask_util.decode(segm) * 255
color_list = colormap(rgb=True)
color_mask = color_list[mask_color_id % len(color_list), 0:3]
mask_color_id += 1
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
image[idx[0], idx[1], :] *= 1.0 - alpha
image[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(image.astype('uint8'))
return image
def segm_results(im_results, masks, im_info):
im_results = np.array(im_results)
class_num = cfg.class_num
M = cfg.resolution
scale = (M + 2.0) / M
lod = masks.lod()[0]
masks_v = np.array(masks)
boxes = im_results[:, 2:]
labels = im_results[:, 0]
segms_results = [[] for _ in range(len(lod) - 1)]
sum = 0
for i in range(len(lod) - 1):
im_results_n = im_results[lod[i]:lod[i + 1]]
cls_segms = []
masks_n = masks_v[lod[i]:lod[i + 1]]
boxes_n = boxes[lod[i]:lod[i + 1]]
labels_n = labels[lod[i]:lod[i + 1]]
im_h = int(round(im_info[i][0] / im_info[i][2]))
im_w = int(round(im_info[i][1] / im_info[i][2]))
boxes_n = box_utils.expand_boxes(boxes_n, scale)
boxes_n = boxes_n.astype(np.int32)
padded_mask = np.zeros((M + 2, M + 2), dtype=np.float32)
for j in range(len(im_results_n)):
class_id = int(labels_n[j])
padded_mask[1:-1, 1:-1] = masks_n[j, class_id, :, :]
ref_box = boxes_n[j, :]
w = ref_box[2] - ref_box[0] + 1
h = ref_box[3] - ref_box[1] + 1
w = np.maximum(w, 1)
h = np.maximum(h, 1)
mask = cv2.resize(padded_mask, (w, h))
mask = np.array(mask > cfg.mrcnn_thresh_binarize, dtype=np.uint8)
im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
x_0 = max(ref_box[0], 0)
x_1 = min(ref_box[2] + 1, im_w)
y_0 = max(ref_box[1], 0)
y_1 = min(ref_box[3] + 1, im_h)
im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - ref_box[1]):(y_1 - ref_box[
1]), (x_0 - ref_box[0]):(x_1 - ref_box[0])]
sum += im_mask.sum()
rle = mask_util.encode(
np.array(
im_mask[:, :, np.newaxis], order='F'))[0]
cls_segms.append(rle)
segms_results[i] = np.array(cls_segms)[:, np.newaxis]
segms_results = np.vstack([segms_results[k] for k in range(len(lod) - 1)])
im_results = np.hstack([segms_results, im_results])
return im_results[:, :3]
import os
import time
import numpy as np
from eval_helper import get_nmsed_box
from eval_helper import get_dt_res
from eval_helper import draw_bounding_box_on_image
from eval_helper import *
import paddle
import paddle.fluid as fluid
import reader
......@@ -24,7 +22,7 @@ def infer():
test_list = 'annotations/instances_val2017.json'
cocoGt = COCO(os.path.join(cfg.data_dir, test_list))
numId_to_catId_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
num_id_to_cat_id_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
category_ids = cocoGt.getCatIds()
label_list = {
item['id']: item['name']
......@@ -40,7 +38,10 @@ def infer():
use_pyreader=False,
is_train=False)
model.build_model(image_shape)
rpn_rois, confs, locs = model.eval_out()
rpn_rois, confs, locs = model.eval_bbox_out()
pred_boxes = model.eval()
if cfg.MASK_ON:
masks = model.eval_mask_out()
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
......@@ -53,17 +54,32 @@ def infer():
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
dts_res = []
fetch_list = [rpn_rois, confs, locs]
segms_res = []
if cfg.MASK_ON:
fetch_list = [rpn_rois, confs, locs, pred_boxes, masks]
else:
fetch_list = [rpn_rois, confs, locs]
data = next(infer_reader())
im_info = [data[0][1]]
rpn_rois_v, confs_v, locs_v = exe.run(
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data),
return_numpy=False)
new_lod, nmsed_out = get_nmsed_box(rpn_rois_v, confs_v, locs_v, class_nums,
im_info, numId_to_catId_map)
result = exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data),
return_numpy=False)
rpn_rois_v = result[0]
confs_v = result[1]
locs_v = result[2]
if cfg.MASK_ON:
pred_boxes_v = result[3]
masks_v = result[4]
new_lod = pred_boxes_v.lod()
nmsed_out = pred_boxes_v
path = os.path.join(cfg.image_path, cfg.image_name)
draw_bounding_box_on_image(path, nmsed_out, cfg.draw_threshold, label_list)
image = None
if cfg.MASK_ON:
segms_out = segm_results(nmsed_out, masks_v, im_info)
image = draw_mask_on_image(path, segms_out, cfg.draw_threshold)
draw_bounding_box_on_image(path, nmsed_out, cfg.draw_threshold, label_list,
num_id_to_cat_id_map, image)
if __name__ == '__main__':
......
......@@ -16,8 +16,11 @@ import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.initializer import Normal
from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay
from config import cfg
import cPickle as cp
import numpy as np
class FasterRCNN(object):
......@@ -32,7 +35,6 @@ class FasterRCNN(object):
self.is_train = is_train
self.use_pyreader = use_pyreader
self.use_random = use_random
#self.py_reader = None
def build_model(self, image_shape):
self.build_input(image_shape)
......@@ -41,31 +43,64 @@ class FasterRCNN(object):
self.rpn_heads(body_conv)
# Fast RCNN
self.fast_rcnn_heads(body_conv)
# Mask RCNN
if cfg.MASK_ON:
self.mask_rcnn_heads(body_conv)
def loss(self):
losses = []
# Fast RCNN loss
loss_cls, loss_bbox = self.fast_rcnn_loss()
# RPN loss
rpn_cls_loss, rpn_reg_loss = self.rpn_loss()
return loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss,
losses = [loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss]
rkeys = ['loss', 'loss_cls', 'loss_bbox', \
'loss_rpn_cls', 'loss_rpn_bbox',]
if cfg.MASK_ON:
loss_mask = self.mask_rcnn_loss()
losses = losses + [loss_mask]
rkeys = rkeys + ["loss_mask"]
loss = fluid.layers.sum(losses)
rloss = [loss] + losses
return rloss, rkeys
def eval_out(self):
def eval_bbox_out(self):
cls_prob = fluid.layers.softmax(self.cls_score, use_cudnn=False)
return [self.rpn_rois, cls_prob, self.bbox_pred]
def eval_mask_out(self):
return self.mask_fcn_logits
def eval(self):
return self.pred_result
def build_input(self, image_shape):
if self.use_pyreader:
in_shapes = [[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1],
[-1, 3], [-1, 1]]
lod_levels = [0, 1, 1, 1, 0, 0]
dtypes = [
'float32', 'float32', 'int32', 'int32', 'float32', 'int32'
]
if cfg.MASK_ON:
in_shapes.append([-1, 2])
lod_levels.append(3)
dtypes.append('float32')
self.py_reader = fluid.layers.py_reader(
capacity=64,
shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1], [-1, 3],
[-1, 1]],
lod_levels=[0, 1, 1, 1, 0, 0],
dtypes=[
"float32", "float32", "int32", "int32", "float32", "int32"
],
shapes=in_shapes,
lod_levels=lod_levels,
dtypes=dtypes,
use_double_buffer=True)
self.image, self.gt_box, self.gt_label, self.is_crowd, \
self.im_info, self.im_id = fluid.layers.read_file(self.py_reader)
ins = fluid.layers.read_file(self.py_reader)
self.image = ins[0]
self.gt_box = ins[1]
self.gt_label = ins[2]
self.is_crowd = ins[3]
self.im_info = ins[4]
self.im_id = ins[5]
if cfg.MASK_ON:
self.gt_masks = ins[6]
else:
self.image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
......@@ -74,22 +109,26 @@ class FasterRCNN(object):
self.gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='int32', lod_level=1)
self.is_crowd = fluid.layers.data(
name='is_crowd',
shape=[-1],
dtype='int32',
lod_level=1,
append_batch_size=False)
name='is_crowd', shape=[1], dtype='int32', lod_level=1)
self.im_info = fluid.layers.data(
name='im_info', shape=[3], dtype='float32')
self.im_id = fluid.layers.data(
name='im_id', shape=[1], dtype='int32')
if cfg.MASK_ON:
self.gt_masks = fluid.layers.data(
name='gt_masks', shape=[2], dtype='float32', lod_level=3)
def feeds(self):
if not self.is_train:
return [self.image, self.im_info, self.im_id]
if not cfg.MASK_ON:
return [
self.image, self.gt_box, self.gt_label, self.is_crowd,
self.im_info, self.im_id
]
return [
self.image, self.gt_box, self.gt_label, self.is_crowd, self.im_info,
self.im_id
self.im_id, self.gt_masks
]
def rpn_heads(self, rpn_input):
......@@ -157,7 +196,7 @@ class FasterRCNN(object):
nms_thresh = param_obj.rpn_nms_thresh
min_size = param_obj.rpn_min_size
eta = param_obj.rpn_eta
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
self.rpn_rois, self.rpn_roi_probs = fluid.layers.generate_proposals(
scores=rpn_cls_score_prob,
bbox_deltas=self.rpn_bbox_pred,
im_info=self.im_info,
......@@ -168,10 +207,9 @@ class FasterRCNN(object):
nms_thresh=nms_thresh,
min_size=min_size,
eta=eta)
self.rpn_rois = rpn_rois
if self.is_train:
outs = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
rpn_rois=self.rpn_rois,
gt_classes=self.gt_label,
is_crowd=self.is_crowd,
gt_boxes=self.gt_box,
......@@ -191,27 +229,28 @@ class FasterRCNN(object):
self.bbox_inside_weights = outs[3]
self.bbox_outside_weights = outs[4]
if cfg.MASK_ON:
mask_out = fluid.layers.generate_mask_labels(
im_info=self.im_info,
gt_classes=self.gt_label,
is_crowd=self.is_crowd,
gt_segms=self.gt_masks,
rois=self.rois,
labels_int32=self.labels_int32,
num_classes=cfg.class_num,
resolution=cfg.resolution)
self.mask_rois = mask_out[0]
self.roi_has_mask_int32 = mask_out[1]
self.mask_int32 = mask_out[2]
def fast_rcnn_heads(self, roi_input):
if self.is_train:
pool_rois = self.rois
else:
pool_rois = self.rpn_rois
if cfg.roi_func == 'RoIPool':
pool = fluid.layers.roi_pool(
input=roi_input,
rois=pool_rois,
pooled_height=cfg.roi_resolution,
pooled_width=cfg.roi_resolution,
spatial_scale=cfg.spatial_scale)
elif cfg.roi_func == 'RoIAlign':
pool = fluid.layers.roi_align(
input=roi_input,
rois=pool_rois,
pooled_height=cfg.roi_resolution,
pooled_width=cfg.roi_resolution,
spatial_scale=cfg.spatial_scale,
sampling_ratio=cfg.sampling_ratio)
rcnn_out = self.add_roi_box_head_func(pool)
self.res5_2_sum = self.add_roi_box_head_func(roi_input, pool_rois)
rcnn_out = fluid.layers.pool2d(
self.res5_2_sum, pool_type='avg', pool_size=7, name='res5_pool')
self.cls_score = fluid.layers.fc(input=rcnn_out,
size=cfg.class_num,
act=None,
......@@ -237,15 +276,110 @@ class FasterRCNN(object):
learning_rate=2.,
regularizer=L2Decay(0.)))
def SuffixNet(self, conv5):
mask_out = fluid.layers.conv2d_transpose(
input=conv5,
num_filters=cfg.dim_reduced,
filter_size=2,
stride=2,
act='relu',
param_attr=ParamAttr(
name='conv5_mask_w', initializer=MSRA(uniform=False)),
bias_attr=ParamAttr(
name='conv5_mask_b', learning_rate=2., regularizer=L2Decay(0.)))
act_func = None
if not self.is_train:
act_func = 'sigmoid'
mask_fcn_logits = fluid.layers.conv2d(
input=mask_out,
num_filters=cfg.class_num,
filter_size=1,
act=act_func,
param_attr=ParamAttr(
name='mask_fcn_logits_w', initializer=MSRA(uniform=False)),
bias_attr=ParamAttr(
name="mask_fcn_logits_b",
learning_rate=2.,
regularizer=L2Decay(0.)))
if not self.is_train:
mask_fcn_logits = fluid.layers.lod_reset(mask_fcn_logits,
self.pred_result)
return mask_fcn_logits
def mask_rcnn_heads(self, mask_input):
if self.is_train:
conv5 = fluid.layers.gather(self.res5_2_sum,
self.roi_has_mask_int32)
self.mask_fcn_logits = self.SuffixNet(conv5)
else:
im_scale = fluid.layers.slice(
self.im_info, [1], starts=[2], ends=[3])
im_scale_lod = fluid.layers.sequence_expand(im_scale, self.rpn_rois)
boxes = self.rpn_rois / im_scale_lod
cls_prob = fluid.layers.softmax(self.cls_score, use_cudnn=False)
bbox_pred_reshape = fluid.layers.reshape(self.bbox_pred,
(-1, cfg.class_num, 4))
decoded_box = fluid.layers.box_coder(
prior_box=boxes,
prior_box_var=cfg.bbox_reg_weights,
target_box=bbox_pred_reshape,
code_type='decode_center_size',
box_normalized=False,
axis=1)
cliped_box = fluid.layers.box_clip(
input=decoded_box, im_info=self.im_info)
self.pred_result = fluid.layers.multiclass_nms(
bboxes=cliped_box,
scores=cls_prob,
score_threshold=cfg.TEST.score_thresh,
nms_top_k=-1,
nms_threshold=cfg.TEST.nms_thresh,
keep_top_k=cfg.TEST.detections_per_im,
normalized=False)
pred_res_shape = fluid.layers.shape(self.pred_result)
shape = fluid.layers.reduce_prod(pred_res_shape)
shape = fluid.layers.reshape(shape, [1, 1])
ones = fluid.layers.fill_constant([1, 1], value=1, dtype='int32')
cond = fluid.layers.equal(x=shape, y=ones)
ie = fluid.layers.IfElse(cond)
with ie.true_block():
pred_res_null = ie.input(self.pred_result)
ie.output(pred_res_null)
with ie.false_block():
pred_res = ie.input(self.pred_result)
pred_boxes = fluid.layers.slice(
pred_res, [1], starts=[2], ends=[6])
im_scale_lod = fluid.layers.sequence_expand(im_scale,
pred_boxes)
mask_rois = pred_boxes * im_scale_lod
conv5 = self.add_roi_box_head_func(mask_input, mask_rois)
mask_fcn = self.SuffixNet(conv5)
ie.output(mask_fcn)
self.mask_fcn_logits = ie()[0]
def mask_rcnn_loss(self):
mask_label = fluid.layers.cast(x=self.mask_int32, dtype='float32')
reshape_dim = cfg.class_num * cfg.resolution * cfg.resolution
mask_fcn_logits_reshape = fluid.layers.reshape(self.mask_fcn_logits,
(-1, reshape_dim))
loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits(
x=mask_fcn_logits_reshape,
label=mask_label,
ignore_index=-1,
normalize=True)
loss_mask = fluid.layers.reduce_sum(loss_mask, name='loss_mask')
return loss_mask
def fast_rcnn_loss(self):
labels_int64 = fluid.layers.cast(x=self.labels_int32, dtype='int64')
labels_int64.stop_gradient = True
#loss_cls = fluid.layers.softmax_with_cross_entropy(
# logits=cls_score,
# label=labels_int64
# )
cls_prob = fluid.layers.softmax(self.cls_score, use_cudnn=False)
loss_cls = fluid.layers.cross_entropy(cls_prob, labels_int64)
loss_cls = fluid.layers.softmax_with_cross_entropy(
logits=self.cls_score,
label=labels_int64,
numeric_stable_mode=True, )
loss_cls = fluid.layers.reduce_mean(loss_cls)
loss_bbox = fluid.layers.smooth_l1(
x=self.bbox_pred,
......@@ -303,5 +437,4 @@ class FasterRCNN(object):
norm = fluid.layers.reduce_prod(score_shape)
norm.stop_gradient = True
rpn_reg_loss = rpn_reg_loss / norm
return rpn_cls_loss, rpn_reg_loss
......@@ -160,8 +160,22 @@ def add_ResNet50_conv4_body(body_input):
return res4
def add_ResNet_roi_conv5_head(head_input):
res5 = layer_warp(bottleneck, head_input, 512, 3, 2, name="res5")
res5_pool = fluid.layers.pool2d(
res5, pool_type='avg', pool_size=7, name='res5_pool')
return res5_pool
def add_ResNet_roi_conv5_head(head_input, rois):
if cfg.roi_func == 'RoIPool':
pool = fluid.layers.roi_pool(
input=head_input,
rois=rois,
pooled_height=cfg.roi_resolution,
pooled_width=cfg.roi_resolution,
spatial_scale=cfg.spatial_scale)
elif cfg.roi_func == 'RoIAlign':
pool = fluid.layers.roi_align(
input=head_input,
rois=rois,
pooled_height=cfg.roi_resolution,
pooled_width=cfg.roi_resolution,
spatial_scale=cfg.spatial_scale,
sampling_ratio=cfg.sampling_ratio)
res5 = layer_warp(bottleneck, pool, 512, 3, 2, name="res5")
return res5
......@@ -43,12 +43,9 @@ def train():
use_pyreader=cfg.use_pyreader,
use_random=False)
model.build_model(image_shape)
loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss = model.loss()
loss_cls.persistable = True
loss_bbox.persistable = True
rpn_cls_loss.persistable = True
rpn_reg_loss.persistable = True
loss = loss_cls + loss_bbox + rpn_cls_loss + rpn_reg_loss
losses, keys = model.loss()
loss = losses[0]
fetch_list = [loss]
boundaries = cfg.lr_steps
gamma = cfg.lr_gamma
......@@ -95,8 +92,6 @@ def train():
train_reader = reader.train(batch_size=total_batch_size, shuffle=False)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
fetch_list = [loss, loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss]
def run(iterations):
reader_time = []
run_time = []
......@@ -109,20 +104,16 @@ def train():
reader_time.append(end_time - start_time)
start_time = time.time()
if cfg.parallel:
losses = train_exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data))
outs = train_exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data))
else:
losses = exe.run(fluid.default_main_program(),
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data))
outs = exe.run(fluid.default_main_program(),
fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data))
end_time = time.time()
run_time.append(end_time - start_time)
total_images += len(data)
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Batch {:d}, lr {:.6f}, loss {:.6f} ".format(batch_id, lr[0],
losses[0][0]))
print("Batch {:d}, loss {:.6f} ".format(batch_id, np.mean(outs[0])))
return reader_time, run_time, total_images
def run_pyreader(iterations):
......@@ -135,18 +126,16 @@ def train():
for batch_id in range(iterations):
start_time = time.time()
if cfg.parallel:
losses = train_exe.run(
outs = train_exe.run(
fetch_list=[v.name for v in fetch_list])
else:
losses = exe.run(fluid.default_main_program(),
fetch_list=[v.name for v in fetch_list])
outs = exe.run(fluid.default_main_program(),
fetch_list=[v.name for v in fetch_list])
end_time = time.time()
run_time.append(end_time - start_time)
total_images += devices_num
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Batch {:d}, lr {:.6f}, loss {:.6f} ".format(batch_id, lr[
0], losses[0][0]))
print("Batch {:d}, loss {:.6f} ".format(batch_id,
np.mean(outs[0])))
except fluid.core.EOFException:
py_reader.reset()
......
......@@ -27,6 +27,46 @@ from collections import deque
from roidbs import JsonDataset
import data_utils
from config import cfg
import segm_utils
def roidb_reader(roidb, mode):
im, im_scales = data_utils.get_image_blob(roidb, mode)
im_id = roidb['id']
im_height = np.round(roidb['height'] * im_scales)
im_width = np.round(roidb['width'] * im_scales)
im_info = np.array([im_height, im_width, im_scales], dtype=np.float32)
if mode == 'test' or mode == 'infer':
return im, im_info, im_id
gt_boxes = roidb['gt_boxes'].astype('float32')
gt_classes = roidb['gt_classes'].astype('int32')
is_crowd = roidb['is_crowd'].astype('int32')
segms = roidb['segms']
outs = (im, gt_boxes, gt_classes, is_crowd, im_info, im_id)
if cfg.MASK_ON:
gt_masks = []
valid = True
segms = roidb['segms']
assert len(segms) == is_crowd.shape[0]
for i in range(len(roidb['segms'])):
segm, iscrowd = segms[i], is_crowd[i]
gt_segm = []
if iscrowd:
gt_segm.append([[0, 0]])
else:
for poly in segm:
if len(poly) == 0:
valid = False
break
gt_segm.append(np.array(poly).reshape(-1, 2))
if (not valid) or len(gt_segm) == 0:
break
gt_masks.append(gt_segm)
outs = outs + (gt_masks, )
return outs
def coco(mode,
......@@ -63,19 +103,6 @@ def coco(mode,
print("{} on {} with {} roidbs".format(mode, cfg.dataset, len(roidbs)))
def roidb_reader(roidb, mode):
im, im_scales = data_utils.get_image_blob(roidb, mode)
im_id = roidb['id']
im_height = np.round(roidb['height'] * im_scales)
im_width = np.round(roidb['width'] * im_scales)
im_info = np.array([im_height, im_width, im_scales], dtype=np.float32)
if mode == 'test' or mode == 'infer':
return im, im_info, im_id
gt_boxes = roidb['gt_boxes'].astype('float32')
gt_classes = roidb['gt_classes'].astype('int32')
is_crowd = roidb['is_crowd'].astype('int32')
return im, gt_boxes, gt_classes, is_crowd, im_info, im_id
def padding_minibatch(batch_data):
if len(batch_data) == 1:
return batch_data
......@@ -93,22 +120,31 @@ def coco(mode,
def reader():
if mode == "train":
roidb_perm = deque(np.random.permutation(roidbs))
if shuffle:
roidb_perm = deque(np.random.permutation(roidbs))
else:
roidb_perm = deque(roidbs)
roidb_cur = 0
count = 0
batch_out = []
while True:
roidb = roidb_perm[0]
roidb_cur += 1
roidb_perm.rotate(-1)
if roidb_cur >= len(roidbs):
roidb_perm = deque(np.random.permutation(roidbs))
if shuffle:
roidb_perm = deque(np.random.permutation(roidbs))
else:
roidb_perm = deque(roidbs)
roidb_cur = 0
im, gt_boxes, gt_classes, is_crowd, im_info, im_id = roidb_reader(
roidb, mode)
if gt_boxes.shape[0] == 0:
# im, gt_boxes, gt_classes, is_crowd, im_info, im_id, gt_masks
datas = roidb_reader(roidb, mode)
if datas[1].shape[0] == 0:
continue
batch_out.append(
(im, gt_boxes, gt_classes, is_crowd, im_info, im_id))
if cfg.MASK_ON:
if len(datas[-1]) != datas[1].shape[0]:
continue
batch_out.append(datas)
if not padding_total:
if len(batch_out) == batch_size:
yield padding_minibatch(batch_out)
......@@ -124,7 +160,9 @@ def coco(mode,
yield sub_batch_out
sub_batch_out = []
batch_out = []
count += 1
if count >= cfg.max_iter + 1:
return
elif mode == "test":
batch_out = []
for roidb in roidbs:
......
......@@ -36,6 +36,7 @@ import matplotlib
matplotlib.use('Agg')
from pycocotools.coco import COCO
import box_utils
import segm_utils
from config import cfg
logger = logging.getLogger(__name__)
......@@ -91,8 +92,9 @@ class JsonDataset(object):
end_time = time.time()
print('_add_gt_annotations took {:.3f}s'.format(end_time -
start_time))
print('Appending horizontally-flipped training examples...')
self._extend_with_flipped_entries(roidb)
if cfg.TRAIN.use_flipped:
print('Appending horizontally-flipped training examples...')
self._extend_with_flipped_entries(roidb)
print('Loaded dataset: {:s}'.format(self.name))
print('{:d} roidb entries'.format(len(roidb)))
if self.is_train:
......@@ -111,6 +113,7 @@ class JsonDataset(object):
entry['gt_classes'] = np.empty((0), dtype=np.int32)
entry['gt_id'] = np.empty((0), dtype=np.int32)
entry['is_crowd'] = np.empty((0), dtype=np.bool)
entry['segms'] = []
# Remove unwanted fields that come from the json file (if they exist)
for k in ['date_captured', 'url', 'license', 'file_name']:
if k in entry:
......@@ -126,9 +129,15 @@ class JsonDataset(object):
objs = self.COCO.loadAnns(ann_ids)
# Sanitize bboxes -- some are invalid
valid_objs = []
valid_segms = []
width = entry['width']
height = entry['height']
for obj in objs:
if isinstance(obj['segmentation'], list):
# Valid polygons have >= 3 points, so require >= 6 coordinates
obj['segmentation'] = [
p for p in obj['segmentation'] if len(p) >= 6
]
if obj['area'] < cfg.TRAIN.gt_min_area:
continue
if 'ignore' in obj and obj['ignore'] == 1:
......@@ -141,6 +150,8 @@ class JsonDataset(object):
if obj['area'] > 0 and x2 > x1 and y2 > y1:
obj['clean_bbox'] = [x1, y1, x2, y2]
valid_objs.append(obj)
valid_segms.append(obj['segmentation'])
num_valid_objs = len(valid_objs)
gt_boxes = np.zeros((num_valid_objs, 4), dtype=entry['gt_boxes'].dtype)
......@@ -158,6 +169,7 @@ class JsonDataset(object):
entry['gt_classes'] = np.append(entry['gt_classes'], gt_classes)
entry['gt_id'] = np.append(entry['gt_id'], gt_id)
entry['is_crowd'] = np.append(entry['is_crowd'], is_crowd)
entry['segms'].extend(valid_segms)
def _extend_with_flipped_entries(self, roidb):
"""Flip each entry in the given roidb and return a new roidb that is the
......@@ -175,11 +187,13 @@ class JsonDataset(object):
gt_boxes[:, 2] = width - oldx1 - 1
assert (gt_boxes[:, 2] >= gt_boxes[:, 0]).all()
flipped_entry = {}
dont_copy = ('gt_boxes', 'flipped')
dont_copy = ('gt_boxes', 'flipped', 'segms')
for k, v in entry.items():
if k not in dont_copy:
flipped_entry[k] = v
flipped_entry['gt_boxes'] = gt_boxes
flipped_entry['segms'] = segm_utils.flip_segms(
entry['segms'], entry['height'], entry['width'])
flipped_entry['flipped'] = True
flipped_roidb.append(flipped_entry)
roidb.extend(flipped_roidb)
......
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
model=$1 # faster_rcnn, mask_rcnn
if [ "$model" = "faster_rcnn" ]; then
mask_on="--MASK_ON False"
elif [ "$model" = "mask_rcnn" ]; then
mask_on="--MASK_ON True"
else
echo "Invalid model provided. Please use one of {faster_rcnn, mask_rcnn}"
exit 1
fi
python -u ../eval_coco_map.py \
$mask_on \
--pretrained_model=../output/model_iter179999 \
--data_dir=../dataset/coco/ \
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
model=$1 # faster_rcnn, mask_rcnn
if [ "$model" = "faster_rcnn" ]; then
mask_on="--MASK_ON False"
elif [ "$model" = "mask_rcnn" ]; then
mask_on="--MASK_ON True"
else
echo "Invalid model provided. Please use one of {faster_rcnn, mask_rcnn}"
exit 1
fi
python -u ../infer.py \
$mask_on \
--pretrained_model=../output/model_iter179999 \
--image_path=../dataset/coco/val2017/ \
--image_name=000000000139.jpg \
--draw_threshold=0.6
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
model=$1 # faster_rcnn, mask_rcnn
if [ "$model" = "faster_rcnn" ]; then
mask_on="--MASK_ON False"
elif [ "$model" = "mask_rcnn" ]; then
mask_on="--MASK_ON True"
else
echo "Invalid model provided. Please use one of {faster_rcnn, mask_rcnn}"
exit 1
fi
python -u ../train.py \
$mask_on \
--model_save_dir=../output/ \
--pretrained_model=../imagenet_resnet50_fusebn/ \
--data_dir=../dataset/coco/ \
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://w_idxw.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on:
# --------------------------------------------------------
# Detectron
# Copyright (c) 2017-present, Facebook, Inc.
# Licensed under the Apache License, Version 2.0;
# Written by Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import pycocotools.mask as mask_util
import cv2
def is_poly(segm):
"""Determine if segm is a polygon. Valid segm expected (polygon or RLE)."""
assert isinstance(segm, (list, dict)), \
'Invalid segm type: {}'.format(type(segm))
return isinstance(segm, list)
def segms_to_rle(segms, height, width):
rle = segms
if isinstance(segms, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = mask_util.frPyObjects(segms, height, width)
rle = mask_util.merge(rles)
elif isinstance(segms['counts'], list):
# uncompressed RLE
rle = mask_util.frPyObjects(segms, height, width)
return rle
def segms_to_mask(segms, iscrowd, height, width):
print('segms: ', segms)
if iscrowd:
return [[0 for i in range(width)] for j in range(height)]
rle = segms_to_rle(segms, height, width)
mask = mask_util.decode(rle)
return mask
def flip_segms(segms, height, width):
"""Left/right flip each mask in a list of masks."""
def _flip_poly(poly, width):
flipped_poly = np.array(poly)
flipped_poly[0::2] = width - np.array(poly[0::2]) - 1
return flipped_poly.tolist()
def _flip_rle(rle, height, width):
if 'counts' in rle and type(rle['counts']) == list:
# Magic RLE format handling painfully discovered by looking at the
# COCO API showAnns function.
rle = mask_util.frPyObjects([rle], height, width)
mask = mask_util.decode(rle)
mask = mask[:, ::-1, :]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
flipped_segms = []
for segm in segms:
if is_poly(segm):
# Polygon format
flipped_segms.append([_flip_poly(poly, width) for poly in segm])
else:
# RLE format
flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms
......@@ -20,7 +20,8 @@ import sys
import numpy as np
import time
import shutil
from utility import parse_args, print_arguments, SmoothedValue
from utility import parse_args, print_arguments, SmoothedValue, TrainingStats, now_time
import collections
import paddle
import paddle.fluid as fluid
......@@ -55,30 +56,30 @@ def train():
use_pyreader=cfg.use_pyreader,
use_random=use_random)
model.build_model(image_shape)
loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss = model.loss()
loss_cls.persistable = True
loss_bbox.persistable = True
rpn_cls_loss.persistable = True
rpn_reg_loss.persistable = True
loss = loss_cls + loss_bbox + rpn_cls_loss + rpn_reg_loss
losses, keys = model.loss()
loss = losses[0]
fetch_list = losses
boundaries = cfg.lr_steps
gamma = cfg.lr_gamma
step_num = len(cfg.lr_steps)
values = [learning_rate * (gamma**i) for i in range(step_num + 1)]
lr = exponential_with_warmup_decay(
learning_rate=learning_rate,
boundaries=boundaries,
values=values,
warmup_iter=cfg.warm_up_iter,
warmup_factor=cfg.warm_up_factor)
optimizer = fluid.optimizer.Momentum(
learning_rate=exponential_with_warmup_decay(
learning_rate=learning_rate,
boundaries=boundaries,
values=values,
warmup_iter=cfg.warm_up_iter,
warmup_factor=cfg.warm_up_factor),
learning_rate=lr,
regularization=fluid.regularizer.L2Decay(cfg.weight_decay),
momentum=cfg.momentum)
optimizer.minimize(loss)
fetch_list = fetch_list + [lr]
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(
fluid.default_main_program(), skip_opt_set=set(fetch_list))
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -107,7 +108,8 @@ def train():
py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader)
else:
train_reader = reader.train(batch_size=total_batch_size, shuffle=shuffle)
train_reader = reader.train(
batch_size=total_batch_size, shuffle=shuffle)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
def save_model(postfix):
......@@ -116,88 +118,72 @@ def train():
shutil.rmtree(model_path)
fluid.io.save_persistables(exe, model_path)
fetch_list = [loss, rpn_cls_loss, rpn_reg_loss, loss_cls, loss_bbox]
def train_loop_pyreader():
py_reader.start()
smoothed_loss = SmoothedValue(cfg.log_window)
train_stats = TrainingStats(cfg.log_window, keys)
try:
start_time = time.time()
prev_start_time = start_time
total_time = 0
last_loss = 0
every_pass_loss = []
for iter_id in range(cfg.max_iter):
prev_start_time = start_time
start_time = time.time()
losses = train_exe.run(fetch_list=[v.name for v in fetch_list])
every_pass_loss.append(np.mean(np.array(losses[0])))
smoothed_loss.add_value(np.mean(np.array(losses[0])))
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format(
iter_id, lr[0],
smoothed_loss.get_median_value(
), start_time - prev_start_time))
end_time = time.time()
total_time += end_time - start_time
last_loss = np.mean(np.array(losses[0]))
outs = train_exe.run(fetch_list=[v.name for v in fetch_list])
stats = {k: np.array(v).mean() for k, v in zip(keys, outs[:-1])}
train_stats.update(stats)
logs = train_stats.log()
strs = '{}, lr: {:.5f}, {}, time: {:.3f}'.format(
now_time(),
np.mean(outs[-1]), logs, start_time - prev_start_time)
print(strs)
sys.stdout.flush()
if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
# only for ce
end_time = time.time()
total_time = end_time - start_time
last_loss = np.array(outs[0]).mean()
if cfg.enable_ce:
gpu_num = devices_num
epoch_idx = iter_id + 1
loss = last_loss
print("kpis\teach_pass_duration_card%s\t%s" %
(gpu_num, total_time / epoch_idx))
print("kpis\ttrain_loss_card%s\t%s" %
(gpu_num, loss))
except fluid.core.EOFException:
(gpu_num, total_time / epoch_idx))
print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, loss))
except (StopIteration, fluid.core.EOFException):
py_reader.reset()
return np.mean(every_pass_loss)
def train_loop():
start_time = time.time()
prev_start_time = start_time
start = start_time
total_time = 0
last_loss = 0
every_pass_loss = []
smoothed_loss = SmoothedValue(cfg.log_window)
train_stats = TrainingStats(cfg.log_window, keys)
for iter_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
losses = train_exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data))
loss_v = np.mean(np.array(losses[0]))
every_pass_loss.append(loss_v)
smoothed_loss.add_value(loss_v)
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
end_time = time.time()
total_time += end_time - start_time
last_loss = loss_v
print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format(
iter_id, lr[0],
smoothed_loss.get_median_value(), start_time - prev_start_time))
outs = train_exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data))
stats = {k: np.array(v).mean() for k, v in zip(keys, outs[:-1])}
train_stats.update(stats)
logs = train_stats.log()
strs = '{}, lr: {:.5f}, {}, time: {:.3f}'.format(
now_time(),
np.mean(outs[-1]), logs, start_time - prev_start_time)
print(strs)
sys.stdout.flush()
if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
if (iter_id + 1) == cfg.max_iter:
break
end_time = time.time()
total_time = end_time - start_time
last_loss = np.array(outs[0]).mean()
# only for ce
if cfg.enable_ce:
gpu_num = devices_num
epoch_idx = iter_id + 1
loss = last_loss
print("kpis\teach_pass_duration_card%s\t%s" %
(gpu_num, total_time / epoch_idx))
print("kpis\ttrain_loss_card%s\t%s" %
(gpu_num, loss))
(gpu_num, total_time / epoch_idx))
print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, loss))
return np.mean(every_pass_loss)
......
......@@ -22,7 +22,9 @@ import sys
import distutils.util
import numpy as np
import six
import collections
from collections import deque
import datetime
from paddle.fluid import core
import argparse
import functools
......@@ -85,6 +87,37 @@ class SmoothedValue(object):
return np.median(self.deque)
def now_time():
return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
class TrainingStats(object):
def __init__(self, window_size, stats_keys):
self.smoothed_losses_and_metrics = {
key: SmoothedValue(window_size)
for key in stats_keys
}
def update(self, stats):
for k, v in self.smoothed_losses_and_metrics.items():
v.add_value(stats[k])
def get(self, extras=None):
stats = collections.OrderedDict()
if extras:
for k, v in extras.items():
stats[k] = v
for k, v in self.smoothed_losses_and_metrics.items():
stats[k] = round(v.get_median_value(), 3)
return stats
def log(self, extras=None):
d = self.get(extras)
strs = ', '.join(str(dict({x: y})).strip('{}') for x, y in d.items())
return strs
def parse_args():
"""return all args
"""
......@@ -108,7 +141,7 @@ def parse_args():
add_arg('learning_rate', float, 0.01, "Learning rate.")
add_arg('max_iter', int, 180000, "Iter number.")
add_arg('log_window', int, 20, "Log smooth window, set 1 for debug, set 20 for train.")
# FAST RCNN
# RCNN
# RPN
add_arg('anchor_sizes', int, [32,64,128,256,512], "The size of anchors.")
add_arg('aspect_ratios', float, [0.5,1.0,2.0], "The ratio of anchors.")
......@@ -116,6 +149,7 @@ def parse_args():
add_arg('rpn_stride', float, [16.,16.], "Stride of the feature map that RPN is attached.")
add_arg('rpn_nms_thresh', float, 0.7, "NMS threshold used on RPN proposals")
# TRAIN TEST INFER
add_arg('MASK_ON', bool, False, "Option for different models. If False, choose faster_rcnn. If True, choose mask_rcnn")
add_arg('im_per_batch', int, 1, "Minibatch size.")
add_arg('max_size', int, 1333, "The resized image height.")
add_arg('scales', int, [800], "The resized image height.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册