diff --git a/ppdet/data/data_feed.py b/ppdet/data/data_feed.py index d5a50696b6c17cff3007e3adf2450c34c895b4ba..e15b9da7ceb04e3ba34b6d5e870703648c0d8349 100644 --- a/ppdet/data/data_feed.py +++ b/ppdet/data/data_feed.py @@ -42,7 +42,7 @@ __all__ = [ ] -def create_reader(feed, max_iter=0): +def create_reader(feed, max_iter=0, args_path=None): """ Return iterable data reader. @@ -52,11 +52,11 @@ def create_reader(feed, max_iter=0): # if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory # named `DATASET_DIR` (e.g., coco, pascal), if not present either, download - if feed.dataset.dataset_dir: + dataset_home = args_path if args_path else feed.dataset.dataset_dir + if dataset_home: annotation = getattr(feed.dataset, 'annotation', None) image_dir = getattr(feed.dataset, 'image_dir', None) - dataset_dir = get_dataset_path(feed.dataset.dataset_dir, - annotation, image_dir) + dataset_dir = get_dataset_path(dataset_home, annotation, image_dir) if annotation: feed.dataset.annotation = os.path.join(dataset_dir, annotation) if image_dir: diff --git a/tools/eval.py b/tools/eval.py index 1e104b132eeb706aa08216505a0fb4f0bb4c906e..8ed450074b9003260fe7790c0eba149feed61fff 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -75,7 +75,7 @@ def main(): fetches = model.eval(feed_vars) eval_prog = eval_prog.clone(True) - reader = create_reader(eval_feed) + reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) pyreader.decorate_sample_list_generator(reader, place) # compile program for multi-devices @@ -114,8 +114,8 @@ def main(): resolution = None if 'mask' in results[0]: resolution = model.mask_head.resolution - eval_results(results, eval_feed, cfg.metric, cfg.num_classes, - resolution, is_bbox_normalized, FLAGS.output_file) + eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution, + is_bbox_normalized, FLAGS.output_file) if __name__ == '__main__': @@ -126,5 +126,11 @@ if __name__ == '__main__': default=None, type=str, help="Evaluation file name, default to bbox.json and mask.json.") + parser.add_argument( + "-d", + "--dataset_dir", + default=None, + type=str, + help="Dataset path, same as DataFeed.dataset.dataset_dir") FLAGS = parser.parse_args() main() diff --git a/tools/train.py b/tools/train.py index 906eb08a0060155375cf94d3ad53c0b7aa3640c6..7ffb204129cdab8c7bfd566c129d4cf5a6d5759d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -105,7 +105,8 @@ def main(): optimizer = optim_builder(lr) optimizer.minimize(loss) - train_reader = create_reader(train_feed, cfg.max_iters * devices_num) + train_reader = create_reader(train_feed, cfg.max_iters * devices_num, + FLAGS.dataset_dir) train_pyreader.decorate_sample_list_generator(train_reader, place) # parse train fetches @@ -121,7 +122,7 @@ def main(): fetches = model.eval(feed_vars) eval_prog = eval_prog.clone(True) - eval_reader = create_reader(eval_feed) + eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir) eval_pyreader.decorate_sample_list_generator(eval_reader, place) # parse eval fetches @@ -197,7 +198,7 @@ def main(): resolution = None if 'mask' in results[0]: resolution = model.mask_head.resolution - eval_results(results, eval_feed, cfg.metric, cfg.num_classes, + eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution, is_bbox_normalized, FLAGS.output_file) train_pyreader.reset() @@ -222,5 +223,11 @@ if __name__ == '__main__': default=None, type=str, help="Evaluation file name, default to bbox.json and mask.json.") + parser.add_argument( + "-d", + "--dataset_dir", + default=None, + type=str, + help="Dataset path, same as DataFeed.dataset.dataset_dir") FLAGS = parser.parse_args() main()