Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5528f599
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5528f599
编写于
7月 08, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Split ReInit() to Shutdown() and Start()
上级
de9a411f
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
136 addition
and
64 deletion
+136
-64
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+22
-1
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+34
-9
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+3
-2
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+5
-3
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+14
-11
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+9
-6
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+6
-2
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+1
-2
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+3
-2
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+13
-2
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+14
-7
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+11
-16
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-1
未找到文件。
paddle/fluid/framework/reader.cc
浏览文件 @
5528f599
...
@@ -16,8 +16,29 @@
...
@@ -16,8 +16,29 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
void
ReaderBase
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
if
(
status_
!=
ReaderStatus
::
kRunning
)
{
PADDLE_THROW
(
"The reader is not at the status of 'running'."
);
}
ReadNextImpl
(
out
);
}
void
ReaderBase
::
Shutdown
()
{
if
(
status_
!=
ReaderStatus
::
kStopped
)
{
ShutdownImpl
();
status_
=
ReaderStatus
::
kStopped
;
}
}
void
ReaderBase
::
Start
()
{
if
(
status_
!=
ReaderStatus
::
kRunning
)
{
StartImpl
();
status_
=
ReaderStatus
::
kRunning
;
}
}
ReaderBase
::~
ReaderBase
()
{}
ReaderBase
::~
ReaderBase
()
{}
void
FileReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
ReadNextImpl
(
out
);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
5528f599
...
@@ -24,13 +24,26 @@
...
@@ -24,13 +24,26 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
enum
ReaderStatus
{
kRunning
,
kStopped
};
class
ReaderBase
{
class
ReaderBase
{
public:
public:
virtual
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
);
void
Shutdown
();
v
irtual
void
ReInit
()
=
0
;
v
oid
Start
()
;
virtual
~
ReaderBase
();
virtual
~
ReaderBase
();
protected:
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
virtual
void
ShutdownImpl
()
=
0
;
virtual
void
StartImpl
()
=
0
;
std
::
atomic
<
ReaderStatus
>
status_
{
kStopped
};
};
};
class
DecoratedReader
:
public
ReaderBase
{
class
DecoratedReader
:
public
ReaderBase
{
...
@@ -40,9 +53,11 @@ class DecoratedReader : public ReaderBase {
...
@@ -40,9 +53,11 @@ class DecoratedReader : public ReaderBase {
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
protected:
protected:
void
ShutdownImpl
()
override
{
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
}
std
::
shared_ptr
<
ReaderBase
>
reader_
;
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
};
...
@@ -50,10 +65,10 @@ class FileReader : public ReaderBase {
...
@@ -50,10 +65,10 @@ class FileReader : public ReaderBase {
public:
public:
FileReader
()
:
ReaderBase
()
{}
FileReader
()
:
ReaderBase
()
{}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
override
;
protected:
protected:
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
void
ShutdownImpl
()
override
{}
void
StartImpl
()
override
{}
};
};
// The ReaderHolder is used as reader' unified wrapper,
// The ReaderHolder is used as reader' unified wrapper,
...
@@ -68,9 +83,19 @@ class ReaderHolder {
...
@@ -68,9 +83,19 @@ class ReaderHolder {
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
ReadNext
(
out
);
reader_
->
ReadNext
(
out
);
}
}
void
ReInit
()
{
void
ResetAll
()
{
// TODO(fengjiayi): The interface of reseting all.
}
void
Shutdown
()
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
Shutdown
();
}
void
Start
()
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
ReIni
t
();
reader_
->
Star
t
();
}
}
private:
private:
...
...
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
5528f599
...
@@ -23,9 +23,10 @@ class BatchReader : public framework::DecoratedReader {
...
@@ -23,9 +23,10 @@ class BatchReader : public framework::DecoratedReader {
BatchReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
batch_size
)
BatchReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
batch_size
)
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
)
{
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
)
{
buffer_
.
reserve
(
batch_size_
);
buffer_
.
reserve
(
batch_size_
);
Start
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
private:
int
batch_size_
;
int
batch_size_
;
...
@@ -66,7 +67,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -66,7 +67,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
}
}
};
};
void
BatchReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
BatchReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reserve
(
batch_size_
);
buffer_
.
reserve
(
batch_size_
);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
5528f599
...
@@ -31,9 +31,11 @@ class CustomReader : public framework::DecoratedReader {
...
@@ -31,9 +31,11 @@ class CustomReader : public framework::DecoratedReader {
sub_block_id_
(
sub_block
.
ID
()),
sub_block_id_
(
sub_block
.
ID
()),
exe_
(
framework
::
Executor
(
platform
::
CPUPlace
())),
exe_
(
framework
::
Executor
(
platform
::
CPUPlace
())),
source_var_names_
(
source_var_names
),
source_var_names_
(
source_var_names
),
sink_var_names_
(
sink_var_names
)
{}
sink_var_names_
(
sink_var_names
)
{
Start
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
private:
const
framework
::
ProgramDesc
program_
;
const
framework
::
ProgramDesc
program_
;
...
@@ -143,7 +145,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
...
@@ -143,7 +145,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
}
}
};
};
void
CustomReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
CustomReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
out
->
clear
();
out
->
clear
();
std
::
vector
<
framework
::
LoDTensor
>
underlying_outs
;
std
::
vector
<
framework
::
LoDTensor
>
underlying_outs
;
reader_
->
ReadNext
(
&
underlying_outs
);
reader_
->
ReadNext
(
&
underlying_outs
);
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
5528f599
...
@@ -47,15 +47,24 @@ class DoubleBufferReader : public framework::DecoratedReader {
...
@@ -47,15 +47,24 @@ class DoubleBufferReader : public framework::DecoratedReader {
}
}
}
}
#endif
#endif
Start
Prefetcher
();
Start
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
~
DoubleBufferReader
()
{
Shutdown
();
}
private:
private:
void
ShutdownImpl
()
override
{
EndPrefetcher
();
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
StartPrefetcher
();
}
void
StartPrefetcher
()
{
void
StartPrefetcher
()
{
channel_
=
new
reader
::
BlockingQueue
<
size_t
>
(
kChannelSize
);
channel_
=
new
reader
::
BlockingQueue
<
size_t
>
(
kChannelSize
);
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
...
@@ -136,7 +145,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -136,7 +145,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
}
};
};
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
DoubleBufferReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
size_t
cached_tensor_id
;
size_t
cached_tensor_id
;
if
(
channel_
->
Receive
(
&
cached_tensor_id
))
{
if
(
channel_
->
Receive
(
&
cached_tensor_id
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
...
@@ -150,12 +159,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
...
@@ -150,12 +159,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
}
}
}
}
void
DoubleBufferReader
::
ReInit
()
{
EndPrefetcher
();
reader_
->
ReInit
();
StartPrefetcher
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
size_t
cached_tensor_id
=
0
;
size_t
cached_tensor_id
=
0
;
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
5528f599
...
@@ -22,25 +22,28 @@ namespace reader {
...
@@ -22,25 +22,28 @@ namespace reader {
class
MultiPassReader
:
public
framework
::
DecoratedReader
{
class
MultiPassReader
:
public
framework
::
DecoratedReader
{
public:
public:
MultiPassReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
pass_num
)
MultiPassReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
pass_num
)
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
)
{
Start
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
reader_
->
ReadNext
(
out
);
reader_
->
ReadNext
(
out
);
if
(
out
->
empty
())
{
if
(
out
->
empty
())
{
++
pass_count_
;
++
pass_count_
;
if
(
pass_count_
<
pass_num_
)
{
if
(
pass_count_
<
pass_num_
)
{
reader_
->
ReInit
();
reader_
->
Shutdown
();
reader_
->
Start
();
reader_
->
ReadNext
(
out
);
reader_
->
ReadNext
(
out
);
}
}
}
}
}
}
void
ReInit
()
override
{
private:
void
StartImpl
()
override
{
pass_count_
=
0
;
pass_count_
=
0
;
reader_
->
ReIni
t
();
reader_
->
Star
t
();
}
}
private:
int
pass_num_
;
int
pass_num_
;
mutable
int
pass_count_
;
mutable
int
pass_count_
;
};
};
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
5528f599
...
@@ -33,9 +33,13 @@ class PyReader : public framework::FileReader {
...
@@ -33,9 +33,13 @@ class PyReader : public framework::FileReader {
if
(
!
success
)
out
->
clear
();
if
(
!
success
)
out
->
clear
();
}
}
void
ReInit
()
override
{}
private:
private:
void
ShutdownImpl
()
override
{
/* TODO */
}
void
StartImpl
()
override
{
/* TODO */
}
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue_
;
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue_
;
};
};
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
5528f599
...
@@ -30,6 +30,7 @@ class RandomDataGenerator : public framework::FileReader {
...
@@ -30,6 +30,7 @@ class RandomDataGenerator : public framework::FileReader {
unsigned
int
seed
=
std
::
random_device
()();
unsigned
int
seed
=
std
::
random_device
()();
engine_
.
seed
(
seed
);
engine_
.
seed
(
seed
);
dist_
=
std
::
uniform_real_distribution
<
float
>
(
low_
,
high_
);
dist_
=
std
::
uniform_real_distribution
<
float
>
(
low_
,
high_
);
Start
();
}
}
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
...
@@ -51,8 +52,6 @@ class RandomDataGenerator : public framework::FileReader {
...
@@ -51,8 +52,6 @@ class RandomDataGenerator : public framework::FileReader {
}
}
}
}
void
ReInit
()
override
{
return
;
}
private:
private:
float
low_
;
float
low_
;
float
high_
;
float
high_
;
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
5528f599
...
@@ -30,10 +30,9 @@ class RecordIOFileReader : public framework::FileReader {
...
@@ -30,10 +30,9 @@ class RecordIOFileReader : public framework::FileReader {
mutex_
.
reset
(
new
std
::
mutex
());
mutex_
.
reset
(
new
std
::
mutex
());
}
}
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
Start
();
}
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
ThreadSafe
)
{
if
(
ThreadSafe
)
{
...
@@ -44,6 +43,8 @@ class RecordIOFileReader : public framework::FileReader {
...
@@ -44,6 +43,8 @@ class RecordIOFileReader : public framework::FileReader {
}
}
}
}
void
ShutdownImpl
()
override
{
scanner_
.
Reset
();
}
private:
private:
std
::
unique_ptr
<
std
::
mutex
>
mutex_
;
std
::
unique_ptr
<
std
::
mutex
>
mutex_
;
recordio
::
Scanner
scanner_
;
recordio
::
Scanner
scanner_
;
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
5528f599
...
@@ -31,10 +31,10 @@ class ShuffleReader : public framework::DecoratedReader {
...
@@ -31,10 +31,10 @@ class ShuffleReader : public framework::DecoratedReader {
std
::
random_device
device
;
std
::
random_device
device
;
seed_
=
device
();
seed_
=
device
();
}
}
ReloadBuffer
();
Start
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
out
->
clear
();
out
->
clear
();
if
(
iteration_pos_
>=
buffer_
.
size
())
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
...
@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
...
@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
}
}
private:
private:
void
ShutdownImpl
()
override
{
buffer_
.
clear
();
iteration_pos_
=
0
;
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
ReloadBuffer
();
}
void
ReloadBuffer
()
{
void
ReloadBuffer
()
{
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
buffer_
.
reserve
(
buffer_size_
);
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
浏览文件 @
5528f599
...
@@ -22,16 +22,26 @@ namespace reader {
...
@@ -22,16 +22,26 @@ namespace reader {
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
public:
public:
explicit
ThreadedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
explicit
ThreadedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
DecoratedReader
(
reader
)
{}
:
DecoratedReader
(
reader
)
{
Start
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
reader_
->
ReadNext
(
out
);
reader_
->
ReadNext
(
out
);
}
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
private:
private:
void
ShutdownImpl
()
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
reader_
->
Start
();
}
std
::
mutex
mutex_
;
std
::
mutex
mutex_
;
};
};
...
@@ -62,9 +72,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -62,9 +72,6 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
This operator creates a threaded reader. A threaded reader's
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
'ReadNext()' can be invoked by several threads at the same
time.
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC"
);
)DOC"
);
}
}
};
};
...
...
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
5528f599
...
@@ -31,17 +31,16 @@ class MultiFileReader : public framework::ReaderBase {
...
@@ -31,17 +31,16 @@ class MultiFileReader : public framework::ReaderBase {
readers_
.
emplace_back
(
CreateReaderByFileName
(
f_name
));
readers_
.
emplace_back
(
CreateReaderByFileName
(
f_name
));
}
}
prefetchers_
.
resize
(
thread_num
);
prefetchers_
.
resize
(
thread_num
);
Start
NewScheduler
();
Start
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
~
MultiFileReader
()
{
EndScheduler
();
}
~
MultiFileReader
()
{
Shutdown
();
}
private:
private:
void
Start
NewScheduler
()
;
void
Start
Impl
()
override
;
void
EndScheduler
()
;
void
ShutdownImpl
()
override
;
void
ScheduleThreadFunc
();
void
ScheduleThreadFunc
();
void
PrefetchThreadFunc
(
size_t
reader_idx
,
size_t
thread_idx
);
void
PrefetchThreadFunc
(
size_t
reader_idx
,
size_t
thread_idx
);
...
@@ -54,18 +53,13 @@ class MultiFileReader : public framework::ReaderBase {
...
@@ -54,18 +53,13 @@ class MultiFileReader : public framework::ReaderBase {
reader
::
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
reader
::
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
};
};
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
MultiFileReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
buffer_
->
Receive
(
out
))
{
if
(
!
buffer_
->
Receive
(
out
))
{
out
->
clear
();
out
->
clear
();
}
}
}
}
void
MultiFileReader
::
ReInit
()
{
void
MultiFileReader
::
StartImpl
()
{
EndScheduler
();
StartNewScheduler
();
}
void
MultiFileReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
size_t
thread_num
=
prefetchers_
.
size
();
waiting_reader_idx_
=
new
reader
::
BlockingQueue
<
size_t
>
(
readers_
.
size
());
waiting_reader_idx_
=
new
reader
::
BlockingQueue
<
size_t
>
(
readers_
.
size
());
available_thread_idx_
=
new
reader
::
BlockingQueue
<
size_t
>
(
thread_num
);
available_thread_idx_
=
new
reader
::
BlockingQueue
<
size_t
>
(
thread_num
);
...
@@ -83,7 +77,7 @@ void MultiFileReader::StartNewScheduler() {
...
@@ -83,7 +77,7 @@ void MultiFileReader::StartNewScheduler() {
scheduler_
=
std
::
thread
([
this
]
{
ScheduleThreadFunc
();
});
scheduler_
=
std
::
thread
([
this
]
{
ScheduleThreadFunc
();
});
}
}
void
MultiFileReader
::
EndScheduler
()
{
void
MultiFileReader
::
ShutdownImpl
()
{
available_thread_idx_
->
Close
();
available_thread_idx_
->
Close
();
buffer_
->
Close
();
buffer_
->
Close
();
waiting_reader_idx_
->
Close
();
waiting_reader_idx_
->
Close
();
...
@@ -119,7 +113,7 @@ void MultiFileReader::ScheduleThreadFunc() {
...
@@ -119,7 +113,7 @@ void MultiFileReader::ScheduleThreadFunc() {
}
}
}
}
}
}
// If users invoke
ReInit
() when scheduler is running, it will close the
// If users invoke
Shutdown
() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
// to release their resource. So a check is needed before scheduler ends.
for
(
auto
&
p
:
prefetchers_
)
{
for
(
auto
&
p
:
prefetchers_
)
{
...
@@ -137,7 +131,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
...
@@ -137,7 +131,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std
::
vector
<
framework
::
LoDTensor
>
ins
;
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
if
(
ins
.
empty
())
{
reader
->
ReInit
();
reader
->
Shutdown
();
reader
->
Start
();
break
;
break
;
}
}
try
{
try
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
5528f599
...
@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
Re
Init
);
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
Re
setAll
);
using
LoDTensorBlockingQueue
=
using
LoDTensorBlockingQueue
=
::
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
::
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录