未验证 提交 23d24764 编写于 作者: G Guanghua Yu 提交者: GitHub

adapt infer reader (#1727)

上级 7a65af0c
...@@ -39,6 +39,8 @@ TestReader: ...@@ -39,6 +39,8 @@ TestReader:
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {interp: 1, max_size: 1333, target_size: 800, use_cv2: true} - ResizeImage: {interp: 1, max_size: 1333, target_size: 800, use_cv2: true}
- Permute: {channel_first: true, to_bgr: false} - Permute: {channel_first: true, to_bgr: false}
batch_transforms:
- PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: false}
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
...@@ -68,6 +68,23 @@ class DetDataset(Dataset): ...@@ -68,6 +68,23 @@ class DetDataset(Dataset):
return os.path.join(self.dataset_dir, self.anno_path) return os.path.join(self.dataset_dir, self.anno_path)
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
return f.lower().endswith(extensions)
def _make_dataset(dir):
dir = os.path.expanduser(dir)
if not os.path.isdir(d):
raise ('{} should be a dir'.format(dir))
images = []
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
images.append(path)
return images
@register @register
@serializable @serializable
class ImageFolder(DetDataset): class ImageFolder(DetDataset):
...@@ -76,11 +93,18 @@ class ImageFolder(DetDataset): ...@@ -76,11 +93,18 @@ class ImageFolder(DetDataset):
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
sample_num=-1, sample_num=-1,
use_default_label=None,
**kwargs): **kwargs):
super(ImageFolder, self).__init__(dataset_dir, image_dir, anno_path, super(ImageFolder, self).__init__(dataset_dir, image_dir, anno_path,
sample_num) sample_num, use_default_label)
self._imid2path = {}
self.roidbs = None
def parse_dataset(self): def parse_dataset(self, with_background=True):
if not self.roidbs:
self.roidbs = self._load_images()
def _parse(self):
image_dir = self.image_dir image_dir = self.image_dir
if not isinstance(image_dir, Sequence): if not isinstance(image_dir, Sequence):
image_dir = [image_dir] image_dir = [image_dir]
...@@ -91,4 +115,27 @@ class ImageFolder(DetDataset): ...@@ -91,4 +115,27 @@ class ImageFolder(DetDataset):
images.extend(_make_dataset(im_dir)) images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir): elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir) images.append(im_dir)
self.roidbs = images return images
def _load_images(self):
images = self._parse()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def get_imid2path(self):
return self._imid2path
def set_images(self, images):
self.image_dir = images
self.roidbs = self._load_images()
...@@ -33,7 +33,6 @@ from ppdet.core.workspace import load_config, merge_config, create ...@@ -33,7 +33,6 @@ from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.visualizer import visualize_results from ppdet.utils.visualizer import visualize_results
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.data.reader import create_reader
from ppdet.utils.checkpoint import load_weight from ppdet.utils.checkpoint import load_weight
from ppdet.utils.eval_utils import get_infer_results from ppdet.utils.eval_utils import get_infer_results
import logging import logging
...@@ -120,22 +119,24 @@ def get_test_images(infer_dir, infer_img): ...@@ -120,22 +119,24 @@ def get_test_images(infer_dir, infer_img):
return images return images
def run(FLAGS, cfg): def run(FLAGS, cfg, place):
# Model # Model
main_arch = cfg.architecture main_arch = cfg.architecture
model = create(cfg.architecture) model = create(cfg.architecture)
dataset = cfg.TestReader['dataset'] # data
dataset = cfg.TestDataset
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
dataset.set_images(test_images) dataset.set_images(test_images)
test_loader, _ = create('TestReader')(dataset, cfg['worker_num'], place)
# TODO: support other metrics # TODO: support other metrics
imid2path = dataset.get_imid2path() imid2path = dataset.get_imid2path()
from ppdet.utils.coco_eval import get_category_info from ppdet.utils.coco_eval import get_category_info
anno_file = dataset.get_anno() anno_file = dataset.get_anno()
with_background = dataset.with_background with_background = cfg.with_background
use_default_label = dataset.use_default_label use_default_label = dataset.use_default_label
clsid2catid, catid2name = get_category_info(anno_file, with_background, clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label) use_default_label)
...@@ -143,11 +144,8 @@ def run(FLAGS, cfg): ...@@ -143,11 +144,8 @@ def run(FLAGS, cfg):
# Init Model # Init Model
load_weight(model, cfg.weights) load_weight(model, cfg.weights)
# Data Reader
test_reader = create_reader(cfg.TestDataset, cfg.TestReader)
# Run Infer # Run Infer
for iter_id, data in enumerate(test_reader()): for iter_id, data in enumerate(test_loader):
# forward # forward
model.eval() model.eval()
outs = model(data, cfg.TestReader['inputs_def']['fields'], 'infer') outs = model(data, cfg.TestReader['inputs_def']['fields'], 'infer')
...@@ -208,7 +206,9 @@ def main(): ...@@ -208,7 +206,9 @@ def main():
check_gpu(cfg.use_gpu) check_gpu(cfg.use_gpu)
check_version() check_version()
run(FLAGS, cfg) place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
run(FLAGS, cfg, place)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册