未验证 提交 1478a5fc 编写于 作者: Y yuyang18

Make open_files use buffer

上级 dc34effd
...@@ -16,7 +16,7 @@ function(reader_library TARGET_NAME) ...@@ -16,7 +16,7 @@ function(reader_library TARGET_NAME)
endfunction() endfunction()
cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool) cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
reader_library(open_files_op SRCS open_files_op.cc) reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "ThreadPool.h" #include "ThreadPool.h"
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -232,12 +233,17 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -232,12 +233,17 @@ class OpenFilesOp : public framework::OperatorBase {
container.reset(new OrderedReaderContainer()); container.reset(new OrderedReaderContainer());
} else { } else {
container.reset(new PreemptiveReaderContainer( container.reset(new PreemptiveReaderContainer(
std::min(file_names.size(), static_cast<size_t>(Attr<int>("thread_num"))));
static_cast<size_t>(std::thread::hardware_concurrency()))));
} }
out->Reset( auto reader =
std::make_shared<MultiFileReader>(file_names, std::move(container))); std::make_shared<MultiFileReader>(file_names, std::move(container));
auto buffer_size = Attr<int>("buffer_size");
if (buffer_size > 1) {
reader = framework::MakeDecoratedReader<BufferedReader>(
reader, platform::CPUPlace(), buffer_size);
}
out->Reset(reader);
} }
}; };
...@@ -253,6 +259,8 @@ class OpenFilesOpMaker : public FileReaderMakerBase { ...@@ -253,6 +259,8 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
An OpenFilesOp creates a MultiFileReader, which is able to An OpenFilesOp creates a MultiFileReader, which is able to
read data multi-threaded from multiple files. read data multi-threaded from multiple files.
)DOC"); )DOC");
AddAttr<int>("thread_num", "Number of thread to read files.");
AddAttr<int>("buffer_size", "The reading buffer of these files.");
} }
}; };
......
...@@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper ...@@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper
from ..executor import global_scope from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc from layer_function_generator import generate_layer_fn, templatedoc
import sys import sys
import multiprocessing
__all__ = [ __all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv', 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
...@@ -549,10 +550,9 @@ def open_files(filenames, ...@@ -549,10 +550,9 @@ def open_files(filenames,
shapes(list): List of tuples which declaring data shapes. shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level. lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
thread_num(None): Deprecated argument. It will be set by open_files thread_num(None): The number of thread to read files.
automatically. Default: min(len(filenames), cpu_number).
buffer_size(None): Deprecated argument. It will be set by open_files buffer_size(None): The buffer size of reader. Default: 3 * thread_num
automatically.
pass_num(int): Number of passes to run. pass_num(int): Number of passes to run.
is_test(bool|None): Whether `open_files` used for testing or not. If it is_test(bool|None): Whether `open_files` used for testing or not. If it
is used for testing, the order of data generated is same as the file is used for testing, the order of data generated is same as the file
...@@ -574,14 +574,15 @@ def open_files(filenames, ...@@ -574,14 +574,15 @@ def open_files(filenames,
# Via the reader, we can use 'read_file' layer to get data: # Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader) image, label = fluid.layers.io.read_file(reader)
""" """
if thread_num is not None: if thread_num is None:
print >> sys.stderr, "thread_num parameter of open_files is " \ thread_num = min(len(filenames), multiprocessing.cpu_count())
"deprecated. It will be ignored and set " \ else:
"automatically by open_files " thread_num = int(thread_num)
if buffer_size is not None:
print >> sys.stderr, "buffer_size parameter of open_files is " \ if buffer_size is None:
"deprecated. It will be ignored and set " \ buffer_size = 3 * thread_num
"automatically by open_files " else:
buffer_size = int(buffer_size)
if isinstance(filenames, basestring): if isinstance(filenames, basestring):
filenames = [filenames] filenames = [filenames]
...@@ -600,7 +601,9 @@ def open_files(filenames, ...@@ -600,7 +601,9 @@ def open_files(filenames,
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
'ranks': ranks, 'ranks': ranks,
'file_names': filenames 'file_names': filenames,
'thread_num': thread_num,
'buffer_size': buffer_size
} }
if is_test is not None: if is_test is not None:
attrs['is_test'] = is_test attrs['is_test'] = is_test
......
...@@ -155,7 +155,7 @@ class TestDataBalance(unittest.TestCase): ...@@ -155,7 +155,7 @@ class TestDataBalance(unittest.TestCase):
main_program=main_prog, main_program=main_prog,
build_strategy=build_strategy) build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size): if parallel_exe.device_count > self.batch_size:
print("WARNING: Unittest TestDataBalance skipped. \ print("WARNING: Unittest TestDataBalance skipped. \
For the result is not correct when device count \ For the result is not correct when device count \
is larger than batch size.") is larger than batch size.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册