Make open_files use buffer

上级 dc34effd
......@@ -16,7 +16,7 @@ function(reader_library TARGET_NAME)
endfunction()
cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
reader_library(open_files_op SRCS open_files_op.cc)
reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
......
......@@ -18,6 +18,7 @@
#include "ThreadPool.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
......@@ -232,12 +233,17 @@ class OpenFilesOp : public framework::OperatorBase {
container.reset(new OrderedReaderContainer());
} else {
container.reset(new PreemptiveReaderContainer(
std::min(file_names.size(),
static_cast<size_t>(std::thread::hardware_concurrency()))));
static_cast<size_t>(Attr<int>("thread_num"))));
}
out->Reset(
std::make_shared<MultiFileReader>(file_names, std::move(container)));
auto reader =
std::make_shared<MultiFileReader>(file_names, std::move(container));
auto buffer_size = Attr<int>("buffer_size");
if (buffer_size > 1) {
reader = framework::MakeDecoratedReader<BufferedReader>(
reader, platform::CPUPlace(), buffer_size);
}
out->Reset(reader);
}
};
......@@ -253,6 +259,8 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
An OpenFilesOp creates a MultiFileReader, which is able to
read data multi-threaded from multiple files.
)DOC");
AddAttr<int>("thread_num", "Number of thread to read files.");
AddAttr<int>("buffer_size", "The reading buffer of these files.");
}
};
......
......@@ -21,6 +21,7 @@ from ..layer_helper import LayerHelper
from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc
import sys
import multiprocessing
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
......@@ -549,10 +550,9 @@ def open_files(filenames,
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(None): Deprecated argument. It will be set by open_files
automatically.
buffer_size(None): Deprecated argument. It will be set by open_files
automatically.
thread_num(None): The number of thread to read files.
Default: min(len(filenames), cpu_number).
buffer_size(None): The buffer size of reader. Default: 3 * thread_num
pass_num(int): Number of passes to run.
is_test(bool|None): Whether `open_files` used for testing or not. If it
is used for testing, the order of data generated is same as the file
......@@ -574,14 +574,15 @@ def open_files(filenames,
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
if thread_num is not None:
print >> sys.stderr, "thread_num parameter of open_files is " \
"deprecated. It will be ignored and set " \
"automatically by open_files "
if buffer_size is not None:
print >> sys.stderr, "buffer_size parameter of open_files is " \
"deprecated. It will be ignored and set " \
"automatically by open_files "
if thread_num is None:
thread_num = min(len(filenames), multiprocessing.cpu_count())
else:
thread_num = int(thread_num)
if buffer_size is None:
buffer_size = 3 * thread_num
else:
buffer_size = int(buffer_size)
if isinstance(filenames, basestring):
filenames = [filenames]
......@@ -600,7 +601,9 @@ def open_files(filenames,
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames
'file_names': filenames,
'thread_num': thread_num,
'buffer_size': buffer_size
}
if is_test is not None:
attrs['is_test'] = is_test
......
......@@ -155,7 +155,7 @@ class TestDataBalance(unittest.TestCase):
main_program=main_prog,
build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size):
if parallel_exe.device_count > self.batch_size:
print("WARNING: Unittest TestDataBalance skipped. \
For the result is not correct when device count \
is larger than batch size.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部