From a07b62623ed63e8a191de410e2e54829771cbd22 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Fri, 31 Jul 2020 12:33:39 +0800 Subject: [PATCH] hidden the dataset call of pipeline to train_from_dataset (#25834) * hidden the explicit setting of dataset for pipeline training. --- python/paddle/fluid/executor.py | 27 +++++++++++++------ .../fluid/tests/unittests/test_pipeline.py | 10 +------ 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 23ab436c06..a8829a42f0 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1334,14 +1334,25 @@ class Executor(object): fetch_info=None, print_period=100, fetch_handler=None): - if dataset is None: - raise RuntimeError("dataset is need and should be initialized") - - if program._pipeline_opt is not None and program._pipeline_opt[ - "sync_steps"] != -1: - # hack for paddlebox: sync_steps(-1) denotes paddlebox - thread = self._adjust_pipeline_resource(program._pipeline_opt, - dataset, thread) + if program._pipeline_opt is not None: + import paddle + if dataset is not None: + raise RuntimeError("dataset should be None for pipeline mode") + # The following fake dataset is created to call + # the _prepare_trainer api, and it is meaningless. + data_vars = [] + for var in program.global_block().vars.values(): + if var.is_data: + data_vars.append(var) + dataset = paddle.fluid.DatasetFactory().create_dataset( + 'FileInstantDataset') + dataset.set_batch_size(1) + dataset.set_thread(1) + dataset.set_filelist(['None']) + dataset.set_use_var(data_vars) + else: + if dataset is None: + raise RuntimeError("dataset is need and should be initialized") dataset._prepare_to_run() diff --git a/python/paddle/fluid/tests/unittests/test_pipeline.py b/python/paddle/fluid/tests/unittests/test_pipeline.py index 1f884195a4..fe31add697 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline.py @@ -186,18 +186,10 @@ class TestPipeline(unittest.TestCase): data_loader.set_sample_generator(train_reader, batch_size=1) place = fluid.CPUPlace() - # The following dataset is only used for the - # interface 'train_from_dataset'. - # And it has no actual meaning. - dataset = fluid.DatasetFactory().create_dataset('FileInstantDataset') - dataset.set_batch_size(1) - dataset.set_thread(1) - dataset.set_filelist(['/tmp/tmp_2.txt']) - dataset.set_use_var([image, label]) exe = fluid.Executor(place) exe.run(startup_prog) data_loader.start() - exe.train_from_dataset(main_prog, dataset, debug=debug) + exe.train_from_dataset(main_prog, debug=debug) def test_pipeline(self): self._run(False) -- GitLab