未验证 提交 fc6abdd2 编写于 作者: K Kaipeng Deng 提交者: GitHub

refine infer with args (#2501)

* refine infer with args

* remove samples

* fix as review

* refine code

* refine args

* move get_test_images to infer.py

* add visualize log

* fix images = []

* fix args

* refine infer.py

* fix yolov3_r34.yml
上级 6c9a86fc
...@@ -136,7 +136,6 @@ FasterRCNNTestFeed: ...@@ -136,7 +136,6 @@ FasterRCNNTestFeed:
dataset: dataset:
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
drop_last: false drop_last: false
num_workers: 2 num_workers: 2
shuffle: false shuffle: false
...@@ -115,5 +115,3 @@ FasterRCNNTestFeed: ...@@ -115,5 +115,3 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
...@@ -138,6 +138,5 @@ FasterRCNNTestFeed: ...@@ -138,6 +138,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -138,6 +138,5 @@ FasterRCNNTestFeed: ...@@ -138,6 +138,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -139,6 +139,5 @@ FasterRCNNTestFeed: ...@@ -139,6 +139,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -139,6 +139,5 @@ FasterRCNNTestFeed: ...@@ -139,6 +139,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -115,5 +115,3 @@ FasterRCNNTestFeed: ...@@ -115,5 +115,3 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
...@@ -115,5 +115,3 @@ FasterRCNNTestFeed: ...@@ -115,5 +115,3 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
...@@ -136,7 +136,6 @@ FasterRCNNTestFeed: ...@@ -136,7 +136,6 @@ FasterRCNNTestFeed:
dataset: dataset:
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
drop_last: false drop_last: false
num_workers: 2 num_workers: 2
shuffle: false shuffle: false
...@@ -117,5 +117,3 @@ FasterRCNNTestFeed: ...@@ -117,5 +117,3 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
...@@ -139,6 +139,5 @@ FasterRCNNTestFeed: ...@@ -139,6 +139,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -124,6 +124,5 @@ FasterRCNNTestFeed: ...@@ -124,6 +124,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -141,6 +141,5 @@ FasterRCNNTestFeed: ...@@ -141,6 +141,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -141,6 +141,5 @@ FasterRCNNTestFeed: ...@@ -141,6 +141,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -140,6 +140,5 @@ FasterRCNNTestFeed: ...@@ -140,6 +140,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -140,6 +140,5 @@ FasterRCNNTestFeed: ...@@ -140,6 +140,5 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
shuffle: False shuffle: False
...@@ -147,7 +147,5 @@ MaskRCNNTestFeed: ...@@ -147,7 +147,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
num_workers: 2 num_workers: 2
use_padded_im_info: True use_padded_im_info: True
...@@ -147,7 +147,5 @@ MaskRCNNTestFeed: ...@@ -147,7 +147,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
num_workers: 2 num_workers: 2
use_padded_im_info: True use_padded_im_info: True
...@@ -128,5 +128,3 @@ MaskRCNNTestFeed: ...@@ -128,5 +128,3 @@ MaskRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
...@@ -129,5 +129,3 @@ MaskRCNNTestFeed: ...@@ -129,5 +129,3 @@ MaskRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
...@@ -147,7 +147,5 @@ MaskRCNNTestFeed: ...@@ -147,7 +147,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
num_workers: 2 num_workers: 2
use_padded_im_info: True use_padded_im_info: True
...@@ -147,7 +147,5 @@ MaskRCNNTestFeed: ...@@ -147,7 +147,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
samples: 5
num_workers: 2 num_workers: 2
use_padded_im_info: True use_padded_im_info: True
...@@ -150,6 +150,5 @@ MaskRCNNTestFeed: ...@@ -150,6 +150,5 @@ MaskRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
num_workers: 2 num_workers: 2
use_padded_im_info: True use_padded_im_info: True
...@@ -151,7 +151,6 @@ FasterRCNNTestFeed: ...@@ -151,7 +151,6 @@ FasterRCNNTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: val2017.txt
drop_last: false drop_last: false
image_shape: [3, 1333, 800] image_shape: [3, 1333, 800]
num_workers: 2 num_workers: 2
......
...@@ -84,4 +84,3 @@ SSDTestFeed: ...@@ -84,4 +84,3 @@ SSDTestFeed:
image_dir: VOCdevkit/VOC_all/JPEGImages image_dir: VOCdevkit/VOC_all/JPEGImages
use_default_label: false use_default_label: false
drop_last: false drop_last: false
test_file: data/voc/VOCdevkit/VOC_all/ImageSets/Main/test.txt
...@@ -80,5 +80,3 @@ YoloTestFeed: ...@@ -80,5 +80,3 @@ YoloTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: ../val2017.txt
samples: 5
...@@ -81,5 +81,3 @@ YoloTestFeed: ...@@ -81,5 +81,3 @@ YoloTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: ../val2017.txt
samples: 5
...@@ -17,6 +17,8 @@ YOLOv3: ...@@ -17,6 +17,8 @@ YOLOv3:
ResNet: ResNet:
norm_type: sync_bn norm_type: sync_bn
freeze_at: 0
freeze_norm: False
norm_decay: 0. norm_decay: 0.
depth: 34 depth: 34
feature_maps: [3, 4, 5] feature_maps: [3, 4, 5]
...@@ -81,5 +83,3 @@ YoloTestFeed: ...@@ -81,5 +83,3 @@ YoloTestFeed:
dataset_dir: data/coco dataset_dir: data/coco
annotation: annotations/instances_val2017.json annotation: annotations/instances_val2017.json
image_dir: val2017 image_dir: val2017
test_file: ../val2017.txt
samples: 5
...@@ -81,9 +81,9 @@ def create_reader(feed, max_iter=0): ...@@ -81,9 +81,9 @@ def create_reader(feed, max_iter=0):
'TYPE': type(feed.dataset).__source__ 'TYPE': type(feed.dataset).__source__
} }
} }
if mode == 'TEST': if len(getattr(feed.dataset, 'images', [])) > 0:
data_config[mode]['TEST_FILE'] = feed.test_file data_config[mode]['IMAGES'] = feed.dataset.images
transform_config = { transform_config = {
'WORKER_CONF': { 'WORKER_CONF': {
...@@ -244,12 +244,16 @@ class SimpleDataSet(DataSet): ...@@ -244,12 +244,16 @@ class SimpleDataSet(DataSet):
__source__ = 'SimpleSource' __source__ = 'SimpleSource'
def __init__(self, def __init__(self,
dataset_dir=VOC_DATASET_DIR, dataset_dir=None,
annotation=VOC_TEST_ANNOTATION, annotation=None,
image_dir=VOC_IMAGE_DIR, image_dir=None,
use_default_label=VOC_USE_DEFAULT_LABEL): use_default_label=None):
super(SimpleDataSet, self).__init__( super(SimpleDataSet, self).__init__(
dataset_dir=dataset_dir, annotation=annotation, image_dir=image_dir) dataset_dir=dataset_dir, annotation=annotation, image_dir=image_dir)
self.images = []
def add_images(self, images):
self.images.extend(images)
@serializable @serializable
...@@ -281,7 +285,6 @@ class DataFeed(object): ...@@ -281,7 +285,6 @@ class DataFeed(object):
samples=-1, samples=-1,
drop_last=False, drop_last=False,
with_background=True, with_background=True,
test_file=None,
num_workers=2, num_workers=2,
bufsize=10, bufsize=10,
use_process=False, use_process=False,
...@@ -296,7 +299,6 @@ class DataFeed(object): ...@@ -296,7 +299,6 @@ class DataFeed(object):
self.samples = samples self.samples = samples
self.drop_last = drop_last self.drop_last = drop_last
self.with_background = with_background self.with_background = with_background
self.test_file = test_file
self.num_workers = num_workers self.num_workers = num_workers
self.bufsize = bufsize self.bufsize = bufsize
self.use_process = use_process self.use_process = use_process
...@@ -385,7 +387,6 @@ class TestFeed(DataFeed): ...@@ -385,7 +387,6 @@ class TestFeed(DataFeed):
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
with_background=True, with_background=True,
test_file=None,
num_workers=2): num_workers=2):
super(TestFeed, self).__init__( super(TestFeed, self).__init__(
dataset, dataset,
...@@ -397,7 +398,6 @@ class TestFeed(DataFeed): ...@@ -397,7 +398,6 @@ class TestFeed(DataFeed):
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last, drop_last=drop_last,
with_background=with_background, with_background=with_background,
test_file=test_file,
num_workers=num_workers) num_workers=num_workers)
...@@ -522,7 +522,6 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -522,7 +522,6 @@ class FasterRCNNEvalFeed(DataFeed):
shuffle=False, shuffle=False,
samples=-1, samples=-1,
drop_last=False, drop_last=False,
test_file=None,
num_workers=2, num_workers=2,
use_padded_im_info=True): use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN()) sample_transforms.append(ArrangeTestRCNN())
...@@ -536,7 +535,6 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -536,7 +535,6 @@ class FasterRCNNEvalFeed(DataFeed):
shuffle=shuffle, shuffle=shuffle,
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
test_file=test_file,
num_workers=num_workers, num_workers=num_workers,
use_padded_im_info=use_padded_im_info) use_padded_im_info=use_padded_im_info)
self.mode = 'VAL' self.mode = 'VAL'
...@@ -564,7 +562,6 @@ class FasterRCNNTestFeed(DataFeed): ...@@ -564,7 +562,6 @@ class FasterRCNNTestFeed(DataFeed):
shuffle=False, shuffle=False,
samples=-1, samples=-1,
drop_last=False, drop_last=False,
test_file=None,
num_workers=2, num_workers=2,
use_padded_im_info=True): use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN()) sample_transforms.append(ArrangeTestRCNN())
...@@ -580,7 +577,6 @@ class FasterRCNNTestFeed(DataFeed): ...@@ -580,7 +577,6 @@ class FasterRCNNTestFeed(DataFeed):
shuffle=shuffle, shuffle=shuffle,
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
test_file=test_file,
num_workers=num_workers, num_workers=num_workers,
use_padded_im_info=use_padded_im_info) use_padded_im_info=use_padded_im_info)
self.mode = 'TEST' self.mode = 'TEST'
...@@ -612,7 +608,6 @@ class MaskRCNNEvalFeed(DataFeed): ...@@ -612,7 +608,6 @@ class MaskRCNNEvalFeed(DataFeed):
shuffle=False, shuffle=False,
samples=-1, samples=-1,
drop_last=False, drop_last=False,
test_file=None,
num_workers=2, num_workers=2,
use_process=False, use_process=False,
use_padded_im_info=True): use_padded_im_info=True):
...@@ -627,7 +622,6 @@ class MaskRCNNEvalFeed(DataFeed): ...@@ -627,7 +622,6 @@ class MaskRCNNEvalFeed(DataFeed):
shuffle=shuffle, shuffle=shuffle,
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
test_file=test_file,
num_workers=num_workers, num_workers=num_workers,
use_process=use_process, use_process=use_process,
use_padded_im_info=use_padded_im_info) use_padded_im_info=use_padded_im_info)
...@@ -657,7 +651,6 @@ class MaskRCNNTestFeed(DataFeed): ...@@ -657,7 +651,6 @@ class MaskRCNNTestFeed(DataFeed):
shuffle=False, shuffle=False,
samples=-1, samples=-1,
drop_last=False, drop_last=False,
test_file=None,
num_workers=2, num_workers=2,
use_process=False, use_process=False,
use_padded_im_info=True): use_padded_im_info=True):
...@@ -674,7 +667,6 @@ class MaskRCNNTestFeed(DataFeed): ...@@ -674,7 +667,6 @@ class MaskRCNNTestFeed(DataFeed):
shuffle=shuffle, shuffle=shuffle,
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
test_file=test_file,
num_workers=num_workers, num_workers=num_workers,
use_process=use_process, use_process=use_process,
use_padded_im_info=use_padded_im_info) use_padded_im_info=use_padded_im_info)
...@@ -805,7 +797,6 @@ class SSDTestFeed(DataFeed): ...@@ -805,7 +797,6 @@ class SSDTestFeed(DataFeed):
shuffle=False, shuffle=False,
samples=-1, samples=-1,
drop_last=False, drop_last=False,
test_file=None,
num_workers=8, num_workers=8,
bufsize=10, bufsize=10,
use_process=False): use_process=False):
...@@ -822,7 +813,6 @@ class SSDTestFeed(DataFeed): ...@@ -822,7 +813,6 @@ class SSDTestFeed(DataFeed):
shuffle=shuffle, shuffle=shuffle,
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
test_file=test_file,
num_workers=num_workers) num_workers=num_workers)
self.mode = 'TEST' self.mode = 'TEST'
...@@ -961,10 +951,9 @@ class YoloTestFeed(DataFeed): ...@@ -961,10 +951,9 @@ class YoloTestFeed(DataFeed):
batch_transforms=[], batch_transforms=[],
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
samples=1, samples=-1,
drop_last=False, drop_last=False,
with_background=False, with_background=False,
test_file=None,
num_workers=8, num_workers=8,
num_max_boxes=50, num_max_boxes=50,
use_process=False): use_process=False):
...@@ -982,7 +971,6 @@ class YoloTestFeed(DataFeed): ...@@ -982,7 +971,6 @@ class YoloTestFeed(DataFeed):
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
with_background=with_background, with_background=with_background,
test_file=test_file,
num_workers=num_workers, num_workers=num_workers,
use_process=use_process) use_process=use_process)
self.num_max_boxes = num_max_boxes self.num_max_boxes = num_max_boxes
......
...@@ -30,26 +30,23 @@ class SimpleSource(Dataset): ...@@ -30,26 +30,23 @@ class SimpleSource(Dataset):
Load image files for testing purpose Load image files for testing purpose
Args: Args:
test_file (str): list of image file names, relative to `image_dir` images (list): list of path of images
image_dir (str): root dir for images
samples (int): number of samples to load, -1 means all samples (int): number of samples to load, -1 means all
load_img (bool): should images be loaded load_img (bool): should images be loaded
""" """
def __init__(self, def __init__(self,
test_file='', images=[],
image_dir=None,
samples=-1, samples=-1,
load_img=True, load_img=True,
**kwargs): **kwargs):
super(SimpleSource, self).__init__() super(SimpleSource, self).__init__()
self._epoch = -1 self._epoch = -1
assert test_file != '' and os.path.isfile(test_file), \ for image in images:
"test file not found: " + test_file assert image != '' and os.path.isfile(image), \
self._fname = test_file "Image {} not found".format(image)
self._image_dir = image_dir self._images = images
assert image_dir is not None and os.path.isdir(image_dir), \ self._fname = None
"image directory not found: " + image_dir
self._simple = None self._simple = None
self._pos = -1 self._pos = -1
self._drained = False self._drained = False
...@@ -68,32 +65,25 @@ class SimpleSource(Dataset): ...@@ -68,32 +65,25 @@ class SimpleSource(Dataset):
sample = copy.deepcopy(self._simple[self._pos]) sample = copy.deepcopy(self._simple[self._pos])
if self._load_img: if self._load_img:
sample['image'] = self._load_image(sample['im_file']) sample['image'] = self._load_image(sample['im_file'])
else:
sample['im_file'] = os.path.join(self._image_dir,
sample['im_file'])
self._pos += 1 self._pos += 1
return sample return sample
def _load(self): def _load(self):
assert os.path.isfile(self._fname) and self._fname.endswith('.txt'), \
"invalid test file path"
ct = 0 ct = 0
records = [] records = []
with open(self._fname, 'r') as fr: for image in self._images:
while True: if self._samples > 0 and ct >= self._samples:
line = fr.readline().strip() break
if not line or (self._samples > 0 and ct >= self._samples): rec = {'im_id': np.array([ct]), 'im_file': image}
break self._imid2path[ct] = image
rec = {'im_id': np.array([ct]), 'im_file': line} ct += 1
self._imid2path[ct] = line records.append(rec)
ct += 1 assert len(records) > 0, "no image file found"
records.append(rec)
assert len(records) > 0, "no image file found in " + self._fname
return records return records
def _load_image(self, where): def _load_image(self, where):
fn = os.path.join(self._image_dir, where) with open(where, 'rb') as f:
with open(fn, 'rb') as f:
return f.read() return f.read()
def reset(self): def reset(self):
......
...@@ -61,6 +61,16 @@ def parse_args(): ...@@ -61,6 +61,16 @@ def parse_args():
action='store_true', action='store_true',
default=False, default=False,
help="Whether perform evaluation in train") help="Whether perform evaluation in train")
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Image directory path to perform inference.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path to perform inference, --infer-img has a higher priority than --image-dir")
parser.add_argument( parser.add_argument(
"-o", "--opt", nargs=REMAINDER, help="set configuration options") "-o", "--opt", nargs=REMAINDER, help="set configuration options")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -38,22 +38,27 @@ def visualize_results(image_path, ...@@ -38,22 +38,27 @@ def visualize_results(image_path,
bbox_results=None, bbox_results=None,
mask_results=None): mask_results=None):
""" """
TODO(dengkaipeng): add more comments
Visualize bbox and mask results Visualize bbox and mask results
""" """
image = None if not os.path.exists(SAVE_HOME):
os.makedirs(SAVE_HOME)
logger.info("Image {} detect: ".format(image_path))
image = Image.open(image_path)
if mask_results: if mask_results:
image = draw_mask(image_path, mask_results, threshold) image = draw_mask(image, mask_results, threshold)
if bbox_results: if bbox_results:
draw_bbox(image_path, catid2name, bbox_results, threshold, image) image = draw_bbox(image, catid2name, bbox_results, threshold)
save_name = get_save_image_name(image_path)
logger.info("Detection results save in {}\n".format(save_name))
image.save(save_name)
def draw_mask(image_path, segms, threshold, alpha=0.7, save_image=False): def draw_mask(image, segms, threshold, alpha=0.7):
""" """
TODO(dengkaipeng): add more comments
Draw mask on image Draw mask on image
""" """
image = Image.open(image_path)
im_width, im_height = image.size im_width, im_height = image.size
mask_color_id = 0 mask_color_id = 0
w_ratio = .4 w_ratio = .4
...@@ -72,23 +77,13 @@ def draw_mask(image_path, segms, threshold, alpha=0.7, save_image=False): ...@@ -72,23 +77,13 @@ def draw_mask(image_path, segms, threshold, alpha=0.7, save_image=False):
image[idx[0], idx[1], :] *= 1.0 - alpha image[idx[0], idx[1], :] *= 1.0 - alpha
image[idx[0], idx[1], :] += alpha * color_mask image[idx[0], idx[1], :] += alpha * color_mask
image = Image.fromarray(image.astype('uint8')) image = Image.fromarray(image.astype('uint8'))
if not os.path.exists(SAVE_HOME):
os.makedirs(SAVE_HOME)
if save_image:
save_name = get_save_image_name(image_path)
logger.info("Detection mask results save in {}".format(save_name))
image.save(save_name)
return image return image
def draw_bbox(image_path, catid2name, bboxes, threshold, image=None): def draw_bbox(image, catid2name, bboxes, threshold):
""" """
TODO(dengkaipeng): add more comments
Draw bbox on image Draw bbox on image
""" """
if image is None:
image = Image.open(image_path)
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
im_width, im_height = image.size im_width, im_height = image.size
...@@ -106,13 +101,12 @@ def draw_bbox(image_path, catid2name, bboxes, threshold, image=None): ...@@ -106,13 +101,12 @@ def draw_bbox(image_path, catid2name, bboxes, threshold, image=None):
fill='red') fill='red')
if image.mode == 'RGB': if image.mode == 'RGB':
draw.text((xmin, ymin), catid2name[catid], (255, 255, 0)) draw.text((xmin, ymin), catid2name[catid], (255, 255, 0))
logger.info("\t {:15s} at {:25} score: {:.5f}".format(
catid2name[catid],
str(list(map(int, [xmin, ymin, xmax, ymax]))),
score))
if not os.path.exists(SAVE_HOME): return image
os.makedirs(SAVE_HOME)
save_name = get_save_image_name(image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name)
def get_save_image_name(image_path): def get_save_image_name(image_path):
""" """
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import glob
import numpy as np import numpy as np
...@@ -37,6 +38,32 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) ...@@ -37,6 +38,32 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer-img or --infer-dir should be set"
images = []
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
images.append(infer_img)
return images
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
for fmt in ['jpg', 'jpeg', 'png', 'bmp']:
images.extend(glob.glob('{}/*.{}'.format(infer_dir, fmt)))
assert len(images) > 0, "no image found in {} with " \
"extension {}".format(infer_dir, image_ext)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def main(): def main():
args = parse_args() args = parse_args()
cfg = load_config(args.config) cfg = load_config(args.config)
...@@ -53,6 +80,9 @@ def main(): ...@@ -53,6 +80,9 @@ def main():
else: else:
test_feed = create(cfg['test_feed']) test_feed = create(cfg['test_feed'])
test_images = get_test_images(args.infer_dir, args.infer_img)
test_feed.dataset.add_images(test_images)
place = fluid.CUDAPlace(0) if cfg['use_gpu'] else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg['use_gpu'] else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -103,7 +133,7 @@ def main(): ...@@ -103,7 +133,7 @@ def main():
logger.info('Infer iter {}'.format(iter_id)) logger.info('Infer iter {}'.format(iter_id))
im_id = int(res['im_id'][0]) im_id = int(res['im_id'][0])
image_path = os.path.join(test_feed.dataset.image_dir, imid2path[im_id]) image_path = imid2path[im_id]
if cfg['metric'] == 'COCO': if cfg['metric'] == 'COCO':
bbox_results = None bbox_results = None
mask_results = None mask_results = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册