未验证 提交 fc6abdd2 编写于 作者: K Kaipeng Deng 提交者: GitHub

refine infer with args (#2501)

* refine infer with args

* remove samples

* fix as review

* refine code

* refine args

* move get_test_images to infer.py

* add visualize log

* fix images = []

* fix args

* refine infer.py

* fix yolov3_r34.yml
上级 6c9a86fc
......@@ -136,7 +136,6 @@ FasterRCNNTestFeed:
dataset:
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
drop_last: false
num_workers: 2
shuffle: false
......@@ -115,5 +115,3 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
......@@ -138,6 +138,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -138,6 +138,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -139,6 +139,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -139,6 +139,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -115,5 +115,3 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
......@@ -115,5 +115,3 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
......@@ -136,7 +136,6 @@ FasterRCNNTestFeed:
dataset:
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
drop_last: false
num_workers: 2
shuffle: false
......@@ -117,5 +117,3 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
......@@ -139,6 +139,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -124,6 +124,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -141,6 +141,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -141,6 +141,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -140,6 +140,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -140,6 +140,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
shuffle: False
......@@ -147,7 +147,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
num_workers: 2
use_padded_im_info: True
......@@ -147,7 +147,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
num_workers: 2
use_padded_im_info: True
......@@ -128,5 +128,3 @@ MaskRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
......@@ -129,5 +129,3 @@ MaskRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
......@@ -147,7 +147,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
num_workers: 2
use_padded_im_info: True
......@@ -147,7 +147,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
samples: 5
num_workers: 2
use_padded_im_info: True
......@@ -150,6 +150,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
num_workers: 2
use_padded_im_info: True
......@@ -151,7 +151,6 @@ FasterRCNNTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: val2017.txt
drop_last: false
image_shape: [3, 1333, 800]
num_workers: 2
......
......@@ -84,4 +84,3 @@ SSDTestFeed:
image_dir: VOCdevkit/VOC_all/JPEGImages
use_default_label: false
drop_last: false
test_file: data/voc/VOCdevkit/VOC_all/ImageSets/Main/test.txt
......@@ -80,5 +80,3 @@ YoloTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: ../val2017.txt
samples: 5
......@@ -81,5 +81,3 @@ YoloTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: ../val2017.txt
samples: 5
......@@ -17,6 +17,8 @@ YOLOv3:
ResNet:
norm_type: sync_bn
freeze_at: 0
freeze_norm: False
norm_decay: 0.
depth: 34
feature_maps: [3, 4, 5]
......@@ -81,5 +83,3 @@ YoloTestFeed:
dataset_dir: data/coco
annotation: annotations/instances_val2017.json
image_dir: val2017
test_file: ../val2017.txt
samples: 5
......@@ -81,9 +81,9 @@ def create_reader(feed, max_iter=0):
'TYPE': type(feed.dataset).__source__
}
}
if mode == 'TEST':
data_config[mode]['TEST_FILE'] = feed.test_file
if len(getattr(feed.dataset, 'images', [])) > 0:
data_config[mode]['IMAGES'] = feed.dataset.images
transform_config = {
'WORKER_CONF': {
......@@ -244,12 +244,16 @@ class SimpleDataSet(DataSet):
__source__ = 'SimpleSource'
def __init__(self,
dataset_dir=VOC_DATASET_DIR,
annotation=VOC_TEST_ANNOTATION,
image_dir=VOC_IMAGE_DIR,
use_default_label=VOC_USE_DEFAULT_LABEL):
dataset_dir=None,
annotation=None,
image_dir=None,
use_default_label=None):
super(SimpleDataSet, self).__init__(
dataset_dir=dataset_dir, annotation=annotation, image_dir=image_dir)
self.images = []
def add_images(self, images):
self.images.extend(images)
@serializable
......@@ -281,7 +285,6 @@ class DataFeed(object):
samples=-1,
drop_last=False,
with_background=True,
test_file=None,
num_workers=2,
bufsize=10,
use_process=False,
......@@ -296,7 +299,6 @@ class DataFeed(object):
self.samples = samples
self.drop_last = drop_last
self.with_background = with_background
self.test_file = test_file
self.num_workers = num_workers
self.bufsize = bufsize
self.use_process = use_process
......@@ -385,7 +387,6 @@ class TestFeed(DataFeed):
shuffle=False,
drop_last=False,
with_background=True,
test_file=None,
num_workers=2):
super(TestFeed, self).__init__(
dataset,
......@@ -397,7 +398,6 @@ class TestFeed(DataFeed):
shuffle=shuffle,
drop_last=drop_last,
with_background=with_background,
test_file=test_file,
num_workers=num_workers)
......@@ -522,7 +522,6 @@ class FasterRCNNEvalFeed(DataFeed):
shuffle=False,
samples=-1,
drop_last=False,
test_file=None,
num_workers=2,
use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN())
......@@ -536,7 +535,6 @@ class FasterRCNNEvalFeed(DataFeed):
shuffle=shuffle,
samples=samples,
drop_last=drop_last,
test_file=test_file,
num_workers=num_workers,
use_padded_im_info=use_padded_im_info)
self.mode = 'VAL'
......@@ -564,7 +562,6 @@ class FasterRCNNTestFeed(DataFeed):
shuffle=False,
samples=-1,
drop_last=False,
test_file=None,
num_workers=2,
use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN())
......@@ -580,7 +577,6 @@ class FasterRCNNTestFeed(DataFeed):
shuffle=shuffle,
samples=samples,
drop_last=drop_last,
test_file=test_file,
num_workers=num_workers,
use_padded_im_info=use_padded_im_info)
self.mode = 'TEST'
......@@ -612,7 +608,6 @@ class MaskRCNNEvalFeed(DataFeed):
shuffle=False,
samples=-1,
drop_last=False,
test_file=None,
num_workers=2,
use_process=False,
use_padded_im_info=True):
......@@ -627,7 +622,6 @@ class MaskRCNNEvalFeed(DataFeed):
shuffle=shuffle,
samples=samples,
drop_last=drop_last,
test_file=test_file,
num_workers=num_workers,
use_process=use_process,
use_padded_im_info=use_padded_im_info)
......@@ -657,7 +651,6 @@ class MaskRCNNTestFeed(DataFeed):
shuffle=False,
samples=-1,
drop_last=False,
test_file=None,
num_workers=2,
use_process=False,
use_padded_im_info=True):
......@@ -674,7 +667,6 @@ class MaskRCNNTestFeed(DataFeed):
shuffle=shuffle,
samples=samples,
drop_last=drop_last,
test_file=test_file,
num_workers=num_workers,
use_process=use_process,
use_padded_im_info=use_padded_im_info)
......@@ -805,7 +797,6 @@ class SSDTestFeed(DataFeed):
shuffle=False,
samples=-1,
drop_last=False,
test_file=None,
num_workers=8,
bufsize=10,
use_process=False):
......@@ -822,7 +813,6 @@ class SSDTestFeed(DataFeed):
shuffle=shuffle,
samples=samples,
drop_last=drop_last,
test_file=test_file,
num_workers=num_workers)
self.mode = 'TEST'
......@@ -961,10 +951,9 @@ class YoloTestFeed(DataFeed):
batch_transforms=[],
batch_size=1,
shuffle=False,
samples=1,
samples=-1,
drop_last=False,
with_background=False,
test_file=None,
num_workers=8,
num_max_boxes=50,
use_process=False):
......@@ -982,7 +971,6 @@ class YoloTestFeed(DataFeed):
samples=samples,
drop_last=drop_last,
with_background=with_background,
test_file=test_file,
num_workers=num_workers,
use_process=use_process)
self.num_max_boxes = num_max_boxes
......
......@@ -30,26 +30,23 @@ class SimpleSource(Dataset):
Load image files for testing purpose
Args:
test_file (str): list of image file names, relative to `image_dir`
image_dir (str): root dir for images
images (list): list of path of images
samples (int): number of samples to load, -1 means all
load_img (bool): should images be loaded
"""
def __init__(self,
test_file='',
image_dir=None,
images=[],
samples=-1,
load_img=True,
**kwargs):
super(SimpleSource, self).__init__()
self._epoch = -1
assert test_file != '' and os.path.isfile(test_file), \
"test file not found: " + test_file
self._fname = test_file
self._image_dir = image_dir
assert image_dir is not None and os.path.isdir(image_dir), \
"image directory not found: " + image_dir
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
self._images = images
self._fname = None
self._simple = None
self._pos = -1
self._drained = False
......@@ -68,32 +65,25 @@ class SimpleSource(Dataset):
sample = copy.deepcopy(self._simple[self._pos])
if self._load_img:
sample['image'] = self._load_image(sample['im_file'])
else:
sample['im_file'] = os.path.join(self._image_dir,
sample['im_file'])
self._pos += 1
return sample
def _load(self):
assert os.path.isfile(self._fname) and self._fname.endswith('.txt'), \
"invalid test file path"
ct = 0
records = []
with open(self._fname, 'r') as fr:
while True:
line = fr.readline().strip()
if not line or (self._samples > 0 and ct >= self._samples):
break
rec = {'im_id': np.array([ct]), 'im_file': line}
self._imid2path[ct] = line
ct += 1
records.append(rec)
assert len(records) > 0, "no image file found in " + self._fname
for image in self._images:
if self._samples > 0 and ct >= self._samples:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "no image file found"
return records
def _load_image(self, where):
fn = os.path.join(self._image_dir, where)
with open(fn, 'rb') as f:
with open(where, 'rb') as f:
return f.read()
def reset(self):
......
......@@ -61,6 +61,16 @@ def parse_args():
action='store_true',
default=False,
help="Whether perform evaluation in train")
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Image directory path to perform inference.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path to perform inference, --infer-img has a higher priority than --image-dir")
parser.add_argument(
"-o", "--opt", nargs=REMAINDER, help="set configuration options")
args = parser.parse_args()
......
......@@ -38,22 +38,27 @@ def visualize_results(image_path,
bbox_results=None,
mask_results=None):
"""
TODO(dengkaipeng): add more comments
Visualize bbox and mask results
"""
image = None
if not os.path.exists(SAVE_HOME):
os.makedirs(SAVE_HOME)
logger.info("Image {} detect: ".format(image_path))
image = Image.open(image_path)
if mask_results:
image = draw_mask(image_path, mask_results, threshold)
image = draw_mask(image, mask_results, threshold)
if bbox_results:
draw_bbox(image_path, catid2name, bbox_results, threshold, image)
image = draw_bbox(image, catid2name, bbox_results, threshold)
save_name = get_save_image_name(image_path)
logger.info("Detection results save in {}\n".format(save_name))
image.save(save_name)
def draw_mask(image_path, segms, threshold, alpha=0.7, save_image=False):
def draw_mask(image, segms, threshold, alpha=0.7):
"""
TODO(dengkaipeng): add more comments
Draw mask on image
"""
image = Image.open(image_path)
im_width, im_height = image.size
mask_color_id = 0
w_ratio = .4
......@@ -72,23 +77,13 @@ def draw_mask(image_path, segms, threshold, alpha=0.7, save_image=False):
image[idx[0], idx[1], :] *= 1.0 - alpha
image[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(image.astype('uint8'))
if not os.path.exists(SAVE_HOME):
os.makedirs(SAVE_HOME)
if save_image:
save_name = get_save_image_name(image_path)
logger.info("Detection mask results save in {}".format(save_name))
image.save(save_name)
return image
def draw_bbox(image_path, catid2name, bboxes, threshold, image=None):
def draw_bbox(image, catid2name, bboxes, threshold):
"""
TODO(dengkaipeng): add more comments
Draw bbox on image
"""
if image is None:
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
......@@ -106,13 +101,12 @@ def draw_bbox(image_path, catid2name, bboxes, threshold, image=None):
fill='red')
if image.mode == 'RGB':
draw.text((xmin, ymin), catid2name[catid], (255, 255, 0))
logger.info("\t {:15s} at {:25} score: {:.5f}".format(
catid2name[catid],
str(list(map(int, [xmin, ymin, xmax, ymax]))),
score))
if not os.path.exists(SAVE_HOME):
os.makedirs(SAVE_HOME)
save_name = get_save_image_name(image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name)
return image
def get_save_image_name(image_path):
"""
......
......@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
import os
import glob
import numpy as np
......@@ -37,6 +38,32 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer-img or --infer-dir should be set"
images = []
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
images.append(infer_img)
return images
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
for fmt in ['jpg', 'jpeg', 'png', 'bmp']:
images.extend(glob.glob('{}/*.{}'.format(infer_dir, fmt)))
assert len(images) > 0, "no image found in {} with " \
"extension {}".format(infer_dir, image_ext)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def main():
args = parse_args()
cfg = load_config(args.config)
......@@ -53,6 +80,9 @@ def main():
else:
test_feed = create(cfg['test_feed'])
test_images = get_test_images(args.infer_dir, args.infer_img)
test_feed.dataset.add_images(test_images)
place = fluid.CUDAPlace(0) if cfg['use_gpu'] else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -103,7 +133,7 @@ def main():
logger.info('Infer iter {}'.format(iter_id))
im_id = int(res['im_id'][0])
image_path = os.path.join(test_feed.dataset.image_dir, imid2path[im_id])
image_path = imid2path[im_id]
if cfg['metric'] == 'COCO':
bbox_results = None
mask_results = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册