提交 fe350028 编写于 作者: L Liufang Sang 提交者: whs

updata reader api test=develop (#3534)

上级 e7684a4d
......@@ -132,7 +132,7 @@ def main():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
eval_pyreader, test_feed_vars = create_feed(eval_feed, use_pyreader=False)
_, test_feed_vars = create_feed(eval_feed, iterable=True)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
......
......@@ -47,7 +47,7 @@ from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches, eval_results
from ppdet.utils.stats import TrainingStats
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu
from ppdet.utils.check import check_gpu, check_version
import ppdet.utils.checkpoint as checkpoint
from ppdet.modeling.model_input import create_feed
......@@ -118,6 +118,7 @@ def main():
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
#check_version()
if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count()
......@@ -147,7 +148,7 @@ def main():
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
train_pyreader, feed_vars = create_feed(train_feed)
train_loader, feed_vars = create_feed(train_feed, iterable=True)
train_fetches = model.train(feed_vars)
loss = train_fetches['loss']
lr = lr_builder()
......@@ -157,7 +158,7 @@ def main():
train_reader = create_reader(train_feed, cfg.max_iters * devices_num,
FLAGS.dataset_dir)
train_pyreader.decorate_sample_list_generator(train_reader, place)
train_loader.set_sample_list_generator(train_reader, place)
# parse train fetches
train_keys, train_values, _ = parse_fetches(train_fetches)
......@@ -172,7 +173,7 @@ def main():
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
model = create(main_arch)
eval_pyreader, test_feed_vars = create_feed(eval_feed, use_pyreader=False)
_, test_feed_vars = create_feed(eval_feed, iterable=True)
fetches = model.eval(test_feed_vars)
eval_prog = eval_prog.clone(True)
......@@ -231,8 +232,8 @@ def main():
place,
fluid.global_scope(),
train_prog,
train_reader=train_pyreader,
train_feed_list=None,
train_reader=train_reader,
train_feed_list=[(key, value.name) for key, value in feed_vars.items()],
train_fetch_list=train_fetch_list,
eval_program=eval_prog,
eval_reader=eval_reader,
......
......@@ -132,7 +132,7 @@ def main():
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
eval_pyreader, test_feed_vars = create_feed(eval_feed, use_pyreader=False)
_, test_feed_vars = create_feed(eval_feed, iterable=True)
eval_reader = create_reader(eval_feed, args_path=FLAGS.dataset_dir)
#eval_pyreader.decorate_sample_list_generator(eval_reader, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册