diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index c6df2646c4c83eea804e507c00ec2fc34a269e76..728197377df04df8c993a48bc282431473fe9959 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -16,7 +16,7 @@ function(reader_library TARGET_NAME) endfunction() cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) -reader_library(open_files_op SRCS open_files_op.cc) +reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) diff --git a/paddle/fluid/operators/reader/open_files_op.cc b/paddle/fluid/operators/reader/open_files_op.cc index daeacdb8b4d540e94d1c03bd33b0bbdb024f958d..74cda3ad980a72d68b2faf81f97aae95c72cebfb 100644 --- a/paddle/fluid/operators/reader/open_files_op.cc +++ b/paddle/fluid/operators/reader/open_files_op.cc @@ -18,6 +18,7 @@ #include "ThreadPool.h" #include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/operators/reader/blocking_queue.h" +#include "paddle/fluid/operators/reader/buffered_reader.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { @@ -232,12 +233,17 @@ class OpenFilesOp : public framework::OperatorBase { container.reset(new OrderedReaderContainer()); } else { container.reset(new PreemptiveReaderContainer( - std::min(file_names.size(), - static_cast(std::thread::hardware_concurrency())))); + static_cast(Attr("thread_num")))); } - out->Reset( - std::make_shared(file_names, std::move(container))); + auto reader = + std::make_shared(file_names, std::move(container)); + auto buffer_size = Attr("buffer_size"); + if (buffer_size > 1) { + reader = framework::MakeDecoratedReader( + reader, platform::CPUPlace(), buffer_size); + } + out->Reset(reader); } }; @@ -253,6 +259,8 @@ class OpenFilesOpMaker : public FileReaderMakerBase { An OpenFilesOp creates a MultiFileReader, which is able to read data multi-threaded from multiple files. )DOC"); + AddAttr("thread_num", "Number of thread to read files."); + AddAttr("buffer_size", "The reading buffer of these files."); } }; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 9133038de237a0a9475c8667c5b9265c2d60e782..bfcf5ee7d87824e2c68da8ec935da8d77e7bab12 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper from ..executor import global_scope from layer_function_generator import generate_layer_fn, templatedoc import sys +import multiprocessing __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv', @@ -549,10 +550,9 @@ def open_files(filenames, shapes(list): List of tuples which declaring data shapes. lod_levels(list): List of ints which declaring data lod_level. dtypes(list): List of strs which declaring data type. - thread_num(None): Deprecated argument. It will be set by open_files - automatically. - buffer_size(None): Deprecated argument. It will be set by open_files - automatically. + thread_num(None): The number of thread to read files. + Default: min(len(filenames), cpu_number). + buffer_size(None): The buffer size of reader. Default: 3 * thread_num pass_num(int): Number of passes to run. is_test(bool|None): Whether `open_files` used for testing or not. If it is used for testing, the order of data generated is same as the file @@ -574,14 +574,15 @@ def open_files(filenames, # Via the reader, we can use 'read_file' layer to get data: image, label = fluid.layers.io.read_file(reader) """ - if thread_num is not None: - print >> sys.stderr, "thread_num parameter of open_files is " \ - "deprecated. It will be ignored and set " \ - "automatically by open_files " - if buffer_size is not None: - print >> sys.stderr, "buffer_size parameter of open_files is " \ - "deprecated. It will be ignored and set " \ - "automatically by open_files " + if thread_num is None: + thread_num = min(len(filenames), multiprocessing.cpu_count()) + else: + thread_num = int(thread_num) + + if buffer_size is None: + buffer_size = 3 * thread_num + else: + buffer_size = int(buffer_size) if isinstance(filenames, basestring): filenames = [filenames] @@ -600,7 +601,9 @@ def open_files(filenames, 'shape_concat': shape_concat, 'lod_levels': lod_levels, 'ranks': ranks, - 'file_names': filenames + 'file_names': filenames, + 'thread_num': thread_num, + 'buffer_size': buffer_size } if is_test is not None: attrs['is_test'] = is_test diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py index 7fa4abc611621adcee1f2e32051efb19e53c4f04..aa09b0ea445adccae3f741b53850f8182f3270cc 100644 --- a/python/paddle/fluid/tests/unittests/test_data_balance.py +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -155,7 +155,7 @@ class TestDataBalance(unittest.TestCase): main_program=main_prog, build_strategy=build_strategy) - if (parallel_exe.device_count > self.batch_size): + if parallel_exe.device_count > self.batch_size: print("WARNING: Unittest TestDataBalance skipped. \ For the result is not correct when device count \ is larger than batch size.")