未验证 提交 b4d00c85 编写于 作者: W wangguanzhong 提交者: GitHub

add dataset_dir in args (#2956)

上级 493daa9b
...@@ -42,7 +42,7 @@ __all__ = [ ...@@ -42,7 +42,7 @@ __all__ = [
] ]
def create_reader(feed, max_iter=0): def create_reader(feed, max_iter=0, args_path=None):
""" """
Return iterable data reader. Return iterable data reader.
...@@ -52,11 +52,11 @@ def create_reader(feed, max_iter=0): ...@@ -52,11 +52,11 @@ def create_reader(feed, max_iter=0):
# if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory # if `DATASET_DIR` does not exists, search ~/.paddle/dataset for a directory
# named `DATASET_DIR` (e.g., coco, pascal), if not present either, download # named `DATASET_DIR` (e.g., coco, pascal), if not present either, download
if feed.dataset.dataset_dir: dataset_home = args_path if args_path else feed.dataset.dataset_dir
if dataset_home:
annotation = getattr(feed.dataset, 'annotation', None) annotation = getattr(feed.dataset, 'annotation', None)
image_dir = getattr(feed.dataset, 'image_dir', None) image_dir = getattr(feed.dataset, 'image_dir', None)
dataset_dir = get_dataset_path(feed.dataset.dataset_dir, dataset_dir = get_dataset_path(dataset_home, annotation, image_dir)
annotation, image_dir)
if annotation: if annotation:
feed.dataset.annotation = os.path.join(dataset_dir, annotation) feed.dataset.annotation = os.path.join(dataset_dir, annotation)
if image_dir: if image_dir:
......
...@@ -75,7 +75,7 @@ def main(): ...@@ -75,7 +75,7 @@ def main():
fetches = model.eval(feed_vars) fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True) eval_prog = eval_prog.clone(True)
reader = create_reader(eval_feed) reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
pyreader.decorate_sample_list_generator(reader, place) pyreader.decorate_sample_list_generator(reader, place)
# compile program for multi-devices # compile program for multi-devices
...@@ -114,8 +114,8 @@ def main(): ...@@ -114,8 +114,8 @@ def main():
resolution = None resolution = None
if 'mask' in results[0]: if 'mask' in results[0]:
resolution = model.mask_head.resolution resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, cfg.num_classes, eval_results(results, eval_feed, cfg.metric, cfg.num_classes, resolution,
resolution, is_bbox_normalized, FLAGS.output_file) is_bbox_normalized, FLAGS.output_file)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -126,5 +126,11 @@ if __name__ == '__main__': ...@@ -126,5 +126,11 @@ if __name__ == '__main__':
default=None, default=None,
type=str, type=str,
help="Evaluation file name, default to bbox.json and mask.json.") help="Evaluation file name, default to bbox.json and mask.json.")
parser.add_argument(
"-d",
"--dataset_dir",
default=None,
type=str,
help="Dataset path, same as DataFeed.dataset.dataset_dir")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
...@@ -105,7 +105,8 @@ def main(): ...@@ -105,7 +105,8 @@ def main():
optimizer = optim_builder(lr) optimizer = optim_builder(lr)
optimizer.minimize(loss) optimizer.minimize(loss)
train_reader = create_reader(train_feed, cfg.max_iters * devices_num) train_reader = create_reader(train_feed, cfg.max_iters * devices_num,
FLAGS.dataset_dir)
train_pyreader.decorate_sample_list_generator(train_reader, place) train_pyreader.decorate_sample_list_generator(train_reader, place)
# parse train fetches # parse train fetches
...@@ -121,7 +122,7 @@ def main(): ...@@ -121,7 +122,7 @@ def main():
fetches = model.eval(feed_vars) fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True) eval_prog = eval_prog.clone(True)
eval_reader = create_reader(eval_feed) eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
eval_pyreader.decorate_sample_list_generator(eval_reader, place) eval_pyreader.decorate_sample_list_generator(eval_reader, place)
# parse eval fetches # parse eval fetches
...@@ -197,7 +198,7 @@ def main(): ...@@ -197,7 +198,7 @@ def main():
resolution = None resolution = None
if 'mask' in results[0]: if 'mask' in results[0]:
resolution = model.mask_head.resolution resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, cfg.num_classes, eval_results(results, eval_feed, cfg.metric, cfg.num_classes,
resolution, is_bbox_normalized, FLAGS.output_file) resolution, is_bbox_normalized, FLAGS.output_file)
train_pyreader.reset() train_pyreader.reset()
...@@ -222,5 +223,11 @@ if __name__ == '__main__': ...@@ -222,5 +223,11 @@ if __name__ == '__main__':
default=None, default=None,
type=str, type=str,
help="Evaluation file name, default to bbox.json and mask.json.") help="Evaluation file name, default to bbox.json and mask.json.")
parser.add_argument(
"-d",
"--dataset_dir",
default=None,
type=str,
help="Dataset path, same as DataFeed.dataset.dataset_dir")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册