提交 32478fe0 编写于 作者: F fengjiayi

Make buffers of DoubleBufferReader and open_files bigger

上级 9dccca96
...@@ -23,13 +23,13 @@ namespace reader { ...@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same // 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2. // time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 3; static constexpr size_t kCacheSize = 5;
// There will be two bacthes out of the channel during training: // There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel // 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by // 2. the one just be received from the channel, which is also being used by
// subsequent operators. // subsequent operators.
// So the channel size should be kChacheSize - 2 // So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 1; // kCacheSize - 2 static constexpr size_t kChannelSize = 3; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
......
...@@ -110,7 +110,7 @@ class BlockGuardServ(BlockGuard): ...@@ -110,7 +110,7 @@ class BlockGuardServ(BlockGuard):
class ListenAndServ(object): class ListenAndServ(object):
""" """
**ListenAndServ Layer** **ListenAndServ Layer**
ListenAndServ is used to create a rpc server bind and listen ListenAndServ is used to create a rpc server bind and listen
on specific TCP port, this server will run the sub-block when on specific TCP port, this server will run the sub-block when
received variables from clients. received variables from clients.
...@@ -212,7 +212,7 @@ def Send(endpoints, send_vars, sync=True): ...@@ -212,7 +212,7 @@ def Send(endpoints, send_vars, sync=True):
of send_vars to send of send_vars to send
send_vars (list): variables to send to server send_vars (list): variables to send to server
sync (bool): whether to wait the request finish sync (bool): whether to wait the request finish
""" """
assert (type(send_vars) == list) assert (type(send_vars) == list)
...@@ -469,10 +469,13 @@ def open_files(filenames, ...@@ -469,10 +469,13 @@ def open_files(filenames,
lod_levels(list): List of ints which declaring data lod_level. lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number. thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer. buffer_size(int|None): The size of prefetch buffer. If it is setted None,
buffer size will be thread_num * 3.
Default: None
pass_num(int): Number of passes to run. pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel. subsequent operators in parallel.
Default: True
Returns: Returns:
Variable: A Reader Variable via which we can get file data. Variable: A Reader Variable via which we can get file data.
...@@ -492,7 +495,7 @@ def open_files(filenames, ...@@ -492,7 +495,7 @@ def open_files(filenames,
image, label = fluid.layers.io.read_file(reader) image, label = fluid.layers.io.read_file(reader)
""" """
if buffer_size is None: if buffer_size is None:
buffer_size = thread_num buffer_size = thread_num * 3
if isinstance(filenames, basestring): if isinstance(filenames, basestring):
filenames = [filenames] filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册