From 6e0df3102ec586279d3dbcb5fc528d656f372e84 Mon Sep 17 00:00:00 2001
From: hutuxian <hutuxian2011@sina.cn>
Date: Wed, 3 Jul 2019 09:35:55 +0800
Subject: [PATCH] Refactor for Pipeline Thread Check (#18459)

move the thread-check code from train_from_dataset to a single function
add UT for the thread check function
---
 python/paddle/fluid/executor.py               | 37 ++++++++---------
 .../fluid/tests/unittests/test_dataset.py     |  6 +++
 .../fluid/tests/unittests/test_pipeline.py    | 40 +++++++++++++++++++
 3 files changed, 65 insertions(+), 18 deletions(-)

diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py
index 80719e9b39..ec7404135b 100644
--- a/python/paddle/fluid/executor.py
+++ b/python/paddle/fluid/executor.py
@@ -764,6 +764,22 @@ class Executor(object):
             with open("fleet_desc.prototxt", "w") as fout:
                 fout.write(str(program._fleet_opt["fleet_desc"]))
 
+    def _adjust_pipeline_resource(self, pipeline_opt, dataset, pipeline_num):
+        filelist_length = len(dataset.dataset.get_filelist())
+        if filelist_length < pipeline_num:
+            pipeline_num = filelist_length
+            print(
+                "Pipeline training: setting the pipeline num to %d is enough because there are only %d files"
+                % (filelist_length, filelist_length))
+        if filelist_length < pipeline_num * pipeline_opt["concurrency_list"][0]:
+            print(
+                "Pipeline training: setting the 1st element in concurrency_list to %d is enough because there are only %d files"
+                % (filelist_length // pipeline_num, filelist_length))
+            pipeline_opt["concurrency_list"][
+                0] = filelist_length // pipeline_num
+        dataset.set_thread(pipeline_opt["concurrency_list"][0] * pipeline_num)
+        return pipeline_num
+
     def _prepare_trainer(self,
                          program=None,
                          dataset=None,
@@ -952,25 +968,10 @@ class Executor(object):
         if dataset == None:
             raise RuntimeError("dataset is need and should be initialized")
 
-        # Adjust the reader size for small file num
         if program._pipeline_opt:
-            dataset.set_thread(thread *
-                               program._pipeline_opt["concurrency_list"][0])
-            file_size = len(dataset.dataset.get_filelist())
-            if file_size < thread:
-                thread = file_size
-                print(
-                    "Pipeline: setting the pipeline num to %d is enough because there are only %d files"
-                    % (file_size, file_size))
-            if file_size < thread * program._pipeline_opt["concurrency_list"][
-                    0]:
-                print(
-                    "Pipeline: setting the 1st element in concurrency_list to %d is enough because there are only %d files"
-                    % (file_size / thread, file_size))
-                program._pipeline_opt["concurrency_list"][
-                    0] = file_size / thread
-                dataset.set_thread(
-                    program._pipeline_opt["concurrency_list"][0] * thread)
+            thread = self._adjust_pipeline_resource(program._pipeline_opt,
+                                                    dataset, thread)
+
         dataset._prepare_to_run()
         scope, trainer = self._prepare_trainer(
             program=program,
diff --git a/python/paddle/fluid/tests/unittests/test_dataset.py b/python/paddle/fluid/tests/unittests/test_dataset.py
index 4cfd991505..bce3c24dc8 100644
--- a/python/paddle/fluid/tests/unittests/test_dataset.py
+++ b/python/paddle/fluid/tests/unittests/test_dataset.py
@@ -39,6 +39,12 @@ class TestDataset(unittest.TestCase):
         except:
             self.assertTrue(False)
 
+        try:
+            dataset = fluid.DatasetFactory().create_dataset(
+                "FileInstantDataset")
+        except:
+            self.assertTrue(False)
+
         try:
             dataset = fluid.DatasetFactory().create_dataset("MyOwnDataset")
             self.assertTrue(False)
diff --git a/python/paddle/fluid/tests/unittests/test_pipeline.py b/python/paddle/fluid/tests/unittests/test_pipeline.py
index 97d63fe8f2..f6454b4907 100644
--- a/python/paddle/fluid/tests/unittests/test_pipeline.py
+++ b/python/paddle/fluid/tests/unittests/test_pipeline.py
@@ -21,6 +21,46 @@ import shutil
 import unittest
 
 
+class TestPipelineConfig(unittest.TestCase):
+    """  TestCases for Config in Pipeline Training. """
+
+    def config(self, filelist_length, pipeline_num, reader_concurrency):
+        filelist = []
+        for i in range(filelist_length):
+            filelist.append("file" + str(i))
+        self.dataset.set_filelist(filelist)
+        self.pipeline_opt["concurrency_list"][0] = reader_concurrency
+        self.pipeline_num = pipeline_num
+
+    def helper(self, in_filelist_length, in_pipeline_num, in_reader_concurrency,
+               out_pipeline_num, out_reader_concurrency, out_dataset_thread):
+        self.config(in_filelist_length, in_pipeline_num, in_reader_concurrency)
+        res = self.exe._adjust_pipeline_resource(
+            self.pipeline_opt, self.dataset, self.pipeline_num)
+        self.assertEqual(self.pipeline_opt["concurrency_list"][0],
+                         out_reader_concurrency)
+        self.assertEqual(res, out_pipeline_num)
+        self.assertEqual(self.dataset.thread_num, out_dataset_thread)
+
+    def test_adjust_pipeline_resource(self):
+        self.exe = fluid.Executor(fluid.CPUPlace())
+        self.dataset = fluid.DatasetFactory().create_dataset(
+            "FileInstantDataset")
+        self.pipeline_opt = {"concurrency_list": [0, 1, 2]}
+        self.pipeline_num = 0
+
+        self.helper(7, 2, 2, 2, 2, 4)
+        self.helper(7, 2, 3, 2, 3, 6)
+        self.helper(7, 2, 4, 2, 3, 6)
+
+        self.helper(8, 2, 3, 2, 3, 6)
+        self.helper(8, 2, 4, 2, 4, 8)
+        self.helper(8, 2, 5, 2, 4, 8)
+
+        self.helper(3, 4, 1, 3, 1, 3)
+        self.helper(3, 4, 2, 3, 1, 3)
+
+
 class TestPipeline(unittest.TestCase):
     """  TestCases for Pipeline Training. """
 
-- 
GitLab