diff --git a/configs/_base_/readers/mask_reader.yml b/configs/_base_/readers/mask_reader.yml index 2efbea2774380ce06604f2e2927480a726441e3d..c4908308ff5b9853996653b0dfbba7346a2a2f6a 100644 --- a/configs/_base_/readers/mask_reader.yml +++ b/configs/_base_/readers/mask_reader.yml @@ -39,6 +39,8 @@ TestReader: - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - ResizeImage: {interp: 1, max_size: 1333, target_size: 800, use_cv2: true} - Permute: {channel_first: true, to_bgr: false} + batch_transforms: + - PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: false} batch_size: 1 shuffle: false drop_last: false diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py index 4ee8b11e7c57aabf8bc3c3897cc73cd61d9cee26..c8a2a9599688ccc2b8a442ca7f17750d7abf002d 100644 --- a/ppdet/data/source/dataset.py +++ b/ppdet/data/source/dataset.py @@ -68,6 +68,23 @@ class DetDataset(Dataset): return os.path.join(self.dataset_dir, self.anno_path) +def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')): + return f.lower().endswith(extensions) + + +def _make_dataset(dir): + dir = os.path.expanduser(dir) + if not os.path.isdir(d): + raise ('{} should be a dir'.format(dir)) + images = [] + for root, _, fnames in sorted(os.walk(dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + images.append(path) + return images + + @register @serializable class ImageFolder(DetDataset): @@ -76,11 +93,18 @@ class ImageFolder(DetDataset): image_dir=None, anno_path=None, sample_num=-1, + use_default_label=None, **kwargs): super(ImageFolder, self).__init__(dataset_dir, image_dir, anno_path, - sample_num) + sample_num, use_default_label) + self._imid2path = {} + self.roidbs = None - def parse_dataset(self): + def parse_dataset(self, with_background=True): + if not self.roidbs: + self.roidbs = self._load_images() + + def _parse(self): image_dir = self.image_dir if not isinstance(image_dir, Sequence): image_dir = [image_dir] @@ -91,4 +115,27 @@ class ImageFolder(DetDataset): images.extend(_make_dataset(im_dir)) elif os.path.isfile(im_dir) and _is_valid_file(im_dir): images.append(im_dir) - self.roidbs = images + return images + + def _load_images(self): + images = self._parse() + ct = 0 + records = [] + for image in images: + assert image != '' and os.path.isfile(image), \ + "Image {} not found".format(image) + if self.sample_num > 0 and ct >= self.sample_num: + break + rec = {'im_id': np.array([ct]), 'im_file': image} + self._imid2path[ct] = image + ct += 1 + records.append(rec) + assert len(records) > 0, "No image file found" + return records + + def get_imid2path(self): + return self._imid2path + + def set_images(self, images): + self.image_dir = images + self.roidbs = self._load_images() diff --git a/tools/infer.py b/tools/infer.py index e11a58d0e9c74397737ae80b5f3534696f618726..69a8ab5576c35b9bfc61f318be541bbfbfdf497e 100755 --- a/tools/infer.py +++ b/tools/infer.py @@ -33,7 +33,6 @@ from ppdet.core.workspace import load_config, merge_config, create from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.visualizer import visualize_results from ppdet.utils.cli import ArgsParser -from ppdet.data.reader import create_reader from ppdet.utils.checkpoint import load_weight from ppdet.utils.eval_utils import get_infer_results import logging @@ -120,22 +119,24 @@ def get_test_images(infer_dir, infer_img): return images -def run(FLAGS, cfg): +def run(FLAGS, cfg, place): # Model main_arch = cfg.architecture model = create(cfg.architecture) - dataset = cfg.TestReader['dataset'] + # data + dataset = cfg.TestDataset test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) dataset.set_images(test_images) + test_loader, _ = create('TestReader')(dataset, cfg['worker_num'], place) # TODO: support other metrics imid2path = dataset.get_imid2path() from ppdet.utils.coco_eval import get_category_info anno_file = dataset.get_anno() - with_background = dataset.with_background + with_background = cfg.with_background use_default_label = dataset.use_default_label clsid2catid, catid2name = get_category_info(anno_file, with_background, use_default_label) @@ -143,11 +144,8 @@ def run(FLAGS, cfg): # Init Model load_weight(model, cfg.weights) - # Data Reader - test_reader = create_reader(cfg.TestDataset, cfg.TestReader) - # Run Infer - for iter_id, data in enumerate(test_reader()): + for iter_id, data in enumerate(test_loader): # forward model.eval() outs = model(data, cfg.TestReader['inputs_def']['fields'], 'infer') @@ -208,7 +206,9 @@ def main(): check_gpu(cfg.use_gpu) check_version() - run(FLAGS, cfg) + place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' + place = paddle.set_device(place) + run(FLAGS, cfg, place) if __name__ == '__main__':