未验证 提交 a07b6262 编写于 作者: L lilong12 提交者: GitHub

hidden the dataset call of pipeline to train_from_dataset (#25834)

* hidden the explicit setting of dataset for pipeline training.
上级 f132c2f4
...@@ -1334,14 +1334,25 @@ class Executor(object): ...@@ -1334,14 +1334,25 @@ class Executor(object):
fetch_info=None, fetch_info=None,
print_period=100, print_period=100,
fetch_handler=None): fetch_handler=None):
if dataset is None: if program._pipeline_opt is not None:
raise RuntimeError("dataset is need and should be initialized") import paddle
if dataset is not None:
if program._pipeline_opt is not None and program._pipeline_opt[ raise RuntimeError("dataset should be None for pipeline mode")
"sync_steps"] != -1: # The following fake dataset is created to call
# hack for paddlebox: sync_steps(-1) denotes paddlebox # the _prepare_trainer api, and it is meaningless.
thread = self._adjust_pipeline_resource(program._pipeline_opt, data_vars = []
dataset, thread) for var in program.global_block().vars.values():
if var.is_data:
data_vars.append(var)
dataset = paddle.fluid.DatasetFactory().create_dataset(
'FileInstantDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['None'])
dataset.set_use_var(data_vars)
else:
if dataset is None:
raise RuntimeError("dataset is need and should be initialized")
dataset._prepare_to_run() dataset._prepare_to_run()
......
...@@ -186,18 +186,10 @@ class TestPipeline(unittest.TestCase): ...@@ -186,18 +186,10 @@ class TestPipeline(unittest.TestCase):
data_loader.set_sample_generator(train_reader, batch_size=1) data_loader.set_sample_generator(train_reader, batch_size=1)
place = fluid.CPUPlace() place = fluid.CPUPlace()
# The following dataset is only used for the
# interface 'train_from_dataset'.
# And it has no actual meaning.
dataset = fluid.DatasetFactory().create_dataset('FileInstantDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['/tmp/tmp_2.txt'])
dataset.set_use_var([image, label])
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
data_loader.start() data_loader.start()
exe.train_from_dataset(main_prog, dataset, debug=debug) exe.train_from_dataset(main_prog, debug=debug)
def test_pipeline(self): def test_pipeline(self):
self._run(False) self._run(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册