提交 5fc84f39 编写于 作者: W wangguanzhong 提交者: GitHub

add dataset_dir in args (#2956)

上级 b00deb54
......@@ -42,7 +42,7 @@ __all__ = [
]
def create_reader(feed, max_iter=0):
def create_reader(feed, max_iter=0, args_path=None):
"""
Return iterable data reader.
......@@ -52,11 +52,11 @@ def create_reader(feed, max_iter=0):
# if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
if feed.dataset.dataset_dir:
dataset_home = args_path if args_path else feed.dataset.dataset_dir
if dataset_home:
annotation = getattr(feed.dataset, 'annotation', None)
image_dir = getattr(feed.dataset, 'image_dir', None)
dataset_dir = get_dataset_path(feed.dataset.dataset_dir,
annotation, image_dir)
dataset_dir = get_dataset_path(dataset_home, annotation, image_dir)
if annotation:
feed.dataset.annotation = os.path.join(dataset_dir, annotation)
if image_dir:
......
......@@ -75,7 +75,7 @@ def main():
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
reader = create_reader(eval_feed)
reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
pyreader.decorate_sample_list_generator(reader, place)
# compile program for multi-devices
......@@ -114,8 +114,8 @@ def main():
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, cfg.num_classes,
resolution, is_bbox_normalized, FLAGS.output_file)
eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution,
is_bbox_normalized, FLAGS.output_file)
if __name__ == '__main__':
......@@ -126,5 +126,11 @@ if __name__ == '__main__':
default=None,
type=str,
help="Evaluation file name, default to bbox.json and mask.json.")
parser.add_argument(
"-d",
"--dataset_dir",
default=None,
type=str,
help="Dataset path, same as DataFeed.dataset.dataset_dir")
FLAGS = parser.parse_args()
main()
......@@ -105,7 +105,8 @@ def main():
optimizer = optim_builder(lr)
optimizer.minimize(loss)
train_reader = create_reader(train_feed, cfg.max_iters * devices_num)
train_reader = create_reader(train_feed, cfg.max_iters * devices_num,
FLAGS.dataset_dir)
train_pyreader.decorate_sample_list_generator(train_reader, place)
# parse train fetches
......@@ -121,7 +122,7 @@ def main():
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
eval_reader = create_reader(eval_feed)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
eval_pyreader.decorate_sample_list_generator(eval_reader, place)
# parse eval fetches
......@@ -197,7 +198,7 @@ def main():
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, cfg.num_classes,
eval_results(results, eval_feed, cfg.metric, cfg.num_classes,
resolution, is_bbox_normalized, FLAGS.output_file)
train_pyreader.reset()
......@@ -222,5 +223,11 @@ if __name__ == '__main__':
default=None,
type=str,
help="Evaluation file name, default to bbox.json and mask.json.")
parser.add_argument(
"-d",
"--dataset_dir",
default=None,
type=str,
help="Dataset path, same as DataFeed.dataset.dataset_dir")
FLAGS = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册