diff --git a/fluid/PaddleCV/faster_rcnn/README.md b/fluid/PaddleCV/faster_rcnn/README.md index 0a5f68c34adda54ba0e27f44f16c18cafe057830..76b82180b201a92b86e8dadef5dd0c6084d67664 100644 --- a/fluid/PaddleCV/faster_rcnn/README.md +++ b/fluid/PaddleCV/faster_rcnn/README.md @@ -1,4 +1,4 @@ -# Faster RCNN Objective Detection +# RCNN Objective Detection --- ## Table of Contents @@ -17,17 +17,20 @@ Running sample code in this directory requires PaddelPaddle Fluid v.1.0.0 and la ## Introduction -[Faster Rcnn](https://arxiv.org/abs/1506.01497) is a typical two stage detector. The total framework of network can be divided into four parts, as shown below: -

-
-Faster RCNN model -

+Region Convolutional Neural Network (RCNN) models are two stages detector. According to proposals and feature extraction, obtain class and more precise proposals. +Now RCNN model contains two typical models: Faster RCNN and Mask RCNN. + +[Faster RCNN](https://arxiv.org/abs/1506.01497), The total framework of network can be divided into four parts: 1. Base conv layer. As a CNN objective dection, Faster RCNN extract feature maps using a basic convolutional network. The feature maps then can be shared by RPN and fc layers. This sampel uses [ResNet-50](https://arxiv.org/abs/1512.03385) as base conv layer. 2. Region Proposal Network (RPN). RPN generates proposals for detection。This block generates anchors by a set of size and ratio and classifies anchors into fore-ground and back-ground by softmax. Then refine anchors to obtain more precise proposals using box regression. 3. RoI Align. This layer takes feature maps and proposals as input. The proposals are mapped to feature maps and pooled to the same size. The output are sent to fc layers for classification and regression. RoIPool and RoIAlign are used separately to this layer and it can be set in roi\_func in config.py. 4. Detection layer. Using the output of roi pooling to compute the class and locatoin of each proposal in two fc layers. +[Mask RCNN](https://arxiv.org/abs/1703.06870) is a classical instance segmentation model and an extension of Faster RCNN + +Mask RCNN is a two stage model as well. At the first stage, it generates proposals from input images. At the second stage, it obtains class result, bbox and mask which is the result from segmentation branch on original Faster RCNN model. It decouples the relation between mask and classification. + ## Data preparation Train the model on [MS-COCO dataset](http://cocodataset.org/#download), download dataset as below: @@ -64,10 +67,12 @@ After data preparation, one can start the training step by: python train.py \ --model_save_dir=output/ \ - --pretrained_model=${path_to_pretrain_model} - --data_dir=${path_to_data} + --pretrained_model=${path_to_pretrain_model} \ + --data_dir=${path_to_data} \ + --MASK_ON=False - Set ```export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7``` to specifiy 8 GPU to train. +- Set ```MASK\_ON``` to choose Faster RCNN or Mask RCNN model. - For more help on arguments: python train.py --help @@ -93,7 +98,6 @@ After data preparation, one can start the training step by: * In first 500 iteration, the learning rate increases linearly from 0.00333 to 0.01. Then lr is decayed at 120000, 160000 iteration with multiplier 0.1, 0.01. The maximum iteration is 180000. Also, we released a 2x model which has 360000 iterations and lr is decayed at 240000, 320000. These configuration can be set by max_iter and lr_steps in config.py. * Set the learning rate of bias to two times as global lr in non basic convolutional layers. * In basic convolutional layers, parameters of affine layers and res body do not update. -* Use Nvidia Tesla V100 8GPU, total time for training is about 40 hours. ## Evaluation @@ -109,6 +113,8 @@ Evaluation is to evaluate the performance of a trained model. This sample provid Evalutaion result is shown as below: +Faster RCNN: + | Model | RoI function | Batch size | Max iteration | mAP | | :--------------- | :--------: | :------------: | :------------------: |------: | | [Fluid RoIPool minibatch padding](http://paddlemodels.bj.bcebos.com/faster_rcnn/model_pool_minibatch_padding.tar.gz) | RoIPool | 8 | 180000 | 0.316 | @@ -121,6 +127,14 @@ Evalutaion result is shown as below: * Fluid RoIAlign no padding: Images without padding. * Fluid RoIAlign no padding 2x: Images without padding, train for 360000 iterations, learning rate is decayed at 240000, 320000. +Mask RCNN: + +| Model | Batch size | Max iteration | box mAP | mask mAP | +| :--------------- | :--------: | :------------: | :--------: |------: | +| [Fluid mask no padding](https://paddlemodels.bj.bcebos.com/faster_rcnn/Fluid_mask_no_padding.tar.gz) | 8 | 180000 | 0.359 | 0.314 | + +* Fluid mask no padding: Use RoIAlign. Images without padding. + ## Inference and Visualization Inference is used to get prediction score or image features based on trained models. `infer.py` is the main executor for inference, one can start infer step by: @@ -135,8 +149,12 @@ Inference is used to get prediction score or image features based on trained mod Visualization of infer result is shown as below:

- - -
+
Faster RCNN Visualization Examples

+ +

+ +
+Mask RCNN Visualization Examples +

diff --git a/fluid/PaddleCV/faster_rcnn/README_cn.md b/fluid/PaddleCV/faster_rcnn/README_cn.md index 29adfcfd274b82f2ddaba1894be6ad1c7ece1e6a..00bb7e3378b2628923eeea4f5e9fcb870c18a266 100644 --- a/fluid/PaddleCV/faster_rcnn/README_cn.md +++ b/fluid/PaddleCV/faster_rcnn/README_cn.md @@ -1,4 +1,4 @@ -# Faster RCNN 目标检测 +# RCNN 系列目标检测 --- ## 内容 @@ -16,18 +16,21 @@ 在当前目录下运行样例代码需要PadddlePaddle Fluid的v.1.0.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/0.15.0/beginners_guide/install/install_doc.html#paddlepaddle)中的说明来更新PaddlePaddle。 ## 简介 +区域卷积神经网络(RCNN)系列模型为两阶段目标检测器。通过对图像生成候选区域,提取特征,判别特征类别并修正候选框位置。 +RCNN系列目前包含两个代表模型:Faster RCNN,Mask RCNN -[Faster Rcnn](https://arxiv.org/abs/1506.01497) 是典型的两阶段目标检测器。如下图所示,整体网络可以分为4个主要内容: -

-
-Faster RCNN 目标检测模型 -

+[Faster RCNN](https://arxiv.org/abs/1506.01497) 整体网络可以分为4个主要内容: 1. 基础卷积层。作为一种卷积神经网络目标检测方法,Faster RCNN首先使用一组基础的卷积网络提取图像的特征图。特征图被后续RPN层和全连接层共享。本示例采用[ResNet-50](https://arxiv.org/abs/1512.03385)作为基础卷积层。 2. 区域生成网络(RPN)。RPN网络用于生成候选区域(proposals)。该层通过一组固定的尺寸和比例得到一组锚点(anchors), 通过softmax判断锚点属于前景或者背景,再利用区域回归修正锚点从而获得精确的候选区域。 3. RoI Align。该层收集输入的特征图和候选区域,将候选区域映射到特征图中并池化为统一大小的区域特征图,送入全连接层判定目标类别, 该层可选用RoIPool和RoIAlign两种方式,在config.py中设置roi\_func。 4. 检测层。利用区域特征图计算候选区域的类别,同时再次通过区域回归获得检测框最终的精确位置。 +[Mask RCNN](https://arxiv.org/abs/1703.06870) 扩展自Faster RCNN,是经典的实例分割模型。 + +Mask RCNN同样为两阶段框架,第一阶段扫描图像生成候选框;第二阶段根据候选框得到分类结果,边界框,同时在原有Faster RCNN模型基础上添加分割分支,得到掩码结果,实现了掩码和类别预测关系的解藕。 + + ## 数据准备 在[MS-COCO数据集](http://cocodataset.org/#download)上进行训练,通过如下方式下载数据集。 @@ -63,10 +66,12 @@ Faster RCNN 目标检测模型 python train.py \ --model_save_dir=output/ \ - --pretrained_model=${path_to_pretrain_model} - --data_dir=${path_to_data} + --pretrained_model=${path_to_pretrain_model} \ + --data_dir=${path_to_data} \ + --MASK_ON=False - 通过设置export CUDA\_VISIBLE\_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。 +- 通过设置MASK\_ON选择Faster RCNN和Mask RCNN模型。 - 可选参数见: python train.py --help @@ -83,11 +88,10 @@ Faster RCNN 目标检测模型 **训练策略:** -* 采用momentum优化算法训练Faster RCNN,momentum=0.9。 +* 采用momentum优化算法训练,momentum=0.9。 * 权重衰减系数为0.0001,前500轮学习率从0.00333线性增加至0.01。在120000,160000轮时使用0.1,0.01乘子进行学习率衰减,最大训练180000轮。同时我们也提供了2x模型,该模型采用更多的迭代轮数进行训练,训练360000轮,学习率在240000,320000轮衰减,其他参数不变,训练最大轮数和学习率策略可以在config.py中对max_iter和lr_steps进行设置。 * 非基础卷积层卷积bias学习率为整体学习率2倍。 * 基础卷积层中,affine_layers参数不更新,res2层参数不更新。 -* 使用Nvidia Tesla V100 8卡并行,总共训练时长大约40小时。 ## 模型评估 @@ -103,6 +107,8 @@ Faster RCNN 目标检测模型 下表为模型评估结果: +Faster RCNN + | 模型 | RoI处理方式 | 批量大小 | 迭代次数 | mAP | | :--------------- | :--------: | :------------: | :------------------: |------: | | [Fluid RoIPool minibatch padding](http://paddlemodels.bj.bcebos.com/faster_rcnn/model_pool_minibatch_padding.tar.gz) | RoIPool | 8 | 180000 | 0.316 | @@ -117,6 +123,13 @@ Faster RCNN 目标检测模型 * Fluid RoIAlign no padding: 使用RoIAlign,不对图像做填充处理。 * Fluid RoIAlign no padding 2x: 使用RoIAlign,不对图像做填充处理。训练360000轮,学习率在240000,320000轮衰减。 +Mask RCNN +| 模型 | 批量大小 | 迭代次数 | box mAP | mask mAP | +| :--------------- | :--------: | :------------: | :--------: |------: | +| [Fluid mask no padding](https://paddlemodels.bj.bcebos.com/faster_rcnn/Fluid_mask_no_padding.tar.gz) | 8 | 180000 | 0.359 | 0.314 | + +* Fluid mask no padding: 使用RoIAlign,不对图像做填充处理 + ## 模型推断及可视化 模型推断可以获取图像中的物体及其对应的类别,`infer.py`是主要执行程序,调用示例如下: @@ -131,8 +144,12 @@ Faster RCNN 目标检测模型 下图为模型可视化预测结果:

- - -
+
Faster RCNN 预测可视化

+ +

+ +
+Mask RCNN 预测可视化 +

diff --git a/fluid/PaddleCV/faster_rcnn/box_utils.py b/fluid/PaddleCV/faster_rcnn/box_utils.py index 64d7d96948b856f4ae5c28594e9fb19a3a18480e..bb3fe9c8f0cb261004578abba651ad7210518a22 100644 --- a/fluid/PaddleCV/faster_rcnn/box_utils.py +++ b/fluid/PaddleCV/faster_rcnn/box_utils.py @@ -69,6 +69,7 @@ def clip_xyxy_to_image(x1, y1, x2, y2, height, width): y2 = np.minimum(height - 1., np.maximum(0., y2)) return x1, y1, x2, y2 + def nms(dets, thresh): """Apply classic DPM-style greedy NMS.""" if dets.shape[0] == 0: @@ -123,3 +124,21 @@ def nms(dets, thresh): return np.where(suppressed == 0)[0] + +def expand_boxes(boxes, scale): + """Expand an array of boxes by a given scale.""" + w_half = (boxes[:, 2] - boxes[:, 0]) * .5 + h_half = (boxes[:, 3] - boxes[:, 1]) * .5 + x_c = (boxes[:, 2] + boxes[:, 0]) * .5 + y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + + w_half *= scale + h_half *= scale + + boxes_exp = np.zeros(boxes.shape) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + + return boxes_exp diff --git a/fluid/PaddleCV/faster_rcnn/colormap.py b/fluid/PaddleCV/faster_rcnn/colormap.py new file mode 100644 index 0000000000000000000000000000000000000000..8c2447794fc2e9841b30c2cdf11e8fc70d20d764 --- /dev/null +++ b/fluid/PaddleCV/faster_rcnn/colormap.py @@ -0,0 +1,61 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# Detectron +# Copyright (c) 2017-present, Facebook, Inc. +# Licensed under the Apache License, Version 2.0; +# Written by Ross Girshick +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np + + +def colormap(rgb=False): + color_list = np.array([ + 0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494, + 0.184, 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078, + 0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000, + 0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667, + 0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000, + 1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000, + 0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667, + 0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333, + 0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333, + 0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000, + 1.000, 0.667, 0.333, 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.167, + 0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286, + 0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714, + 0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000 + ]).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list diff --git a/fluid/PaddleCV/faster_rcnn/config.py b/fluid/PaddleCV/faster_rcnn/config.py index 44b35f7509eeb1adf316e3e725aef8a729bf6499..f535ad705b7578f3c27a7ac8837baf9d791a7612 100644 --- a/fluid/PaddleCV/faster_rcnn/config.py +++ b/fluid/PaddleCV/faster_rcnn/config.py @@ -90,6 +90,9 @@ _C.TRAIN.freeze_at = 2 # min area of ground truth box _C.TRAIN.gt_min_area = -1 +# Use horizontally-flipped images during training? +_C.TRAIN.use_flipped = True + # # Inference options # @@ -120,7 +123,7 @@ _C.TEST.rpn_post_nms_top_n = 1000 _C.TEST.rpn_min_size = 0.0 # max number of detections -_C.TEST.detectiions_per_im = 100 +_C.TEST.detections_per_im = 100 # NMS threshold used on RPN proposals _C.TEST.rpn_nms_thresh = 0.7 @@ -129,6 +132,9 @@ _C.TEST.rpn_nms_thresh = 0.7 # Model options # +# Whether use mask rcnn head +_C.MASK_ON = True + # weight for bbox regression targets _C.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2] @@ -156,6 +162,15 @@ _C.roi_resolution = 14 # spatial scale _C.spatial_scale = 1. / 16. +# resolution to represent mask labels +_C.resolution = 14 + +# Number of channels in the mask head +_C.dim_reduced = 256 + +# Threshold for converting soft masks to hard masks +_C.mrcnn_thresh_binarize = 0.5 + # # SOLVER options # diff --git a/fluid/PaddleCV/faster_rcnn/eval_coco_map.py b/fluid/PaddleCV/faster_rcnn/eval_coco_map.py index f8c755a3d0f880a47791f1c43aa161cfa0e5ff98..b9a18b5b3d211172995ef7bbf2a1e4ec96c76bd4 100644 --- a/fluid/PaddleCV/faster_rcnn/eval_coco_map.py +++ b/fluid/PaddleCV/faster_rcnn/eval_coco_map.py @@ -18,8 +18,7 @@ from __future__ import print_function import os import time import numpy as np -from eval_helper import get_nmsed_box -from eval_helper import get_dt_res +from eval_helper import * import paddle import paddle.fluid as fluid import reader @@ -44,7 +43,7 @@ def eval(): devices_num = len(devices.split(",")) total_batch_size = devices_num * cfg.TRAIN.im_per_batch cocoGt = COCO(os.path.join(cfg.data_dir, test_list)) - numId_to_catId_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())} + num_id_to_cat_id_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())} category_ids = cocoGt.getCatIds() label_list = { item['id']: item['name'] @@ -58,45 +57,76 @@ def eval(): use_pyreader=False, is_train=False) model.build_model(image_shape) - rpn_rois, confs, locs = model.eval_out() + rpn_rois, confs, locs = model.eval_bbox_out() + pred_boxes = model.eval() + if cfg.MASK_ON: + masks = model.eval_mask_out() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) # yapf: disable if cfg.pretrained_model: def if_exist(var): return os.path.exists(os.path.join(cfg.pretrained_model, var.name)) fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist) + # yapf: enable test_reader = reader.test(total_batch_size) feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) dts_res = [] - fetch_list = [rpn_rois, confs, locs] + segms_res = [] + if cfg.MASK_ON: + fetch_list = [rpn_rois, confs, locs, pred_boxes, masks] + else: + fetch_list = [rpn_rois, confs, locs] for batch_id, batch_data in enumerate(test_reader()): start = time.time() im_info = [] for data in batch_data: im_info.append(data[1]) - rpn_rois_v, confs_v, locs_v = exe.run( - fetch_list=[v.name for v in fetch_list], - feed=feeder.feed(batch_data), - return_numpy=False) - new_lod, nmsed_out = get_nmsed_box(rpn_rois_v, confs_v, locs_v, - class_nums, im_info, - numId_to_catId_map) + result = exe.run(fetch_list=[v.name for v in fetch_list], + feed=feeder.feed(batch_data), + return_numpy=False) + + rpn_rois_v = result[0] + confs_v = result[1] + locs_v = result[2] + if cfg.MASK_ON: + pred_boxes_v = result[3] + masks_v = result[4] + + new_lod = pred_boxes_v.lod() + nmsed_out = pred_boxes_v - dts_res += get_dt_res(total_batch_size, new_lod, nmsed_out, batch_data) + dts_res += get_dt_res(total_batch_size, new_lod[0], nmsed_out, + batch_data, num_id_to_cat_id_map) + + if cfg.MASK_ON and np.array(masks_v).shape != (1, 1): + segms_out = segm_results(nmsed_out, masks_v, im_info) + segms_res += get_segms_res(total_batch_size, new_lod[0], segms_out, + batch_data, num_id_to_cat_id_map) end = time.time() print('batch id: {}, time: {}'.format(batch_id, end - start)) - with open("detection_result.json", 'w') as outfile: + with open("detection_bbox_result.json", 'w') as outfile: json.dump(dts_res, outfile) - print("start evaluate using coco api") - cocoDt = cocoGt.loadRes("detection_result.json") + print("start evaluate bbox using coco api") + cocoDt = cocoGt.loadRes("detection_bbox_result.json") cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() + if cfg.MASK_ON: + with open("detection_segms_result.json", 'w') as outfile: + json.dump(segms_res, outfile) + print("start evaluate mask using coco api") + cocoDt = cocoGt.loadRes("detection_segms_result.json") + cocoEval = COCOeval(cocoGt, cocoDt, 'segm') + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + if __name__ == '__main__': args = parse_args() diff --git a/fluid/PaddleCV/faster_rcnn/eval_helper.py b/fluid/PaddleCV/faster_rcnn/eval_helper.py index 852b52955915bf268f930ce3b0fa35de5734b1ea..5e1b970900527afbbba7b964b69850bdb4eaf4e0 100644 --- a/fluid/PaddleCV/faster_rcnn/eval_helper.py +++ b/fluid/PaddleCV/faster_rcnn/eval_helper.py @@ -21,6 +21,10 @@ from PIL import Image from PIL import ImageDraw from PIL import ImageFont from config import cfg +import pycocotools.mask as mask_util +import six +from colormap import colormap +import cv2 def box_decoder(deltas, boxes, weights): @@ -80,8 +84,7 @@ def clip_tiled_boxes(boxes, im_shape): return boxes -def get_nmsed_box(rpn_rois, confs, locs, class_nums, im_info, - numId_to_catId_map): +def get_nmsed_box(rpn_rois, confs, locs, class_nums, im_info): lod = rpn_rois.lod()[0] rpn_rois_v = np.array(rpn_rois) variance_v = np.array(cfg.bbox_reg_weights) @@ -106,38 +109,41 @@ def get_nmsed_box(rpn_rois, confs, locs, class_nums, im_info, inds = np.where(scores_n[:, j] > cfg.TEST.score_thresh)[0] scores_j = scores_n[inds, j] rois_j = rois_n[inds, j * 4:(j + 1) * 4] - dets_j = np.hstack((rois_j, scores_j[:, np.newaxis])).astype( + dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype( np.float32, copy=False) keep = box_utils.nms(dets_j, cfg.TEST.nms_thresh) nms_dets = dets_j[keep, :] #add labels - cat_id = numId_to_catId_map[j] - label = np.array([cat_id for _ in range(len(keep))]) + label = np.array([j for _ in range(len(keep))]) nms_dets = np.hstack((nms_dets, label[:, np.newaxis])).astype( np.float32, copy=False) cls_boxes[j] = nms_dets # Limit to max_per_image detections **over all classes** image_scores = np.hstack( - [cls_boxes[j][:, -2] for j in range(1, class_nums)]) - if len(image_scores) > cfg.TEST.detectiions_per_im: - image_thresh = np.sort(image_scores)[-cfg.TEST.detectiions_per_im] + [cls_boxes[j][:, 1] for j in range(1, class_nums)]) + if len(image_scores) > cfg.TEST.detections_per_im: + image_thresh = np.sort(image_scores)[-cfg.TEST.detections_per_im] for j in range(1, class_nums): - keep = np.where(cls_boxes[j][:, -2] >= image_thresh)[0] + keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0] cls_boxes[j] = cls_boxes[j][keep, :] im_results_n = np.vstack([cls_boxes[j] for j in range(1, class_nums)]) im_results[i] = im_results_n new_lod.append(len(im_results_n) + new_lod[-1]) - boxes = im_results_n[:, :-2] - scores = im_results_n[:, -2] - labels = im_results_n[:, -1] + boxes = im_results_n[:, 2:] + scores = im_results_n[:, 1] + labels = im_results_n[:, 0] im_results = np.vstack([im_results[k] for k in range(len(lod) - 1)]) return new_lod, im_results -def get_dt_res(batch_size, lod, nmsed_out, data): +def get_dt_res(batch_size, lod, nmsed_out, data, num_id_to_cat_id_map): dts_res = [] nmsed_out_v = np.array(nmsed_out) + if nmsed_out_v.shape == ( + 1, + 1, ): + return dts_res assert (len(lod) == batch_size + 1), \ "Error Lod Tensor offset dimension. Lod({}) vs. batch_size({})"\ .format(len(lod), batch_size) @@ -150,7 +156,8 @@ def get_dt_res(batch_size, lod, nmsed_out, data): for j in range(dt_num_this_img): dt = nmsed_out_v[k] k = k + 1 - xmin, ymin, xmax, ymax, score, category_id = dt.tolist() + num_id, score, xmin, ymin, xmax, ymax = dt.tolist() + category_id = num_id_to_cat_id_map[num_id] w = xmax - xmin + 1 h = ymax - ymin + 1 bbox = [xmin, ymin, w, h] @@ -164,24 +171,131 @@ def get_dt_res(batch_size, lod, nmsed_out, data): return dts_res -def draw_bounding_box_on_image(image_path, nms_out, draw_threshold, label_list): - image = Image.open(image_path) +def get_segms_res(batch_size, lod, segms_out, data, num_id_to_cat_id_map): + segms_res = [] + segms_out_v = np.array(segms_out) + k = 0 + for i in range(batch_size): + dt_num_this_img = lod[i + 1] - lod[i] + image_id = int(data[i][-1]) + for j in range(dt_num_this_img): + dt = segms_out_v[k] + k = k + 1 + segm, num_id, score = dt.tolist() + cat_id = num_id_to_cat_id_map[num_id] + if six.PY3: + if 'counts' in segm: + segm['counts'] = rle['counts'].decode("utf8") + segm_res = { + 'image_id': image_id, + 'category_id': cat_id, + 'segmentation': segm, + 'score': score + } + segms_res.append(segm_res) + return segms_res + + +def draw_bounding_box_on_image(image_path, + nms_out, + draw_threshold, + label_list, + num_id_to_cat_id_map, + image=None): + if image is None: + image = Image.open(image_path) draw = ImageDraw.Draw(image) im_width, im_height = image.size - for dt in nms_out: - xmin, ymin, xmax, ymax, score, category_id = dt.tolist() + for dt in np.array(nms_out): + num_id, score, xmin, ymin, xmax, ymax = dt.tolist() + category_id = num_id_to_cat_id_map[num_id] if score < draw_threshold: continue - bbox = dt[:4] - xmin, ymin, xmax, ymax = bbox draw.line( [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)], - width=4, + width=2, fill='red') if image.mode == 'RGB': draw.text((xmin, ymin), label_list[int(category_id)], (255, 255, 0)) image_name = image_path.split('/')[-1] print("image with bbox drawed saved as {}".format(image_name)) image.save(image_name) + + +def draw_mask_on_image(image_path, segms_out, draw_threshold, alpha=0.7): + image = Image.open(image_path) + draw = ImageDraw.Draw(image) + im_width, im_height = image.size + mask_color_id = 0 + w_ratio = .4 + image = np.array(image).astype('float32') + for dt in np.array(segms_out): + segm, num_id, score = dt.tolist() + if score < draw_threshold: + continue + mask = mask_util.decode(segm) * 255 + color_list = colormap(rgb=True) + color_mask = color_list[mask_color_id % len(color_list), 0:3] + mask_color_id += 1 + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + image[idx[0], idx[1], :] *= 1.0 - alpha + image[idx[0], idx[1], :] += alpha * color_mask + image = Image.fromarray(image.astype('uint8')) + return image + + +def segm_results(im_results, masks, im_info): + im_results = np.array(im_results) + class_num = cfg.class_num + M = cfg.resolution + scale = (M + 2.0) / M + lod = masks.lod()[0] + masks_v = np.array(masks) + boxes = im_results[:, 2:] + labels = im_results[:, 0] + segms_results = [[] for _ in range(len(lod) - 1)] + sum = 0 + for i in range(len(lod) - 1): + im_results_n = im_results[lod[i]:lod[i + 1]] + cls_segms = [] + masks_n = masks_v[lod[i]:lod[i + 1]] + boxes_n = boxes[lod[i]:lod[i + 1]] + labels_n = labels[lod[i]:lod[i + 1]] + im_h = int(round(im_info[i][0] / im_info[i][2])) + im_w = int(round(im_info[i][1] / im_info[i][2])) + boxes_n = box_utils.expand_boxes(boxes_n, scale) + boxes_n = boxes_n.astype(np.int32) + padded_mask = np.zeros((M + 2, M + 2), dtype=np.float32) + for j in range(len(im_results_n)): + class_id = int(labels_n[j]) + padded_mask[1:-1, 1:-1] = masks_n[j, class_id, :, :] + + ref_box = boxes_n[j, :] + w = ref_box[2] - ref_box[0] + 1 + h = ref_box[3] - ref_box[1] + 1 + w = np.maximum(w, 1) + h = np.maximum(h, 1) + + mask = cv2.resize(padded_mask, (w, h)) + mask = np.array(mask > cfg.mrcnn_thresh_binarize, dtype=np.uint8) + im_mask = np.zeros((im_h, im_w), dtype=np.uint8) + + x_0 = max(ref_box[0], 0) + x_1 = min(ref_box[2] + 1, im_w) + y_0 = max(ref_box[1], 0) + y_1 = min(ref_box[3] + 1, im_h) + im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - ref_box[1]):(y_1 - ref_box[ + 1]), (x_0 - ref_box[0]):(x_1 - ref_box[0])] + sum += im_mask.sum() + rle = mask_util.encode( + np.array( + im_mask[:, :, np.newaxis], order='F'))[0] + cls_segms.append(rle) + segms_results[i] = np.array(cls_segms)[:, np.newaxis] + segms_results = np.vstack([segms_results[k] for k in range(len(lod) - 1)]) + im_results = np.hstack([segms_results, im_results]) + return im_results[:, :3] diff --git a/fluid/PaddleCV/faster_rcnn/image/000000000139_mask.jpg b/fluid/PaddleCV/faster_rcnn/image/000000000139_mask.jpg new file mode 100644 index 0000000000000000000000000000000000000000..47dfa9a435bf81c8585e8100413cfc0d6719754c Binary files /dev/null and b/fluid/PaddleCV/faster_rcnn/image/000000000139_mask.jpg differ diff --git a/fluid/PaddleCV/faster_rcnn/image/000000127517_mask.jpg b/fluid/PaddleCV/faster_rcnn/image/000000127517_mask.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c0284591deadf6010bf780acf16124231c42d677 Binary files /dev/null and b/fluid/PaddleCV/faster_rcnn/image/000000127517_mask.jpg differ diff --git a/fluid/PaddleCV/faster_rcnn/image/Faster_RCNN.jpg b/fluid/PaddleCV/faster_rcnn/image/Faster_RCNN.jpg deleted file mode 100644 index c2ab8085c914979eb23a59734d54797b6580e956..0000000000000000000000000000000000000000 Binary files a/fluid/PaddleCV/faster_rcnn/image/Faster_RCNN.jpg and /dev/null differ diff --git a/fluid/PaddleCV/faster_rcnn/infer.py b/fluid/PaddleCV/faster_rcnn/infer.py index 3c7200f9de57bbd8d42df9dcb7d72c8fdca7e253..73fc7882567926ed629bbbcad892a3f225ee72e7 100644 --- a/fluid/PaddleCV/faster_rcnn/infer.py +++ b/fluid/PaddleCV/faster_rcnn/infer.py @@ -1,9 +1,7 @@ import os import time import numpy as np -from eval_helper import get_nmsed_box -from eval_helper import get_dt_res -from eval_helper import draw_bounding_box_on_image +from eval_helper import * import paddle import paddle.fluid as fluid import reader @@ -24,7 +22,7 @@ def infer(): test_list = 'annotations/instances_val2017.json' cocoGt = COCO(os.path.join(cfg.data_dir, test_list)) - numId_to_catId_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())} + num_id_to_cat_id_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())} category_ids = cocoGt.getCatIds() label_list = { item['id']: item['name'] @@ -40,7 +38,10 @@ def infer(): use_pyreader=False, is_train=False) model.build_model(image_shape) - rpn_rois, confs, locs = model.eval_out() + rpn_rois, confs, locs = model.eval_bbox_out() + pred_boxes = model.eval() + if cfg.MASK_ON: + masks = model.eval_mask_out() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) # yapf: disable @@ -53,17 +54,32 @@ def infer(): feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) dts_res = [] - fetch_list = [rpn_rois, confs, locs] + segms_res = [] + if cfg.MASK_ON: + fetch_list = [rpn_rois, confs, locs, pred_boxes, masks] + else: + fetch_list = [rpn_rois, confs, locs] data = next(infer_reader()) im_info = [data[0][1]] - rpn_rois_v, confs_v, locs_v = exe.run( - fetch_list=[v.name for v in fetch_list], - feed=feeder.feed(data), - return_numpy=False) - new_lod, nmsed_out = get_nmsed_box(rpn_rois_v, confs_v, locs_v, class_nums, - im_info, numId_to_catId_map) + result = exe.run(fetch_list=[v.name for v in fetch_list], + feed=feeder.feed(data), + return_numpy=False) + rpn_rois_v = result[0] + confs_v = result[1] + locs_v = result[2] + if cfg.MASK_ON: + pred_boxes_v = result[3] + masks_v = result[4] + new_lod = pred_boxes_v.lod() + nmsed_out = pred_boxes_v path = os.path.join(cfg.image_path, cfg.image_name) - draw_bounding_box_on_image(path, nmsed_out, cfg.draw_threshold, label_list) + image = None + if cfg.MASK_ON: + segms_out = segm_results(nmsed_out, masks_v, im_info) + image = draw_mask_on_image(path, segms_out, cfg.draw_threshold) + + draw_bounding_box_on_image(path, nmsed_out, cfg.draw_threshold, label_list, + num_id_to_cat_id_map, image) if __name__ == '__main__': diff --git a/fluid/PaddleCV/faster_rcnn/models/model_builder.py b/fluid/PaddleCV/faster_rcnn/models/model_builder.py index 9be2f330a62081107d57566962aadc32e1ac687a..eea0e16a3ca8058d7f5ba9c713d189ceb9866d65 100644 --- a/fluid/PaddleCV/faster_rcnn/models/model_builder.py +++ b/fluid/PaddleCV/faster_rcnn/models/model_builder.py @@ -16,8 +16,11 @@ import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddle.fluid.initializer import Constant from paddle.fluid.initializer import Normal +from paddle.fluid.initializer import MSRA from paddle.fluid.regularizer import L2Decay from config import cfg +import cPickle as cp +import numpy as np class FasterRCNN(object): @@ -32,7 +35,6 @@ class FasterRCNN(object): self.is_train = is_train self.use_pyreader = use_pyreader self.use_random = use_random - #self.py_reader = None def build_model(self, image_shape): self.build_input(image_shape) @@ -41,31 +43,64 @@ class FasterRCNN(object): self.rpn_heads(body_conv) # Fast RCNN self.fast_rcnn_heads(body_conv) + # Mask RCNN + if cfg.MASK_ON: + self.mask_rcnn_heads(body_conv) def loss(self): + losses = [] # Fast RCNN loss loss_cls, loss_bbox = self.fast_rcnn_loss() # RPN loss rpn_cls_loss, rpn_reg_loss = self.rpn_loss() - return loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss, + losses = [loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss] + rkeys = ['loss', 'loss_cls', 'loss_bbox', \ + 'loss_rpn_cls', 'loss_rpn_bbox',] + if cfg.MASK_ON: + loss_mask = self.mask_rcnn_loss() + losses = losses + [loss_mask] + rkeys = rkeys + ["loss_mask"] + loss = fluid.layers.sum(losses) + rloss = [loss] + losses + return rloss, rkeys - def eval_out(self): + def eval_bbox_out(self): cls_prob = fluid.layers.softmax(self.cls_score, use_cudnn=False) return [self.rpn_rois, cls_prob, self.bbox_pred] + def eval_mask_out(self): + return self.mask_fcn_logits + + def eval(self): + return self.pred_result + def build_input(self, image_shape): if self.use_pyreader: + in_shapes = [[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1], + [-1, 3], [-1, 1]] + lod_levels = [0, 1, 1, 1, 0, 0] + dtypes = [ + 'float32', 'float32', 'int32', 'int32', 'float32', 'int32' + ] + if cfg.MASK_ON: + in_shapes.append([-1, 2]) + lod_levels.append(3) + dtypes.append('float32') self.py_reader = fluid.layers.py_reader( capacity=64, - shapes=[[-1] + image_shape, [-1, 4], [-1, 1], [-1, 1], [-1, 3], - [-1, 1]], - lod_levels=[0, 1, 1, 1, 0, 0], - dtypes=[ - "float32", "float32", "int32", "int32", "float32", "int32" - ], + shapes=in_shapes, + lod_levels=lod_levels, + dtypes=dtypes, use_double_buffer=True) - self.image, self.gt_box, self.gt_label, self.is_crowd, \ - self.im_info, self.im_id = fluid.layers.read_file(self.py_reader) + ins = fluid.layers.read_file(self.py_reader) + self.image = ins[0] + self.gt_box = ins[1] + self.gt_label = ins[2] + self.is_crowd = ins[3] + self.im_info = ins[4] + self.im_id = ins[5] + if cfg.MASK_ON: + self.gt_masks = ins[6] else: self.image = fluid.layers.data( name='image', shape=image_shape, dtype='float32') @@ -74,22 +109,26 @@ class FasterRCNN(object): self.gt_label = fluid.layers.data( name='gt_label', shape=[1], dtype='int32', lod_level=1) self.is_crowd = fluid.layers.data( - name='is_crowd', - shape=[-1], - dtype='int32', - lod_level=1, - append_batch_size=False) + name='is_crowd', shape=[1], dtype='int32', lod_level=1) self.im_info = fluid.layers.data( name='im_info', shape=[3], dtype='float32') self.im_id = fluid.layers.data( name='im_id', shape=[1], dtype='int32') + if cfg.MASK_ON: + self.gt_masks = fluid.layers.data( + name='gt_masks', shape=[2], dtype='float32', lod_level=3) def feeds(self): if not self.is_train: return [self.image, self.im_info, self.im_id] + if not cfg.MASK_ON: + return [ + self.image, self.gt_box, self.gt_label, self.is_crowd, + self.im_info, self.im_id + ] return [ self.image, self.gt_box, self.gt_label, self.is_crowd, self.im_info, - self.im_id + self.im_id, self.gt_masks ] def rpn_heads(self, rpn_input): @@ -157,7 +196,7 @@ class FasterRCNN(object): nms_thresh = param_obj.rpn_nms_thresh min_size = param_obj.rpn_min_size eta = param_obj.rpn_eta - rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals( + self.rpn_rois, self.rpn_roi_probs = fluid.layers.generate_proposals( scores=rpn_cls_score_prob, bbox_deltas=self.rpn_bbox_pred, im_info=self.im_info, @@ -168,10 +207,9 @@ class FasterRCNN(object): nms_thresh=nms_thresh, min_size=min_size, eta=eta) - self.rpn_rois = rpn_rois if self.is_train: outs = fluid.layers.generate_proposal_labels( - rpn_rois=rpn_rois, + rpn_rois=self.rpn_rois, gt_classes=self.gt_label, is_crowd=self.is_crowd, gt_boxes=self.gt_box, @@ -191,27 +229,28 @@ class FasterRCNN(object): self.bbox_inside_weights = outs[3] self.bbox_outside_weights = outs[4] + if cfg.MASK_ON: + mask_out = fluid.layers.generate_mask_labels( + im_info=self.im_info, + gt_classes=self.gt_label, + is_crowd=self.is_crowd, + gt_segms=self.gt_masks, + rois=self.rois, + labels_int32=self.labels_int32, + num_classes=cfg.class_num, + resolution=cfg.resolution) + self.mask_rois = mask_out[0] + self.roi_has_mask_int32 = mask_out[1] + self.mask_int32 = mask_out[2] + def fast_rcnn_heads(self, roi_input): if self.is_train: pool_rois = self.rois else: pool_rois = self.rpn_rois - if cfg.roi_func == 'RoIPool': - pool = fluid.layers.roi_pool( - input=roi_input, - rois=pool_rois, - pooled_height=cfg.roi_resolution, - pooled_width=cfg.roi_resolution, - spatial_scale=cfg.spatial_scale) - elif cfg.roi_func == 'RoIAlign': - pool = fluid.layers.roi_align( - input=roi_input, - rois=pool_rois, - pooled_height=cfg.roi_resolution, - pooled_width=cfg.roi_resolution, - spatial_scale=cfg.spatial_scale, - sampling_ratio=cfg.sampling_ratio) - rcnn_out = self.add_roi_box_head_func(pool) + self.res5_2_sum = self.add_roi_box_head_func(roi_input, pool_rois) + rcnn_out = fluid.layers.pool2d( + self.res5_2_sum, pool_type='avg', pool_size=7, name='res5_pool') self.cls_score = fluid.layers.fc(input=rcnn_out, size=cfg.class_num, act=None, @@ -237,15 +276,110 @@ class FasterRCNN(object): learning_rate=2., regularizer=L2Decay(0.))) + def SuffixNet(self, conv5): + mask_out = fluid.layers.conv2d_transpose( + input=conv5, + num_filters=cfg.dim_reduced, + filter_size=2, + stride=2, + act='relu', + param_attr=ParamAttr( + name='conv5_mask_w', initializer=MSRA(uniform=False)), + bias_attr=ParamAttr( + name='conv5_mask_b', learning_rate=2., regularizer=L2Decay(0.))) + act_func = None + if not self.is_train: + act_func = 'sigmoid' + mask_fcn_logits = fluid.layers.conv2d( + input=mask_out, + num_filters=cfg.class_num, + filter_size=1, + act=act_func, + param_attr=ParamAttr( + name='mask_fcn_logits_w', initializer=MSRA(uniform=False)), + bias_attr=ParamAttr( + name="mask_fcn_logits_b", + learning_rate=2., + regularizer=L2Decay(0.))) + + if not self.is_train: + mask_fcn_logits = fluid.layers.lod_reset(mask_fcn_logits, + self.pred_result) + return mask_fcn_logits + + def mask_rcnn_heads(self, mask_input): + if self.is_train: + conv5 = fluid.layers.gather(self.res5_2_sum, + self.roi_has_mask_int32) + self.mask_fcn_logits = self.SuffixNet(conv5) + else: + im_scale = fluid.layers.slice( + self.im_info, [1], starts=[2], ends=[3]) + im_scale_lod = fluid.layers.sequence_expand(im_scale, self.rpn_rois) + boxes = self.rpn_rois / im_scale_lod + cls_prob = fluid.layers.softmax(self.cls_score, use_cudnn=False) + bbox_pred_reshape = fluid.layers.reshape(self.bbox_pred, + (-1, cfg.class_num, 4)) + decoded_box = fluid.layers.box_coder( + prior_box=boxes, + prior_box_var=cfg.bbox_reg_weights, + target_box=bbox_pred_reshape, + code_type='decode_center_size', + box_normalized=False, + axis=1) + cliped_box = fluid.layers.box_clip( + input=decoded_box, im_info=self.im_info) + self.pred_result = fluid.layers.multiclass_nms( + bboxes=cliped_box, + scores=cls_prob, + score_threshold=cfg.TEST.score_thresh, + nms_top_k=-1, + nms_threshold=cfg.TEST.nms_thresh, + keep_top_k=cfg.TEST.detections_per_im, + normalized=False) + pred_res_shape = fluid.layers.shape(self.pred_result) + shape = fluid.layers.reduce_prod(pred_res_shape) + shape = fluid.layers.reshape(shape, [1, 1]) + ones = fluid.layers.fill_constant([1, 1], value=1, dtype='int32') + cond = fluid.layers.equal(x=shape, y=ones) + ie = fluid.layers.IfElse(cond) + + with ie.true_block(): + pred_res_null = ie.input(self.pred_result) + ie.output(pred_res_null) + with ie.false_block(): + pred_res = ie.input(self.pred_result) + pred_boxes = fluid.layers.slice( + pred_res, [1], starts=[2], ends=[6]) + im_scale_lod = fluid.layers.sequence_expand(im_scale, + pred_boxes) + mask_rois = pred_boxes * im_scale_lod + conv5 = self.add_roi_box_head_func(mask_input, mask_rois) + mask_fcn = self.SuffixNet(conv5) + ie.output(mask_fcn) + self.mask_fcn_logits = ie()[0] + + def mask_rcnn_loss(self): + mask_label = fluid.layers.cast(x=self.mask_int32, dtype='float32') + reshape_dim = cfg.class_num * cfg.resolution * cfg.resolution + mask_fcn_logits_reshape = fluid.layers.reshape(self.mask_fcn_logits, + (-1, reshape_dim)) + + loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits( + x=mask_fcn_logits_reshape, + label=mask_label, + ignore_index=-1, + normalize=True) + loss_mask = fluid.layers.reduce_sum(loss_mask, name='loss_mask') + return loss_mask + def fast_rcnn_loss(self): labels_int64 = fluid.layers.cast(x=self.labels_int32, dtype='int64') labels_int64.stop_gradient = True - #loss_cls = fluid.layers.softmax_with_cross_entropy( - # logits=cls_score, - # label=labels_int64 - # ) - cls_prob = fluid.layers.softmax(self.cls_score, use_cudnn=False) - loss_cls = fluid.layers.cross_entropy(cls_prob, labels_int64) + loss_cls = fluid.layers.softmax_with_cross_entropy( + logits=self.cls_score, + label=labels_int64, + numeric_stable_mode=True, ) loss_cls = fluid.layers.reduce_mean(loss_cls) loss_bbox = fluid.layers.smooth_l1( x=self.bbox_pred, @@ -303,5 +437,4 @@ class FasterRCNN(object): norm = fluid.layers.reduce_prod(score_shape) norm.stop_gradient = True rpn_reg_loss = rpn_reg_loss / norm - return rpn_cls_loss, rpn_reg_loss diff --git a/fluid/PaddleCV/faster_rcnn/models/resnet.py b/fluid/PaddleCV/faster_rcnn/models/resnet.py index e868a1506afe4124036d2ecef4acf83676ba02f9..8093470241b3297c44a2e42b5162e25cac1514be 100644 --- a/fluid/PaddleCV/faster_rcnn/models/resnet.py +++ b/fluid/PaddleCV/faster_rcnn/models/resnet.py @@ -160,8 +160,22 @@ def add_ResNet50_conv4_body(body_input): return res4 -def add_ResNet_roi_conv5_head(head_input): - res5 = layer_warp(bottleneck, head_input, 512, 3, 2, name="res5") - res5_pool = fluid.layers.pool2d( - res5, pool_type='avg', pool_size=7, name='res5_pool') - return res5_pool +def add_ResNet_roi_conv5_head(head_input, rois): + if cfg.roi_func == 'RoIPool': + pool = fluid.layers.roi_pool( + input=head_input, + rois=rois, + pooled_height=cfg.roi_resolution, + pooled_width=cfg.roi_resolution, + spatial_scale=cfg.spatial_scale) + elif cfg.roi_func == 'RoIAlign': + pool = fluid.layers.roi_align( + input=head_input, + rois=rois, + pooled_height=cfg.roi_resolution, + pooled_width=cfg.roi_resolution, + spatial_scale=cfg.spatial_scale, + sampling_ratio=cfg.sampling_ratio) + + res5 = layer_warp(bottleneck, pool, 512, 3, 2, name="res5") + return res5 diff --git a/fluid/PaddleCV/faster_rcnn/profile.py b/fluid/PaddleCV/faster_rcnn/profile.py index 73634bd6773ecb1606a43b297f0966e2d55506b3..b19d51b66465bad1b7d090bdc7c8a021c2a640d5 100644 --- a/fluid/PaddleCV/faster_rcnn/profile.py +++ b/fluid/PaddleCV/faster_rcnn/profile.py @@ -43,12 +43,9 @@ def train(): use_pyreader=cfg.use_pyreader, use_random=False) model.build_model(image_shape) - loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss = model.loss() - loss_cls.persistable = True - loss_bbox.persistable = True - rpn_cls_loss.persistable = True - rpn_reg_loss.persistable = True - loss = loss_cls + loss_bbox + rpn_cls_loss + rpn_reg_loss + losses, keys = model.loss() + loss = losses[0] + fetch_list = [loss] boundaries = cfg.lr_steps gamma = cfg.lr_gamma @@ -95,8 +92,6 @@ def train(): train_reader = reader.train(batch_size=total_batch_size, shuffle=False) feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) - fetch_list = [loss, loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss] - def run(iterations): reader_time = [] run_time = [] @@ -109,20 +104,16 @@ def train(): reader_time.append(end_time - start_time) start_time = time.time() if cfg.parallel: - losses = train_exe.run(fetch_list=[v.name for v in fetch_list], - feed=feeder.feed(data)) + outs = train_exe.run(fetch_list=[v.name for v in fetch_list], + feed=feeder.feed(data)) else: - losses = exe.run(fluid.default_main_program(), - fetch_list=[v.name for v in fetch_list], - feed=feeder.feed(data)) + outs = exe.run(fluid.default_main_program(), + fetch_list=[v.name for v in fetch_list], + feed=feeder.feed(data)) end_time = time.time() run_time.append(end_time - start_time) total_images += len(data) - - lr = np.array(fluid.global_scope().find_var('learning_rate') - .get_tensor()) - print("Batch {:d}, lr {:.6f}, loss {:.6f} ".format(batch_id, lr[0], - losses[0][0])) + print("Batch {:d}, loss {:.6f} ".format(batch_id, np.mean(outs[0]))) return reader_time, run_time, total_images def run_pyreader(iterations): @@ -135,18 +126,16 @@ def train(): for batch_id in range(iterations): start_time = time.time() if cfg.parallel: - losses = train_exe.run( + outs = train_exe.run( fetch_list=[v.name for v in fetch_list]) else: - losses = exe.run(fluid.default_main_program(), - fetch_list=[v.name for v in fetch_list]) + outs = exe.run(fluid.default_main_program(), + fetch_list=[v.name for v in fetch_list]) end_time = time.time() run_time.append(end_time - start_time) total_images += devices_num - lr = np.array(fluid.global_scope().find_var('learning_rate') - .get_tensor()) - print("Batch {:d}, lr {:.6f}, loss {:.6f} ".format(batch_id, lr[ - 0], losses[0][0])) + print("Batch {:d}, loss {:.6f} ".format(batch_id, + np.mean(outs[0]))) except fluid.core.EOFException: py_reader.reset() diff --git a/fluid/PaddleCV/faster_rcnn/reader.py b/fluid/PaddleCV/faster_rcnn/reader.py index 50b3d88b3995442c49833e6f69c7d6f04ea84064..fcd7234e839769993bce0f2aea73ef0ca58e201f 100644 --- a/fluid/PaddleCV/faster_rcnn/reader.py +++ b/fluid/PaddleCV/faster_rcnn/reader.py @@ -27,6 +27,46 @@ from collections import deque from roidbs import JsonDataset import data_utils from config import cfg +import segm_utils + + +def roidb_reader(roidb, mode): + im, im_scales = data_utils.get_image_blob(roidb, mode) + im_id = roidb['id'] + im_height = np.round(roidb['height'] * im_scales) + im_width = np.round(roidb['width'] * im_scales) + im_info = np.array([im_height, im_width, im_scales], dtype=np.float32) + if mode == 'test' or mode == 'infer': + return im, im_info, im_id + + gt_boxes = roidb['gt_boxes'].astype('float32') + gt_classes = roidb['gt_classes'].astype('int32') + is_crowd = roidb['is_crowd'].astype('int32') + segms = roidb['segms'] + + outs = (im, gt_boxes, gt_classes, is_crowd, im_info, im_id) + + if cfg.MASK_ON: + gt_masks = [] + valid = True + segms = roidb['segms'] + assert len(segms) == is_crowd.shape[0] + for i in range(len(roidb['segms'])): + segm, iscrowd = segms[i], is_crowd[i] + gt_segm = [] + if iscrowd: + gt_segm.append([[0, 0]]) + else: + for poly in segm: + if len(poly) == 0: + valid = False + break + gt_segm.append(np.array(poly).reshape(-1, 2)) + if (not valid) or len(gt_segm) == 0: + break + gt_masks.append(gt_segm) + outs = outs + (gt_masks, ) + return outs def coco(mode, @@ -63,19 +103,6 @@ def coco(mode, print("{} on {} with {} roidbs".format(mode, cfg.dataset, len(roidbs))) - def roidb_reader(roidb, mode): - im, im_scales = data_utils.get_image_blob(roidb, mode) - im_id = roidb['id'] - im_height = np.round(roidb['height'] * im_scales) - im_width = np.round(roidb['width'] * im_scales) - im_info = np.array([im_height, im_width, im_scales], dtype=np.float32) - if mode == 'test' or mode == 'infer': - return im, im_info, im_id - gt_boxes = roidb['gt_boxes'].astype('float32') - gt_classes = roidb['gt_classes'].astype('int32') - is_crowd = roidb['is_crowd'].astype('int32') - return im, gt_boxes, gt_classes, is_crowd, im_info, im_id - def padding_minibatch(batch_data): if len(batch_data) == 1: return batch_data @@ -93,22 +120,31 @@ def coco(mode, def reader(): if mode == "train": - roidb_perm = deque(np.random.permutation(roidbs)) + if shuffle: + roidb_perm = deque(np.random.permutation(roidbs)) + else: + roidb_perm = deque(roidbs) roidb_cur = 0 + count = 0 batch_out = [] while True: roidb = roidb_perm[0] roidb_cur += 1 roidb_perm.rotate(-1) if roidb_cur >= len(roidbs): - roidb_perm = deque(np.random.permutation(roidbs)) + if shuffle: + roidb_perm = deque(np.random.permutation(roidbs)) + else: + roidb_perm = deque(roidbs) roidb_cur = 0 - im, gt_boxes, gt_classes, is_crowd, im_info, im_id = roidb_reader( - roidb, mode) - if gt_boxes.shape[0] == 0: + # im, gt_boxes, gt_classes, is_crowd, im_info, im_id, gt_masks + datas = roidb_reader(roidb, mode) + if datas[1].shape[0] == 0: continue - batch_out.append( - (im, gt_boxes, gt_classes, is_crowd, im_info, im_id)) + if cfg.MASK_ON: + if len(datas[-1]) != datas[1].shape[0]: + continue + batch_out.append(datas) if not padding_total: if len(batch_out) == batch_size: yield padding_minibatch(batch_out) @@ -124,7 +160,9 @@ def coco(mode, yield sub_batch_out sub_batch_out = [] batch_out = [] - + count += 1 + if count >= cfg.max_iter + 1: + return elif mode == "test": batch_out = [] for roidb in roidbs: diff --git a/fluid/PaddleCV/faster_rcnn/roidbs.py b/fluid/PaddleCV/faster_rcnn/roidbs.py index b21dc9ed1fb01275aa57b158b0151a56ae297dc7..accc4f615b94399cc0204a60bf04cd6a413ed75a 100644 --- a/fluid/PaddleCV/faster_rcnn/roidbs.py +++ b/fluid/PaddleCV/faster_rcnn/roidbs.py @@ -36,6 +36,7 @@ import matplotlib matplotlib.use('Agg') from pycocotools.coco import COCO import box_utils +import segm_utils from config import cfg logger = logging.getLogger(__name__) @@ -91,8 +92,9 @@ class JsonDataset(object): end_time = time.time() print('_add_gt_annotations took {:.3f}s'.format(end_time - start_time)) - print('Appending horizontally-flipped training examples...') - self._extend_with_flipped_entries(roidb) + if cfg.TRAIN.use_flipped: + print('Appending horizontally-flipped training examples...') + self._extend_with_flipped_entries(roidb) print('Loaded dataset: {:s}'.format(self.name)) print('{:d} roidb entries'.format(len(roidb))) if self.is_train: @@ -111,6 +113,7 @@ class JsonDataset(object): entry['gt_classes'] = np.empty((0), dtype=np.int32) entry['gt_id'] = np.empty((0), dtype=np.int32) entry['is_crowd'] = np.empty((0), dtype=np.bool) + entry['segms'] = [] # Remove unwanted fields that come from the json file (if they exist) for k in ['date_captured', 'url', 'license', 'file_name']: if k in entry: @@ -126,9 +129,15 @@ class JsonDataset(object): objs = self.COCO.loadAnns(ann_ids) # Sanitize bboxes -- some are invalid valid_objs = [] + valid_segms = [] width = entry['width'] height = entry['height'] for obj in objs: + if isinstance(obj['segmentation'], list): + # Valid polygons have >= 3 points, so require >= 6 coordinates + obj['segmentation'] = [ + p for p in obj['segmentation'] if len(p) >= 6 + ] if obj['area'] < cfg.TRAIN.gt_min_area: continue if 'ignore' in obj and obj['ignore'] == 1: @@ -141,6 +150,8 @@ class JsonDataset(object): if obj['area'] > 0 and x2 > x1 and y2 > y1: obj['clean_bbox'] = [x1, y1, x2, y2] valid_objs.append(obj) + valid_segms.append(obj['segmentation']) + num_valid_objs = len(valid_objs) gt_boxes = np.zeros((num_valid_objs, 4), dtype=entry['gt_boxes'].dtype) @@ -158,6 +169,7 @@ class JsonDataset(object): entry['gt_classes'] = np.append(entry['gt_classes'], gt_classes) entry['gt_id'] = np.append(entry['gt_id'], gt_id) entry['is_crowd'] = np.append(entry['is_crowd'], is_crowd) + entry['segms'].extend(valid_segms) def _extend_with_flipped_entries(self, roidb): """Flip each entry in the given roidb and return a new roidb that is the @@ -175,11 +187,13 @@ class JsonDataset(object): gt_boxes[:, 2] = width - oldx1 - 1 assert (gt_boxes[:, 2] >= gt_boxes[:, 0]).all() flipped_entry = {} - dont_copy = ('gt_boxes', 'flipped') + dont_copy = ('gt_boxes', 'flipped', 'segms') for k, v in entry.items(): if k not in dont_copy: flipped_entry[k] = v flipped_entry['gt_boxes'] = gt_boxes + flipped_entry['segms'] = segm_utils.flip_segms( + entry['segms'], entry['height'], entry['width']) flipped_entry['flipped'] = True flipped_roidb.append(flipped_entry) roidb.extend(flipped_roidb) diff --git a/fluid/PaddleCV/faster_rcnn/scripts/eval.sh b/fluid/PaddleCV/faster_rcnn/scripts/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..922380acf52e594931506e791990319d152d9260 --- /dev/null +++ b/fluid/PaddleCV/faster_rcnn/scripts/eval.sh @@ -0,0 +1,17 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0 + +model=$1 # faster_rcnn, mask_rcnn +if [ "$model" = "faster_rcnn" ]; then + mask_on="--MASK_ON False" +elif [ "$model" = "mask_rcnn" ]; then + mask_on="--MASK_ON True" +else + echo "Invalid model provided. Please use one of {faster_rcnn, mask_rcnn}" + exit 1 +fi + +python -u ../eval_coco_map.py \ + $mask_on \ + --pretrained_model=../output/model_iter179999 \ + --data_dir=../dataset/coco/ \ diff --git a/fluid/PaddleCV/faster_rcnn/scripts/infer.sh b/fluid/PaddleCV/faster_rcnn/scripts/infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..6f0e02730b9db07568c31a280825f75e321eab64 --- /dev/null +++ b/fluid/PaddleCV/faster_rcnn/scripts/infer.sh @@ -0,0 +1,19 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0 + +model=$1 # faster_rcnn, mask_rcnn +if [ "$model" = "faster_rcnn" ]; then + mask_on="--MASK_ON False" +elif [ "$model" = "mask_rcnn" ]; then + mask_on="--MASK_ON True" +else + echo "Invalid model provided. Please use one of {faster_rcnn, mask_rcnn}" + exit 1 +fi + +python -u ../infer.py \ + $mask_on \ + --pretrained_model=../output/model_iter179999 \ + --image_path=../dataset/coco/val2017/ \ + --image_name=000000000139.jpg \ + --draw_threshold=0.6 diff --git a/fluid/PaddleCV/faster_rcnn/scripts/train.sh b/fluid/PaddleCV/faster_rcnn/scripts/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..83c67e6c39121c0fecec5cd7c037d14ab53c619d --- /dev/null +++ b/fluid/PaddleCV/faster_rcnn/scripts/train.sh @@ -0,0 +1,19 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + +model=$1 # faster_rcnn, mask_rcnn +if [ "$model" = "faster_rcnn" ]; then + mask_on="--MASK_ON False" +elif [ "$model" = "mask_rcnn" ]; then + mask_on="--MASK_ON True" +else + echo "Invalid model provided. Please use one of {faster_rcnn, mask_rcnn}" + exit 1 +fi + +python -u ../train.py \ + $mask_on \ + --model_save_dir=../output/ \ + --pretrained_model=../imagenet_resnet50_fusebn/ \ + --data_dir=../dataset/coco/ \ + diff --git a/fluid/PaddleCV/faster_rcnn/segm_utils.py b/fluid/PaddleCV/faster_rcnn/segm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..17b72228bc4284dc5936d4a3fda5c2422c4aa958 --- /dev/null +++ b/fluid/PaddleCV/faster_rcnn/segm_utils.py @@ -0,0 +1,88 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://w_idxw.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on: +# -------------------------------------------------------- +# Detectron +# Copyright (c) 2017-present, Facebook, Inc. +# Licensed under the Apache License, Version 2.0; +# Written by Ross Girshick +# -------------------------------------------------------- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np +import pycocotools.mask as mask_util +import cv2 + + +def is_poly(segm): + """Determine if segm is a polygon. Valid segm expected (polygon or RLE).""" + assert isinstance(segm, (list, dict)), \ + 'Invalid segm type: {}'.format(type(segm)) + return isinstance(segm, list) + + +def segms_to_rle(segms, height, width): + rle = segms + if isinstance(segms, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = mask_util.frPyObjects(segms, height, width) + rle = mask_util.merge(rles) + elif isinstance(segms['counts'], list): + # uncompressed RLE + rle = mask_util.frPyObjects(segms, height, width) + return rle + + +def segms_to_mask(segms, iscrowd, height, width): + print('segms: ', segms) + if iscrowd: + return [[0 for i in range(width)] for j in range(height)] + rle = segms_to_rle(segms, height, width) + mask = mask_util.decode(rle) + return mask + + +def flip_segms(segms, height, width): + """Left/right flip each mask in a list of masks.""" + + def _flip_poly(poly, width): + flipped_poly = np.array(poly) + flipped_poly[0::2] = width - np.array(poly[0::2]) - 1 + return flipped_poly.tolist() + + def _flip_rle(rle, height, width): + if 'counts' in rle and type(rle['counts']) == list: + # Magic RLE format handling painfully discovered by looking at the + # COCO API showAnns function. + rle = mask_util.frPyObjects([rle], height, width) + mask = mask_util.decode(rle) + mask = mask[:, ::-1, :] + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + flipped_segms = [] + for segm in segms: + if is_poly(segm): + # Polygon format + flipped_segms.append([_flip_poly(poly, width) for poly in segm]) + else: + # RLE format + flipped_segms.append(_flip_rle(segm, height, width)) + return flipped_segms diff --git a/fluid/PaddleCV/faster_rcnn/train.py b/fluid/PaddleCV/faster_rcnn/train.py index b840d2855c09e1df91601d30df1503a6003aeef5..46ee4c40d83800c0b08414ac29c8c93413c54cb2 100644 --- a/fluid/PaddleCV/faster_rcnn/train.py +++ b/fluid/PaddleCV/faster_rcnn/train.py @@ -20,7 +20,8 @@ import sys import numpy as np import time import shutil -from utility import parse_args, print_arguments, SmoothedValue +from utility import parse_args, print_arguments, SmoothedValue, TrainingStats, now_time +import collections import paddle import paddle.fluid as fluid @@ -55,30 +56,30 @@ def train(): use_pyreader=cfg.use_pyreader, use_random=use_random) model.build_model(image_shape) - loss_cls, loss_bbox, rpn_cls_loss, rpn_reg_loss = model.loss() - loss_cls.persistable = True - loss_bbox.persistable = True - rpn_cls_loss.persistable = True - rpn_reg_loss.persistable = True - loss = loss_cls + loss_bbox + rpn_cls_loss + rpn_reg_loss + losses, keys = model.loss() + loss = losses[0] + fetch_list = losses boundaries = cfg.lr_steps gamma = cfg.lr_gamma step_num = len(cfg.lr_steps) values = [learning_rate * (gamma**i) for i in range(step_num + 1)] + lr = exponential_with_warmup_decay( + learning_rate=learning_rate, + boundaries=boundaries, + values=values, + warmup_iter=cfg.warm_up_iter, + warmup_factor=cfg.warm_up_factor) optimizer = fluid.optimizer.Momentum( - learning_rate=exponential_with_warmup_decay( - learning_rate=learning_rate, - boundaries=boundaries, - values=values, - warmup_iter=cfg.warm_up_iter, - warmup_factor=cfg.warm_up_factor), + learning_rate=lr, regularization=fluid.regularizer.L2Decay(cfg.weight_decay), momentum=cfg.momentum) optimizer.minimize(loss) + fetch_list = fetch_list + [lr] - fluid.memory_optimize(fluid.default_main_program()) + fluid.memory_optimize( + fluid.default_main_program(), skip_opt_set=set(fetch_list)) place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) @@ -107,7 +108,8 @@ def train(): py_reader = model.py_reader py_reader.decorate_paddle_reader(train_reader) else: - train_reader = reader.train(batch_size=total_batch_size, shuffle=shuffle) + train_reader = reader.train( + batch_size=total_batch_size, shuffle=shuffle) feeder = fluid.DataFeeder(place=place, feed_list=model.feeds()) def save_model(postfix): @@ -116,88 +118,72 @@ def train(): shutil.rmtree(model_path) fluid.io.save_persistables(exe, model_path) - fetch_list = [loss, rpn_cls_loss, rpn_reg_loss, loss_cls, loss_bbox] - def train_loop_pyreader(): py_reader.start() - smoothed_loss = SmoothedValue(cfg.log_window) + train_stats = TrainingStats(cfg.log_window, keys) try: start_time = time.time() prev_start_time = start_time - total_time = 0 - last_loss = 0 - every_pass_loss = [] for iter_id in range(cfg.max_iter): prev_start_time = start_time start_time = time.time() - losses = train_exe.run(fetch_list=[v.name for v in fetch_list]) - every_pass_loss.append(np.mean(np.array(losses[0]))) - smoothed_loss.add_value(np.mean(np.array(losses[0]))) - lr = np.array(fluid.global_scope().find_var('learning_rate') - .get_tensor()) - print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format( - iter_id, lr[0], - smoothed_loss.get_median_value( - ), start_time - prev_start_time)) - end_time = time.time() - total_time += end_time - start_time - last_loss = np.mean(np.array(losses[0])) - + outs = train_exe.run(fetch_list=[v.name for v in fetch_list]) + stats = {k: np.array(v).mean() for k, v in zip(keys, outs[:-1])} + train_stats.update(stats) + logs = train_stats.log() + strs = '{}, lr: {:.5f}, {}, time: {:.3f}'.format( + now_time(), + np.mean(outs[-1]), logs, start_time - prev_start_time) + print(strs) sys.stdout.flush() if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0: save_model("model_iter{}".format(iter_id)) - # only for ce + end_time = time.time() + total_time = end_time - start_time + last_loss = np.array(outs[0]).mean() if cfg.enable_ce: gpu_num = devices_num epoch_idx = iter_id + 1 loss = last_loss print("kpis\teach_pass_duration_card%s\t%s" % - (gpu_num, total_time / epoch_idx)) - print("kpis\ttrain_loss_card%s\t%s" % - (gpu_num, loss)) - - except fluid.core.EOFException: + (gpu_num, total_time / epoch_idx)) + print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, loss)) + except (StopIteration, fluid.core.EOFException): py_reader.reset() - return np.mean(every_pass_loss) def train_loop(): start_time = time.time() prev_start_time = start_time start = start_time - total_time = 0 - last_loss = 0 - every_pass_loss = [] - smoothed_loss = SmoothedValue(cfg.log_window) + train_stats = TrainingStats(cfg.log_window, keys) for iter_id, data in enumerate(train_reader()): prev_start_time = start_time start_time = time.time() - losses = train_exe.run(fetch_list=[v.name for v in fetch_list], - feed=feeder.feed(data)) - loss_v = np.mean(np.array(losses[0])) - every_pass_loss.append(loss_v) - smoothed_loss.add_value(loss_v) - lr = np.array(fluid.global_scope().find_var('learning_rate') - .get_tensor()) - end_time = time.time() - total_time += end_time - start_time - last_loss = loss_v - print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format( - iter_id, lr[0], - smoothed_loss.get_median_value(), start_time - prev_start_time)) + outs = train_exe.run(fetch_list=[v.name for v in fetch_list], + feed=feeder.feed(data)) + stats = {k: np.array(v).mean() for k, v in zip(keys, outs[:-1])} + train_stats.update(stats) + logs = train_stats.log() + strs = '{}, lr: {:.5f}, {}, time: {:.3f}'.format( + now_time(), + np.mean(outs[-1]), logs, start_time - prev_start_time) + print(strs) sys.stdout.flush() if (iter_id + 1) % cfg.TRAIN.snapshot_iter == 0: save_model("model_iter{}".format(iter_id)) if (iter_id + 1) == cfg.max_iter: break + end_time = time.time() + total_time = end_time - start_time + last_loss = np.array(outs[0]).mean() # only for ce if cfg.enable_ce: gpu_num = devices_num epoch_idx = iter_id + 1 loss = last_loss print("kpis\teach_pass_duration_card%s\t%s" % - (gpu_num, total_time / epoch_idx)) - print("kpis\ttrain_loss_card%s\t%s" % - (gpu_num, loss)) + (gpu_num, total_time / epoch_idx)) + print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, loss)) return np.mean(every_pass_loss) diff --git a/fluid/PaddleCV/faster_rcnn/utility.py b/fluid/PaddleCV/faster_rcnn/utility.py index f428de4c17ac9a6bd1600f52267d6718426adc78..2dbe74f6bfa62fd2ec6e533355684036c7b9fe8b 100644 --- a/fluid/PaddleCV/faster_rcnn/utility.py +++ b/fluid/PaddleCV/faster_rcnn/utility.py @@ -22,7 +22,9 @@ import sys import distutils.util import numpy as np import six +import collections from collections import deque +import datetime from paddle.fluid import core import argparse import functools @@ -85,6 +87,37 @@ class SmoothedValue(object): return np.median(self.deque) +def now_time(): + return datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + + +class TrainingStats(object): + def __init__(self, window_size, stats_keys): + self.smoothed_losses_and_metrics = { + key: SmoothedValue(window_size) + for key in stats_keys + } + + def update(self, stats): + for k, v in self.smoothed_losses_and_metrics.items(): + v.add_value(stats[k]) + + def get(self, extras=None): + stats = collections.OrderedDict() + if extras: + for k, v in extras.items(): + stats[k] = v + for k, v in self.smoothed_losses_and_metrics.items(): + stats[k] = round(v.get_median_value(), 3) + + return stats + + def log(self, extras=None): + d = self.get(extras) + strs = ', '.join(str(dict({x: y})).strip('{}') for x, y in d.items()) + return strs + + def parse_args(): """return all args """ @@ -108,7 +141,7 @@ def parse_args(): add_arg('learning_rate', float, 0.01, "Learning rate.") add_arg('max_iter', int, 180000, "Iter number.") add_arg('log_window', int, 20, "Log smooth window, set 1 for debug, set 20 for train.") - # FAST RCNN + # RCNN # RPN add_arg('anchor_sizes', int, [32,64,128,256,512], "The size of anchors.") add_arg('aspect_ratios', float, [0.5,1.0,2.0], "The ratio of anchors.") @@ -116,6 +149,7 @@ def parse_args(): add_arg('rpn_stride', float, [16.,16.], "Stride of the feature map that RPN is attached.") add_arg('rpn_nms_thresh', float, 0.7, "NMS threshold used on RPN proposals") # TRAIN TEST INFER + add_arg('MASK_ON', bool, False, "Option for different models. If False, choose faster_rcnn. If True, choose mask_rcnn") add_arg('im_per_batch', int, 1, "Minibatch size.") add_arg('max_size', int, 1333, "The resized image height.") add_arg('scales', int, [800], "The resized image height.")