未验证 提交 6e0df310 编写于 作者: H hutuxian 提交者: GitHub

Refactor for Pipeline Thread Check (#18459)

move the thread-check code from train_from_dataset to a single function
add UT for the thread check function
上级 8f5fffca
......@@ -764,6 +764,22 @@ class Executor(object):
with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"]))
def _adjust_pipeline_resource(self, pipeline_opt, dataset, pipeline_num):
filelist_length = len(dataset.dataset.get_filelist())
if filelist_length < pipeline_num:
pipeline_num = filelist_length
print(
"Pipeline training: setting the pipeline num to %d is enough because there are only %d files"
% (filelist_length, filelist_length))
if filelist_length < pipeline_num * pipeline_opt["concurrency_list"][0]:
print(
"Pipeline training: setting the 1st element in concurrency_list to %d is enough because there are only %d files"
% (filelist_length // pipeline_num, filelist_length))
pipeline_opt["concurrency_list"][
0] = filelist_length // pipeline_num
dataset.set_thread(pipeline_opt["concurrency_list"][0] * pipeline_num)
return pipeline_num
def _prepare_trainer(self,
program=None,
dataset=None,
......@@ -952,25 +968,10 @@ class Executor(object):
if dataset == None:
raise RuntimeError("dataset is need and should be initialized")
# Adjust the reader size for small file num
if program._pipeline_opt:
dataset.set_thread(thread *
program._pipeline_opt["concurrency_list"][0])
file_size = len(dataset.dataset.get_filelist())
if file_size < thread:
thread = file_size
print(
"Pipeline: setting the pipeline num to %d is enough because there are only %d files"
% (file_size, file_size))
if file_size < thread * program._pipeline_opt["concurrency_list"][
0]:
print(
"Pipeline: setting the 1st element in concurrency_list to %d is enough because there are only %d files"
% (file_size / thread, file_size))
program._pipeline_opt["concurrency_list"][
0] = file_size / thread
dataset.set_thread(
program._pipeline_opt["concurrency_list"][0] * thread)
thread = self._adjust_pipeline_resource(program._pipeline_opt,
dataset, thread)
dataset._prepare_to_run()
scope, trainer = self._prepare_trainer(
program=program,
......
......@@ -39,6 +39,12 @@ class TestDataset(unittest.TestCase):
except:
self.assertTrue(False)
try:
dataset = fluid.DatasetFactory().create_dataset(
"FileInstantDataset")
except:
self.assertTrue(False)
try:
dataset = fluid.DatasetFactory().create_dataset("MyOwnDataset")
self.assertTrue(False)
......
......@@ -21,6 +21,46 @@ import shutil
import unittest
class TestPipelineConfig(unittest.TestCase):
""" TestCases for Config in Pipeline Training. """
def config(self, filelist_length, pipeline_num, reader_concurrency):
filelist = []
for i in range(filelist_length):
filelist.append("file" + str(i))
self.dataset.set_filelist(filelist)
self.pipeline_opt["concurrency_list"][0] = reader_concurrency
self.pipeline_num = pipeline_num
def helper(self, in_filelist_length, in_pipeline_num, in_reader_concurrency,
out_pipeline_num, out_reader_concurrency, out_dataset_thread):
self.config(in_filelist_length, in_pipeline_num, in_reader_concurrency)
res = self.exe._adjust_pipeline_resource(
self.pipeline_opt, self.dataset, self.pipeline_num)
self.assertEqual(self.pipeline_opt["concurrency_list"][0],
out_reader_concurrency)
self.assertEqual(res, out_pipeline_num)
self.assertEqual(self.dataset.thread_num, out_dataset_thread)
def test_adjust_pipeline_resource(self):
self.exe = fluid.Executor(fluid.CPUPlace())
self.dataset = fluid.DatasetFactory().create_dataset(
"FileInstantDataset")
self.pipeline_opt = {"concurrency_list": [0, 1, 2]}
self.pipeline_num = 0
self.helper(7, 2, 2, 2, 2, 4)
self.helper(7, 2, 3, 2, 3, 6)
self.helper(7, 2, 4, 2, 3, 6)
self.helper(8, 2, 3, 2, 3, 6)
self.helper(8, 2, 4, 2, 4, 8)
self.helper(8, 2, 5, 2, 4, 8)
self.helper(3, 4, 1, 3, 1, 3)
self.helper(3, 4, 2, 3, 1, 3)
class TestPipeline(unittest.TestCase):
""" TestCases for Pipeline Training. """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册