未验证 提交 2c65b659 编写于 作者: Q qingqing01 提交者: GitHub

Improve inference for RCNN (#1864)

* Improve inference
* Update reader
* Update doc
上级 71c2aa9e
......@@ -51,9 +51,8 @@ Please make sure that pretrained_model is downloaded and loaded correctly, other
To train the model, [cocoapi](https://github.com/cocodataset/cocoapi) is needed. Install the cocoapi:
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
......@@ -66,25 +65,29 @@ After data preparation, one can start the training step by:
- Faster RCNN
```
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model} \
--data_dir=${path_to_data} \
--MASK_ON=False
```
- Mask RCNN
```
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model} \
--data_dir=${path_to_data} \
--MASK_ON=True
```
- Set ```export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7``` to specifiy 8 GPU to train.
- Set ```MASK_ON``` to choose Faster RCNN or Mask RCNN model.
- For more help on arguments:
- Set ```export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7``` to specifiy 8 GPU to train.
- Set ```MASK_ON``` to choose Faster RCNN or Mask RCNN model.
- For more help on arguments:
python train.py --help
python train.py --help
**data reader introduction:**
......@@ -116,20 +119,25 @@ Evaluation is to evaluate the performance of a trained model. This sample provid
- Faster RCNN
```
python eval_coco_map.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
--pretrained_model=${path_to_trained_model} \
--MASK_ON=False
```
- Mask RCNN
```
python eval_coco_map.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
--pretrained_model=${path_to_trainde_model} \
--MASK_ON=True
```
- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to eval.
- Set ```MASK_ON``` to choose Faster RCNN or Mask RCNN model.
- Set ```--pretrained_model=${path_to_trained_model}``` to specifiy the trained model, not the initialized model.
- Set ```export CUDA_VISIBLE_DEVICES=0``` to specifiy one GPU to eval.
- Set ```MASK_ON``` to choose Faster RCNN or Mask RCNN model.
Evalutaion result is shown as below:
......@@ -159,12 +167,14 @@ Mask RCNN:
Inference is used to get prediction score or image features based on trained models. `infer.py` is the main executor for inference, one can start infer step by:
python infer.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
--image_path=dataset/coco/val2017/ \
--image_name=000000000139.jpg \
--draw_threshold=0.6
```
python infer.py \
--pretrained_model=${path_to_trained_model} \
--image_path=dataset/coco/val2017/000000000139.jpg \
--draw_threshold=0.6
```
Please set the model path and image path correctly.
Visualization of infer result is shown as below:
<p align="center">
......
......@@ -50,9 +50,8 @@ Mask RCNN同样为两阶段框架,第一阶段扫描图像生成候选框;
训练前需要首先下载[cocoapi](https://github.com/cocodataset/cocoapi)
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
git clone https://github.com/cocodataset/cocoapi.git
cd cocoapi/PythonAPI
# if cython is not installed
pip install Cython
# Install into global site-packages
......@@ -65,25 +64,29 @@ Mask RCNN同样为两阶段框架,第一阶段扫描图像生成候选框;
- Faster RCNN
```
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model} \
--data_dir=${path_to_data} \
--MASK_ON=False
```
- Mask RCNN
```
python train.py \
--model_save_dir=output/ \
--pretrained_model=${path_to_pretrain_model} \
--data_dir=${path_to_data} \
--MASK_ON=True
```
- 通过设置export CUDA\_VISIBLE\_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。
- 通过设置```MASK_ON```选择Faster RCNN和Mask RCNN模型。
- 可选参数见:
- 通过设置export CUDA\_VISIBLE\_DEVICES=0,1,2,3,4,5,6,7指定8卡GPU训练。
- 通过设置```MASK_ON```选择Faster RCNN和Mask RCNN模型。
- 可选参数见:
python train.py --help
python train.py --help
**数据读取器说明:** 数据读取器定义在reader.py中。所有图像将短边等比例缩放至`scales`,若长边大于`max_size`, 则再次将长边等比例缩放至`max_size`。在训练阶段,对图像采用水平翻转。支持将同一个batch内的图像padding为相同尺寸。
......@@ -110,20 +113,25 @@ Mask RCNN同样为两阶段框架,第一阶段扫描图像生成候选框;
- Faster RCNN
```
python eval_coco_map.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
--pretrained_model=${path_to_trained_model} \
--MASK_ON=False
```
- Mask RCNN
```
python eval_coco_map.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
--pretrained_model=${path_to_trained_model} \
--MASK_ON=True
```
- 通过设置export CUDA\_VISIBLE\_DEVICES=0指定单卡GPU评估。
- 通过设置```MASK_ON```选择Faster RCNN和Mask RCNN模型。
- 通过设置`--pretrained_model=${path_to_trained_model}`指定训练好的模型,注意不是初始化的模型。
- 通过设置`export CUDA\_VISIBLE\_DEVICES=0`指定单卡GPU评估。
- 通过设置```MASK_ON```选择Faster RCNN和Mask RCNN模型。
下表为模型评估结果:
......@@ -155,12 +163,14 @@ Mask RCNN:
模型推断可以获取图像中的物体及其对应的类别,`infer.py`是主要执行程序,调用示例如下:
python infer.py \
--dataset=coco2017 \
--pretrained_model=${path_to_pretrain_model} \
--image_path=dataset/coco/val2017/ \
--image_name=000000000139.jpg \
--draw_threshold=0.6
```
python infer.py \
--pretrained_model=${path_to_trained_model} \
--image_path=dataset/coco/val2017/000000000139.jpg \
--draw_threshold=0.6
```
注意,请正确设置`${path_to_trained_model}`模型和预测图片的路径。
下图为模型可视化预测结果:
<p align="center">
......
......@@ -31,6 +31,23 @@ from config import cfg
import os
class DatasetPath(object):
def __init__(self, mode):
self.mode = mode
mode_name = 'train' if mode == 'train' else 'val'
if cfg.dataset != 'coco2014' and cfg.dataset != 'coco2017':
raise NotImplementedError('Dataset {} not supported'.format(
cfg.dataset))
self.sub_name = mode_name + cfg.dataset[-4:]
def get_data_dir(self):
return os.path.join(cfg.data_dir, self.sub_name)
def get_file_list(self):
sfile_list = 'annotations/instances_' + self.sub_name + '.json'
return os.path.join(cfg.data_dir, sfile_list)
def get_image_blob(roidb, mode):
"""Builds an input blob from the images in the roidb at the specified
scales.
......
......@@ -29,7 +29,7 @@ import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config import cfg
from roidbs import DatasetPath
from data_utils import DatasetPath
def eval():
......
......@@ -199,8 +199,7 @@ def get_segms_res(batch_size, lod, segms_out, data, num_id_to_cat_id_map):
def draw_bounding_box_on_image(image_path,
nms_out,
draw_threshold,
label_list,
num_id_to_cat_id_map,
labels_map,
image=None):
if image is None:
image = Image.open(image_path)
......@@ -209,7 +208,6 @@ def draw_bounding_box_on_image(image_path,
for dt in np.array(nms_out):
num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
category_id = num_id_to_cat_id_map[num_id]
if score < draw_threshold:
continue
draw.line(
......@@ -218,7 +216,7 @@ def draw_bounding_box_on_image(image_path,
width=2,
fill='red')
if image.mode == 'RGB':
draw.text((xmin, ymin), label_list[int(category_id)], (255, 255, 0))
draw.text((xmin, ymin), labels_map[num_id], (255, 255, 0))
image_name = image_path.split('/')[-1]
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)
......@@ -299,3 +297,90 @@ def segm_results(im_results, masks, im_info):
segms_results = np.vstack([segms_results[k] for k in range(len(lod) - 1)])
im_results = np.hstack([segms_results, im_results])
return im_results[:, :3]
def coco17_labels():
labels_map = {
0: 'background',
1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
12: 'stop sign',
13: 'parking meter',
14: 'bench',
15: 'bird',
16: 'cat',
17: 'dog',
18: 'horse',
19: 'sheep',
20: 'cow',
21: 'elephant',
22: 'bear',
23: 'zebra',
24: 'giraffe',
25: 'backpack',
26: 'umbrella',
27: 'handbag',
28: 'tie',
29: 'suitcase',
30: 'frisbee',
31: 'skis',
32: 'snowboard',
33: 'sports ball',
34: 'kite',
35: 'baseball bat',
36: 'baseball glove',
37: 'skateboard',
38: 'surfboard',
39: 'tennis racket',
40: 'bottle',
41: 'wine glass',
42: 'cup',
43: 'fork',
44: 'knife',
45: 'spoon',
46: 'bowl',
47: 'banana',
48: 'apple',
49: 'sandwich',
50: 'orange',
51: 'broccoli',
52: 'carrot',
53: 'hot dog',
54: 'pizza',
55: 'donut',
56: 'cake',
57: 'chair',
58: 'couch',
59: 'potted plant',
60: 'bed',
61: 'dining table',
62: 'toilet',
63: 'tv',
64: 'laptop',
65: 'mouse',
66: 'remote',
67: 'keyboard',
68: 'cell phone',
69: 'microwave',
70: 'oven',
71: 'toaster',
72: 'sink',
73: 'refrigerator',
74: 'book',
75: 'clock',
76: 'vase',
77: 'scissors',
78: 'teddy bear',
79: 'hair drier',
80: 'toothbrush'
}
return labels_map
......@@ -8,26 +8,36 @@ import reader
from utility import print_arguments, parse_args
import models.model_builder as model_builder
import models.resnet as resnet
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config import cfg
from roidbs import DatasetPath
from data_utils import DatasetPath
def infer():
data_path = DatasetPath('val')
test_list = data_path.get_file_list()
try:
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
data_path = DatasetPath('val')
test_list = data_path.get_file_list()
coco_api = COCO(test_list)
cid = coco_api.getCatIds()
cat_id_to_num_id_map = {
v: i + 1
for i, v in enumerate(coco_api.getCatIds())
}
category_ids = coco_api.getCatIds()
labels_map = {
cat_id_to_num_id_map[item['id']]: item['name']
for item in coco_api.loadCats(category_ids)
}
labels_map[0] = 'background'
except:
print("The COCO dataset is not exist, use the defalut mapping of class "
"index and real category name on COCO17.")
assert cfg.dataset == 'coco2017'
labels_map = coco17_labels()
cocoGt = COCO(test_list)
num_id_to_cat_id_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
category_ids = cocoGt.getCatIds()
label_list = {
item['id']: item['name']
for item in cocoGt.loadCats(category_ids)
}
label_list[0] = ['background']
image_shape = [3, cfg.TEST.max_size, cfg.TEST.max_size]
class_nums = cfg.class_num
......@@ -43,12 +53,14 @@ def infer():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if cfg.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
if not os.path.exists(cfg.pretrained_model):
raise ValueError("Model path [%s] does not exist." % (cfg.pretrained_model))
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
# yapf: enable
infer_reader = reader.infer()
infer_reader = reader.infer(cfg.image_path)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
dts_res = []
......@@ -67,14 +79,14 @@ def infer():
masks_v = result[1]
new_lod = pred_boxes_v.lod()
nmsed_out = pred_boxes_v
path = os.path.join(cfg.image_path, cfg.image_name)
image = None
if cfg.MASK_ON:
segms_out = segm_results(nmsed_out, masks_v, im_info)
image = draw_mask_on_image(path, segms_out, cfg.draw_threshold)
image = draw_mask_on_image(cfg.image_path, segms_out,
cfg.draw_threshold)
draw_bounding_box_on_image(path, nmsed_out, cfg.draw_threshold, label_list,
num_id_to_cat_id_map, image)
draw_bounding_box_on_image(cfg.image_path, nmsed_out, cfg.draw_threshold,
labels_map, image)
if __name__ == '__main__':
......
......@@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils.image_util import *
import random
from PIL import Image
from PIL import ImageDraw
import numpy as np
import xml.etree.ElementTree
import os
import time
import copy
import six
import cv2
from collections import deque
from roidbs import JsonDataset
......@@ -36,8 +34,6 @@ def roidb_reader(roidb, mode):
im_height = np.round(roidb['height'] * im_scales)
im_width = np.round(roidb['width'] * im_scales)
im_info = np.array([im_height, im_width, im_scales], dtype=np.float32)
if mode == 'infer':
return im, im_info
if mode == 'val':
return im, im_info, im_id
......@@ -76,11 +72,8 @@ def coco(mode,
total_batch_size=None,
padding_total=False,
shuffle=False):
cfg.mean_value = np.array(cfg.pixel_means)[np.newaxis,
np.newaxis, :].astype('float32')
total_batch_size = total_batch_size if total_batch_size else batch_size
if mode != 'infer':
assert total_batch_size % batch_size == 0
assert total_batch_size % batch_size == 0
json_dataset = JsonDataset(mode)
roidbs = json_dataset.get_roidb()
......@@ -160,14 +153,6 @@ def coco(mode,
if len(batch_out) != 0:
yield batch_out
else:
for roidb in roidbs:
if cfg.image_name not in roidb['image']:
continue
im, im_info = roidb_reader(roidb, mode)
batch_out = [(im, im_info)]
yield batch_out
return reader
......@@ -180,5 +165,17 @@ def test(batch_size, total_batch_size=None, padding_total=False):
return coco('val', batch_size, total_batch_size, shuffle=False)
def infer():
return coco('infer')
def infer(file_path):
def reader():
if not os.path.exists(file_path):
raise ValueError("Image path [%s] does not exist." % (file_path))
im = cv2.imread(file_path)
im = im.astype(np.float32, copy=False)
im -= cfg.pixel_means
im_height, im_width, channel = im.shape
channel_swap = (2, 0, 1) #(channel, height, width)
im = im.transpose(channel_swap)
im_info = np.array([im_height, im_width, 1.0], dtype=np.float32)
yield [(im, im_info)]
return reader
......@@ -38,27 +38,11 @@ from pycocotools.coco import COCO
import box_utils
import segm_utils
from config import cfg
from data_utils import DatasetPath
logger = logging.getLogger(__name__)
class DatasetPath(object):
def __init__(self, mode):
self.mode = mode
mode_name = 'train' if mode == 'train' else 'val'
if cfg.dataset != 'coco2014' and cfg.dataset != 'coco2017':
raise NotImplementedError('Dataset {} not supported'.format(
cfg.dataset))
self.sub_name = mode_name + cfg.dataset[-4:]
def get_data_dir(self):
return os.path.join(cfg.data_dir, self.sub_name)
def get_file_list(self):
sfile_list = 'annotations/instances_' + self.sub_name + '.json'
return os.path.join(cfg.data_dir, sfile_list)
class JsonDataset(object):
"""A class representing a COCO json dataset."""
......
......@@ -161,7 +161,6 @@ def parse_args():
# SINGLE EVAL AND DRAW
add_arg('draw_threshold', float, 0.8, "Confidence threshold to draw bbox.")
add_arg('image_path', str, 'dataset/coco/val2017', "The image path used to inference and visualize.")
add_arg('image_name', str, '', "The single image used to inference and visualize.")
# ce
parser.add_argument(
'--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册