未验证 提交 8fd6f220 编写于 作者: W wangguanzhong 提交者: GitHub

Cherry pick fix mask eval (#499)

* refine mask eval

* fix mask eval
上级 ce3c663a
......@@ -105,7 +105,36 @@ def mask_eval(results, anno_file, outfile, resolution, thresh_binarize=0.5):
coco_gt = COCO(anno_file)
clsid2catid = {i + 1: v for i, v in enumerate(coco_gt.getCatIds())}
segm_results = mask2out(results, clsid2catid, resolution, thresh_binarize)
segm_results = []
for t in results:
im_ids = np.array(t['im_id'][0])
bboxes = t['bbox'][0]
lengths = t['bbox'][1][0]
masks = t['mask']
if bboxes.shape == (1, 1) or bboxes is None:
continue
if len(bboxes.tolist()) == 0:
continue
s = 0
for i in range(len(lengths)):
num = lengths[i]
im_id = int(im_ids[i][0])
clsid_scores = bboxes[s:s + num][:, 0:2]
mask = masks[s:s + num]
s += num
for j in range(num):
clsid, score = clsid_scores[j].tolist()
catid = int(clsid2catid[clsid])
segm = mask[j]
segm['counts'] = segm['counts'].decode('utf8')
coco_res = {
'image_id': im_id,
'category_id': int(catid),
'segmentation': segm,
'score': score
}
segm_results.append(coco_res)
if len(segm_results) == 0:
logger.warning("The number of valid mask detected is zero.\n \
Please use reasonable model and check input data.")
......
......@@ -103,7 +103,8 @@ def eval_run(exe,
cfg=None,
sub_prog=None,
sub_keys=None,
sub_values=None):
sub_values=None,
resolution=None):
"""
Run evaluation program, return program outputs.
"""
......@@ -152,6 +153,9 @@ def eval_run(exe,
if multi_scale_test:
res = clean_res(
res, ['im_info', 'bbox', 'im_id', 'im_shape', 'mask'])
if 'mask' in res:
from ppdet.utils.post_process import mask_encode
res['mask'] = mask_encode(res, resolution)
results.append(res)
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
......
......@@ -18,7 +18,7 @@ from __future__ import print_function
import logging
import numpy as np
import cv2
import paddle.fluid as fluid
__all__ = ['nms']
......@@ -210,3 +210,64 @@ def mstest_mask_post_process(result, cfg):
mask_pred = np.mean(mask_list, axis=0)
return {'mask': (mask_pred, [[len(mask_pred)]])}
def mask_encode(results, resolution, thresh_binarize=0.5):
import pycocotools.mask as mask_util
from ppdet.utils.coco_eval import expand_boxes
scale = (resolution + 2.0) / resolution
bboxes = results['bbox'][0]
masks = results['mask'][0]
lengths = results['mask'][1][0]
im_shapes = results['im_shape'][0]
segms = []
if bboxes.shape == (1, 1) or bboxes is None:
return segms
if len(bboxes.tolist()) == 0:
return segms
s = 0
# for each sample
for i in range(len(lengths)):
num = lengths[i]
im_shape = im_shapes[i]
bbox = bboxes[s:s + num][:, 2:]
clsid_scores = bboxes[s:s + num][:, 0:2]
mask = masks[s:s + num]
s += num
im_h = int(im_shape[0])
im_w = int(im_shape[1])
expand_bbox = expand_boxes(bbox, scale)
expand_bbox = expand_bbox.astype(np.int32)
padded_mask = np.zeros(
(resolution + 2, resolution + 2), dtype=np.float32)
for j in range(num):
xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
clsid, score = clsid_scores[j].tolist()
clsid = int(clsid)
padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
w = xmax - xmin + 1
h = ymax - ymin + 1
w = np.maximum(w, 1)
h = np.maximum(h, 1)
resized_mask = cv2.resize(padded_mask, (w, h))
resized_mask = np.array(
resized_mask > thresh_binarize, dtype=np.uint8)
im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
x0 = min(max(xmin, 0), im_w)
x1 = min(max(xmax + 1, 0), im_w)
y0 = min(max(ymin, 0), im_h)
y1 = min(max(ymax + 1, 0), im_h)
im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
x0 - xmin):(x1 - xmin)]
segm = mask_util.encode(
np.array(
im_mask[:, :, np.newaxis], order='F'))[0]
segms.append(segm)
return segms
......@@ -152,14 +152,14 @@ def main():
if 'weights' in cfg:
checkpoint.load_params(exe, startup_prog, cfg.weights)
resolution = None
if 'Mask' in cfg.architecture:
resolution = model.mask_head.resolution
results = eval_run(exe, compile_program, loader, keys, values, cls, cfg,
sub_eval_prog, sub_keys, sub_values)
sub_eval_prog, sub_keys, sub_values, resolution)
#print(cfg['EvalReader']['dataset'].__dict__)
# evaluation
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
# if map_type not set, use default 11point, only use in VOC eval
map_type = cfg.map_type if 'map_type' in cfg else '11point'
eval_results(
......
......@@ -257,11 +257,17 @@ def main():
if FLAGS.eval:
# evaluation
results = eval_run(exe, compiled_eval_prog, eval_loader,
eval_keys, eval_values, eval_cls)
resolution = None
if 'mask' in results[0]:
if 'Mask' in cfg.architecture:
resolution = model.mask_head.resolution
results = eval_run(
exe,
compiled_eval_prog,
eval_loader,
eval_keys,
eval_values,
eval_cls,
resolution=resolution)
box_ap_stats = eval_results(
results, cfg.metric, cfg.num_classes, resolution,
is_bbox_normalized, FLAGS.output_eval, map_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册