From 6e0df3102ec586279d3dbcb5fc528d656f372e84 Mon Sep 17 00:00:00 2001 From: hutuxian Date: Wed, 3 Jul 2019 09:35:55 +0800 Subject: [PATCH] Refactor for Pipeline Thread Check (#18459) move the thread-check code from train_from_dataset to a single function add UT for the thread check function --- python/paddle/fluid/executor.py | 37 ++++++++--------- .../fluid/tests/unittests/test_dataset.py | 6 +++ .../fluid/tests/unittests/test_pipeline.py | 40 +++++++++++++++++++ 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 80719e9b39c..ec7404135be 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -764,6 +764,22 @@ class Executor(object): with open("fleet_desc.prototxt", "w") as fout: fout.write(str(program._fleet_opt["fleet_desc"])) + def _adjust_pipeline_resource(self, pipeline_opt, dataset, pipeline_num): + filelist_length = len(dataset.dataset.get_filelist()) + if filelist_length < pipeline_num: + pipeline_num = filelist_length + print( + "Pipeline training: setting the pipeline num to %d is enough because there are only %d files" + % (filelist_length, filelist_length)) + if filelist_length < pipeline_num * pipeline_opt["concurrency_list"][0]: + print( + "Pipeline training: setting the 1st element in concurrency_list to %d is enough because there are only %d files" + % (filelist_length // pipeline_num, filelist_length)) + pipeline_opt["concurrency_list"][ + 0] = filelist_length // pipeline_num + dataset.set_thread(pipeline_opt["concurrency_list"][0] * pipeline_num) + return pipeline_num + def _prepare_trainer(self, program=None, dataset=None, @@ -952,25 +968,10 @@ class Executor(object): if dataset == None: raise RuntimeError("dataset is need and should be initialized") - # Adjust the reader size for small file num if program._pipeline_opt: - dataset.set_thread(thread * - program._pipeline_opt["concurrency_list"][0]) - file_size = len(dataset.dataset.get_filelist()) - if file_size < thread: - thread = file_size - print( - "Pipeline: setting the pipeline num to %d is enough because there are only %d files" - % (file_size, file_size)) - if file_size < thread * program._pipeline_opt["concurrency_list"][ - 0]: - print( - "Pipeline: setting the 1st element in concurrency_list to %d is enough because there are only %d files" - % (file_size / thread, file_size)) - program._pipeline_opt["concurrency_list"][ - 0] = file_size / thread - dataset.set_thread( - program._pipeline_opt["concurrency_list"][0] * thread) + thread = self._adjust_pipeline_resource(program._pipeline_opt, + dataset, thread) + dataset._prepare_to_run() scope, trainer = self._prepare_trainer( program=program, diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py index 4cfd9915056..bce3c24dc81 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset.py +++ b/python/paddle/fluid/tests/unittests/test_dataset.py @@ -39,6 +39,12 @@ class TestDataset(unittest.TestCase): except: self.assertTrue(False) + try: + dataset = fluid.DatasetFactory().create_dataset( + "FileInstantDataset") + except: + self.assertTrue(False) + try: dataset = fluid.DatasetFactory().create_dataset("MyOwnDataset") self.assertTrue(False) diff --git a/python/paddle/fluid/tests/unittests/test_pipeline.py b/python/paddle/fluid/tests/unittests/test_pipeline.py index 97d63fe8f21..f6454b49076 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline.py @@ -21,6 +21,46 @@ import shutil import unittest +class TestPipelineConfig(unittest.TestCase): + """ TestCases for Config in Pipeline Training. """ + + def config(self, filelist_length, pipeline_num, reader_concurrency): + filelist = [] + for i in range(filelist_length): + filelist.append("file" + str(i)) + self.dataset.set_filelist(filelist) + self.pipeline_opt["concurrency_list"][0] = reader_concurrency + self.pipeline_num = pipeline_num + + def helper(self, in_filelist_length, in_pipeline_num, in_reader_concurrency, + out_pipeline_num, out_reader_concurrency, out_dataset_thread): + self.config(in_filelist_length, in_pipeline_num, in_reader_concurrency) + res = self.exe._adjust_pipeline_resource( + self.pipeline_opt, self.dataset, self.pipeline_num) + self.assertEqual(self.pipeline_opt["concurrency_list"][0], + out_reader_concurrency) + self.assertEqual(res, out_pipeline_num) + self.assertEqual(self.dataset.thread_num, out_dataset_thread) + + def test_adjust_pipeline_resource(self): + self.exe = fluid.Executor(fluid.CPUPlace()) + self.dataset = fluid.DatasetFactory().create_dataset( + "FileInstantDataset") + self.pipeline_opt = {"concurrency_list": [0, 1, 2]} + self.pipeline_num = 0 + + self.helper(7, 2, 2, 2, 2, 4) + self.helper(7, 2, 3, 2, 3, 6) + self.helper(7, 2, 4, 2, 3, 6) + + self.helper(8, 2, 3, 2, 3, 6) + self.helper(8, 2, 4, 2, 4, 8) + self.helper(8, 2, 5, 2, 4, 8) + + self.helper(3, 4, 1, 3, 1, 3) + self.helper(3, 4, 2, 3, 1, 3) + + class TestPipeline(unittest.TestCase): """ TestCases for Pipeline Training. """ -- GitLab