未验证 提交 8a878423 编写于 作者: W wangguanzhong 提交者: GitHub

add infer (#1640)

上级 b0706ef6
......@@ -97,10 +97,7 @@ class BBoxPostProcessYOLO(object):
scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
yolo_boxes = paddle.concat(boxes_list, axis=1)
yolo_scores = paddle.concat(scores_list, axis=2)
bbox = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
# TODO: parse the lod of nmsed_bbox
# default batch size is 1
bbox_num = np.array([int(bbox.shape[0])], dtype=np.int32)
bbox, bbox_num = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
return bbox, bbox_num
......
......@@ -136,8 +136,6 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
k = 0
for i in range(len(bbox_nums)):
image_id = int(image_id[i][0])
if bboxes.shape == (1, 1):
continue
det_nums = bbox_nums[i]
for j in range(det_nums):
dt = bboxes[k]
......
......@@ -47,7 +47,7 @@ def check_gpu(use_gpu):
pass
def check_version(version='1.7.0'):
def check_version(version='2.0'):
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
......
......@@ -3,6 +3,11 @@ from __future__ import division
from __future__ import print_function
import os
import sys
import json
from ppdet.py_op.post_process import get_det_res, get_seg_res
import logging
logger = logging.getLogger(__name__)
def json_eval_results(metric, json_directory=None, dataset=None):
......@@ -28,49 +33,62 @@ def json_eval_results(metric, json_directory=None, dataset=None):
logger.info("{} not exists!".format(v_json))
def coco_eval_results(outs_res=None, include_mask=False, dataset=None):
print("start evaluate bbox using coco api")
import io
import six
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from ppdet.py_op.post_process import get_det_res, get_seg_res
anno_file = os.path.join(dataset.dataset_dir, dataset.anno_path)
cocoGt = COCO(anno_file)
catid = {
i + dataset.with_background: v
for i, v in enumerate(cocoGt.getCatIds())
}
if outs_res is not None and len(outs_res) > 0:
det_res = []
for outs in outs_res:
det_res += get_det_res(outs['bbox'], outs['bbox_num'],
outs['im_id'], catid)
def get_infer_results(outs_res, eval_type, catid):
"""
Get result at the stage of inference.
The output format is dictionary containing bbox or mask result.
with io.open("bbox.json", 'w') as outfile:
encode_func = unicode if six.PY2 else str
outfile.write(encode_func(json.dumps(det_res)))
For example, bbox result is a list and each element contains
image_id, category_id, bbox and score.
"""
if outs_res is None or len(outs_res) == 0:
raise ValueError(
'The number of valid detection result if zero. Please use reasonable model and check input data.'
)
infer_res = {}
cocoDt = cocoGt.loadRes("bbox.json")
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
if 'bbox' in eval_type:
box_res = []
for outs in outs_res:
box_res += get_det_res(outs['bbox'], outs['bbox_num'],
outs['im_id'], catid)
infer_res['bbox'] = box_res
if outs_res is not None and len(outs_res) > 0 and include_mask:
if 'mask' in eval_type:
seg_res = []
for outs in outs_res:
seg_res += get_seg_res(outs['mask'], outs['bbox_num'],
outs['im_id'], catid)
infer_res['mask'] = seg_res
return infer_res
def eval_results(res, metric, anno_file):
"""
Evalute the inference result
"""
eval_res = []
if metric == 'COCO':
from ppdet.utils.coco_eval import cocoapi_eval
if 'bbox' in res:
with open("bbox.json", 'w') as f:
json.dump(res['bbox'], f)
logger.info('The bbox result is saved to bbox.json.')
bbox_stats = cocoapi_eval('bbox.json', 'bbox', anno_file=anno_file)
eval_res.append(bbox_stats)
sys.stdout.flush()
if 'mask' in res:
with open("mask.json", 'w') as f:
json.dump(res['mask'], f)
logger.info('The mask result is saved to mask.json.')
with io.open("mask.json", 'w') as outfile:
encode_func = unicode if six.PY2 else str
outfile.write(encode_func(json.dumps(seg_res)))
seg_stats = cocoapi_eval('mask.json', 'mask', anno_file=anno_file)
eval_res.append(seg_stats)
sys.stdout.flush()
else:
raise NotImplemented("Only COCO metric is supported now.")
cocoSg = cocoGt.loadRes("mask.json")
cocoEval = COCOeval(cocoGt, cocoSg, 'segm')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
return eval_res
......@@ -26,18 +26,18 @@ __all__ = ['visualize_results']
def visualize_results(image,
bbox_res,
mask_res,
im_id,
catid2name,
threshold=0.5,
bbox_results=None,
mask_results=None):
threshold=0.5):
"""
Visualize bbox and mask results
"""
if mask_results:
image = draw_mask(image, im_id, mask_results, threshold)
if bbox_results:
image = draw_bbox(image, im_id, catid2name, bbox_results, threshold)
if bbox_res is not None:
image = draw_bbox(image, im_id, catid2name, bbox_res, threshold)
if mask_res is not None:
image = draw_mask(image, im_id, mask_res, threshold)
return image
......
......@@ -18,7 +18,7 @@ from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.eval_utils import coco_eval_results
from ppdet.utils.eval_utils import get_infer_results, eval_results
from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
import logging
......@@ -79,11 +79,22 @@ def run(FLAGS, cfg):
cost_time = time.time() - start_time
logger.info('Total sample number: {}, averge FPS: {}'.format(
sample_num, sample_num / cost_time))
eval_type = ['bbox']
if getattr(cfg, 'MaskHead', None):
eval_type.append('mask')
# Metric
coco_eval_results(
outs_res,
include_mask=True if getattr(cfg, 'MaskHead', None) else False,
dataset=cfg['EvalReader']['dataset'])
# TODO: support other metric
dataset = cfg.EvalReader['dataset']
from ppdet.utils.coco_eval import get_category_info
anno_file = dataset.get_anno()
with_background = dataset.with_background
use_default_label = dataset.use_default_label
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
infer_res = get_infer_results(outs_res, eval_type, clsid2catid)
eval_results(infer_res, cfg.metric, anno_file)
def main():
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
sys.path.append(parent_path)
# ignore numba warning
import warnings
warnings.filterwarnings('ignore')
import glob
import numpy as np
from PIL import Image
import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.visualizer import visualize_results
from ppdet.utils.cli import ArgsParser
from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_dygraph_ckpt
from ppdet.utils.eval_utils import get_infer_results
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path, has higher priority over --infer_dir")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory for storing the output visualization files.")
parser.add_argument(
"--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--use_vdl",
type=bool,
default=False,
help="whether to record the data to VisualDL.")
parser.add_argument(
'--vdl_log_dir',
type=str,
default="vdl_log_dir/image",
help='VisualDL logging directory for image.')
args = parser.parse_args()
return args
def get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def run(FLAGS, cfg):
# Model
main_arch = cfg.architecture
model = create(cfg.architecture)
dataset = cfg.TestReader['dataset']
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
dataset.set_images(test_images)
# TODO: support other metrics
imid2path = dataset.get_imid2path()
from ppdet.utils.coco_eval import get_category_info
anno_file = dataset.get_anno()
with_background = dataset.with_background
use_default_label = dataset.use_default_label
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
# Init Model
model = load_dygraph_ckpt(model, ckpt=cfg.weights)
# Data Reader
test_reader = create_reader(cfg.TestReader)
# Run Infer
for iter_id, data in enumerate(test_reader()):
# forward
model.eval()
outs = model(data, cfg.TestReader['inputs_def']['fields'], 'infer')
batch_res = get_infer_results([outs], outs.keys(), clsid2catid)
logger.info('Infer iter {}'.format(iter_id))
bbox_res = None
mask_res = None
im_ids = outs['im_id']
bbox_num = outs['bbox_num']
start = 0
for i, im_id in enumerate(im_ids):
im_id = im_ids[i]
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
end = start + bbox_num[i]
# use VisualDL to log original image
if FLAGS.use_vdl:
original_image_np = np.array(image)
vdl_writer.add_image(
"original/frame_{}".format(vdl_image_frame),
original_image_np, vdl_image_step)
if 'bbox' in batch_res:
bbox_res = batch_res['bbox'][start:end]
if 'mask' in batch_res:
mask_res = batch_res['mask'][start:end]
image = visualize_results(image, bbox_res, mask_res,
int(im_id), catid2name,
FLAGS.draw_threshold)
# use VisualDL to log image with bbox
if FLAGS.use_vdl:
infer_image_np = np.array(image)
vdl_writer.add_image("bbox/frame_{}".format(vdl_image_frame),
infer_image_np, vdl_image_step)
vdl_image_step += 1
if vdl_image_step % 10 == 0:
vdl_image_step = 0
vdl_image_frame += 1
# save image with detection
save_name = get_save_image_name(FLAGS.output_dir, image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name, quality=95)
start = end
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
run(FLAGS, cfg)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册