未验证 提交 8a878423 编写于 作者: W wangguanzhong 提交者: GitHub

add infer (#1640)

上级 b0706ef6
...@@ -97,10 +97,7 @@ class BBoxPostProcessYOLO(object): ...@@ -97,10 +97,7 @@ class BBoxPostProcessYOLO(object):
scores_list.append(paddle.transpose(scores, perm=[0, 2, 1])) scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
yolo_boxes = paddle.concat(boxes_list, axis=1) yolo_boxes = paddle.concat(boxes_list, axis=1)
yolo_scores = paddle.concat(scores_list, axis=2) yolo_scores = paddle.concat(scores_list, axis=2)
bbox = self.nms(bboxes=yolo_boxes, scores=yolo_scores) bbox, bbox_num = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
# TODO: parse the lod of nmsed_bbox
# default batch size is 1
bbox_num = np.array([int(bbox.shape[0])], dtype=np.int32)
return bbox, bbox_num return bbox, bbox_num
......
...@@ -136,8 +136,6 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map): ...@@ -136,8 +136,6 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
k = 0 k = 0
for i in range(len(bbox_nums)): for i in range(len(bbox_nums)):
image_id = int(image_id[i][0]) image_id = int(image_id[i][0])
if bboxes.shape == (1, 1):
continue
det_nums = bbox_nums[i] det_nums = bbox_nums[i]
for j in range(det_nums): for j in range(det_nums):
dt = bboxes[k] dt = bboxes[k]
......
...@@ -47,7 +47,7 @@ def check_gpu(use_gpu): ...@@ -47,7 +47,7 @@ def check_gpu(use_gpu):
pass pass
def check_version(version='1.7.0'): def check_version(version='2.0'):
""" """
Log error and exit when the installed version of paddlepaddle is Log error and exit when the installed version of paddlepaddle is
not satisfied. not satisfied.
......
...@@ -3,6 +3,11 @@ from __future__ import division ...@@ -3,6 +3,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import json
from ppdet.py_op.post_process import get_det_res, get_seg_res
import logging
logger = logging.getLogger(__name__)
def json_eval_results(metric, json_directory=None, dataset=None): def json_eval_results(metric, json_directory=None, dataset=None):
...@@ -28,49 +33,62 @@ def json_eval_results(metric, json_directory=None, dataset=None): ...@@ -28,49 +33,62 @@ def json_eval_results(metric, json_directory=None, dataset=None):
logger.info("{} not exists!".format(v_json)) logger.info("{} not exists!".format(v_json))
def coco_eval_results(outs_res=None, include_mask=False, dataset=None): def get_infer_results(outs_res, eval_type, catid):
print("start evaluate bbox using coco api") """
import io Get result at the stage of inference.
import six The output format is dictionary containing bbox or mask result.
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from ppdet.py_op.post_process import get_det_res, get_seg_res
anno_file = os.path.join(dataset.dataset_dir, dataset.anno_path)
cocoGt = COCO(anno_file)
catid = {
i + dataset.with_background: v
for i, v in enumerate(cocoGt.getCatIds())
}
if outs_res is not None and len(outs_res) > 0:
det_res = []
for outs in outs_res:
det_res += get_det_res(outs['bbox'], outs['bbox_num'],
outs['im_id'], catid)
with io.open("bbox.json", 'w') as outfile: For example, bbox result is a list and each element contains
encode_func = unicode if six.PY2 else str image_id, category_id, bbox and score.
outfile.write(encode_func(json.dumps(det_res))) """
if outs_res is None or len(outs_res) == 0:
raise ValueError(
'The number of valid detection result if zero. Please use reasonable model and check input data.'
)
infer_res = {}
cocoDt = cocoGt.loadRes("bbox.json") if 'bbox' in eval_type:
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') box_res = []
cocoEval.evaluate() for outs in outs_res:
cocoEval.accumulate() box_res += get_det_res(outs['bbox'], outs['bbox_num'],
cocoEval.summarize() outs['im_id'], catid)
infer_res['bbox'] = box_res
if outs_res is not None and len(outs_res) > 0 and include_mask: if 'mask' in eval_type:
seg_res = [] seg_res = []
for outs in outs_res: for outs in outs_res:
seg_res += get_seg_res(outs['mask'], outs['bbox_num'], seg_res += get_seg_res(outs['mask'], outs['bbox_num'],
outs['im_id'], catid) outs['im_id'], catid)
infer_res['mask'] = seg_res
return infer_res
with io.open("mask.json", 'w') as outfile: def eval_results(res, metric, anno_file):
encode_func = unicode if six.PY2 else str """
outfile.write(encode_func(json.dumps(seg_res))) Evalute the inference result
"""
eval_res = []
if metric == 'COCO':
from ppdet.utils.coco_eval import cocoapi_eval
if 'bbox' in res:
with open("bbox.json", 'w') as f:
json.dump(res['bbox'], f)
logger.info('The bbox result is saved to bbox.json.')
bbox_stats = cocoapi_eval('bbox.json', 'bbox', anno_file=anno_file)
eval_res.append(bbox_stats)
sys.stdout.flush()
if 'mask' in res:
with open("mask.json", 'w') as f:
json.dump(res['mask'], f)
logger.info('The mask result is saved to mask.json.')
seg_stats = cocoapi_eval('mask.json', 'mask', anno_file=anno_file)
eval_res.append(seg_stats)
sys.stdout.flush()
else:
raise NotImplemented("Only COCO metric is supported now.")
cocoSg = cocoGt.loadRes("mask.json") return eval_res
cocoEval = COCOeval(cocoGt, cocoSg, 'segm')
cocoEval.evaluate()
cocoEval.accumulate()
cocoEval.summarize()
...@@ -26,18 +26,18 @@ __all__ = ['visualize_results'] ...@@ -26,18 +26,18 @@ __all__ = ['visualize_results']
def visualize_results(image, def visualize_results(image,
bbox_res,
mask_res,
im_id, im_id,
catid2name, catid2name,
threshold=0.5, threshold=0.5):
bbox_results=None,
mask_results=None):
""" """
Visualize bbox and mask results Visualize bbox and mask results
""" """
if mask_results: if bbox_res is not None:
image = draw_mask(image, im_id, mask_results, threshold) image = draw_bbox(image, im_id, catid2name, bbox_res, threshold)
if bbox_results: if mask_res is not None:
image = draw_bbox(image, im_id, catid2name, bbox_results, threshold) image = draw_mask(image, im_id, mask_res, threshold)
return image return image
......
...@@ -18,7 +18,7 @@ from paddle.distributed import ParallelEnv ...@@ -18,7 +18,7 @@ from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.eval_utils import coco_eval_results from ppdet.utils.eval_utils import get_infer_results, eval_results
from ppdet.data.reader import create_reader from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt from ppdet.utils.checkpoint import load_dygraph_ckpt, save_dygraph_ckpt
import logging import logging
...@@ -79,11 +79,22 @@ def run(FLAGS, cfg): ...@@ -79,11 +79,22 @@ def run(FLAGS, cfg):
cost_time = time.time() - start_time cost_time = time.time() - start_time
logger.info('Total sample number: {}, averge FPS: {}'.format( logger.info('Total sample number: {}, averge FPS: {}'.format(
sample_num, sample_num / cost_time)) sample_num, sample_num / cost_time))
eval_type = ['bbox']
if getattr(cfg, 'MaskHead', None):
eval_type.append('mask')
# Metric # Metric
coco_eval_results( # TODO: support other metric
outs_res, dataset = cfg.EvalReader['dataset']
include_mask=True if getattr(cfg, 'MaskHead', None) else False, from ppdet.utils.coco_eval import get_category_info
dataset=cfg['EvalReader']['dataset']) anno_file = dataset.get_anno()
with_background = dataset.with_background
use_default_label = dataset.use_default_label
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
infer_res = get_infer_results(outs_res, eval_type, clsid2catid)
eval_results(infer_res, cfg.metric, anno_file)
def main(): def main():
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
sys.path.append(parent_path)
# ignore numba warning
import warnings
warnings.filterwarnings('ignore')
import glob
import numpy as np
from PIL import Image
import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.visualizer import visualize_results
from ppdet.utils.cli import ArgsParser
from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_dygraph_ckpt
from ppdet.utils.eval_utils import get_infer_results
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path, has higher priority over --infer_dir")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory for storing the output visualization files.")
parser.add_argument(
"--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--use_vdl",
type=bool,
default=False,
help="whether to record the data to VisualDL.")
parser.add_argument(
'--vdl_log_dir',
type=str,
default="vdl_log_dir/image",
help='VisualDL logging directory for image.')
args = parser.parse_args()
return args
def get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def run(FLAGS, cfg):
# Model
main_arch = cfg.architecture
model = create(cfg.architecture)
dataset = cfg.TestReader['dataset']
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
dataset.set_images(test_images)
# TODO: support other metrics
imid2path = dataset.get_imid2path()
from ppdet.utils.coco_eval import get_category_info
anno_file = dataset.get_anno()
with_background = dataset.with_background
use_default_label = dataset.use_default_label
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
# Init Model
model = load_dygraph_ckpt(model, ckpt=cfg.weights)
# Data Reader
test_reader = create_reader(cfg.TestReader)
# Run Infer
for iter_id, data in enumerate(test_reader()):
# forward
model.eval()
outs = model(data, cfg.TestReader['inputs_def']['fields'], 'infer')
batch_res = get_infer_results([outs], outs.keys(), clsid2catid)
logger.info('Infer iter {}'.format(iter_id))
bbox_res = None
mask_res = None
im_ids = outs['im_id']
bbox_num = outs['bbox_num']
start = 0
for i, im_id in enumerate(im_ids):
im_id = im_ids[i]
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
end = start + bbox_num[i]
# use VisualDL to log original image
if FLAGS.use_vdl:
original_image_np = np.array(image)
vdl_writer.add_image(
"original/frame_{}".format(vdl_image_frame),
original_image_np, vdl_image_step)
if 'bbox' in batch_res:
bbox_res = batch_res['bbox'][start:end]
if 'mask' in batch_res:
mask_res = batch_res['mask'][start:end]
image = visualize_results(image, bbox_res, mask_res,
int(im_id), catid2name,
FLAGS.draw_threshold)
# use VisualDL to log image with bbox
if FLAGS.use_vdl:
infer_image_np = np.array(image)
vdl_writer.add_image("bbox/frame_{}".format(vdl_image_frame),
infer_image_np, vdl_image_step)
vdl_image_step += 1
if vdl_image_step % 10 == 0:
vdl_image_step = 0
vdl_image_frame += 1
# save image with detection
save_name = get_save_image_name(FLAGS.output_dir, image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name, quality=95)
start = end
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
run(FLAGS, cfg)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册