Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a84b8150
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a84b8150
编写于
4月 11, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove Readers' HasNext()
上级
53aea5e1
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
105 addition
and
170 deletion
+105
-170
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+17
-15
paddle/fluid/framework/lod_tensor.h
paddle/fluid/framework/lod_tensor.h
+5
-2
paddle/fluid/framework/lod_tensor_test.cc
paddle/fluid/framework/lod_tensor_test.cc
+9
-9
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+3
-10
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+21
-17
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+3
-13
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+1
-3
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+3
-7
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+11
-13
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+15
-60
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+14
-11
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
paddle/fluid/pybind/recordio.cc
paddle/fluid/pybind/recordio.cc
+1
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+2
-8
未找到文件。
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
a84b8150
...
...
@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
...
...
@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace
paddle
{
namespace
framework
{
...
...
@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
}
void
WriteToRecordIO
(
recordio
::
Writer
&
writer
,
void
WriteToRecordIO
(
recordio
::
Writer
*
writer
,
const
std
::
vector
<
LoDTensor
>
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
stringstream
buffer
;
...
...
@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
for
(
auto
&
each
:
tensor
)
{
SerializeToStream
(
buffer
,
each
,
dev_ctx
);
}
writer
.
Write
(
buffer
.
str
());
writer
->
Write
(
buffer
.
str
());
}
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
recordio
::
Scanner
&
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
istringstream
sin
(
scanner
.
Next
());
uint32_t
sz
;
sin
.
read
(
reinterpret_cast
<
char
*>
(
&
sz
),
sizeof
(
uint32_t
));
recordio
::
Scanner
*
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
vector
<
LoDTensor
>
result
;
result
.
resize
(
sz
);
for
(
uint32_t
i
=
0
;
i
<
sz
;
++
i
)
{
DeserializeFromStream
(
sin
,
&
result
[
i
],
dev_ctx
);
if
(
scanner
->
HasNext
())
{
std
::
istringstream
sin
(
scanner
->
Next
());
uint32_t
sz
;
sin
.
read
(
reinterpret_cast
<
char
*>
(
&
sz
),
sizeof
(
uint32_t
));
result
.
resize
(
sz
);
for
(
uint32_t
i
=
0
;
i
<
sz
;
++
i
)
{
DeserializeFromStream
(
sin
,
&
result
[
i
],
dev_ctx
);
}
}
return
result
;
}
...
...
paddle/fluid/framework/lod_tensor.h
浏览文件 @
a84b8150
...
...
@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
...
...
@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
extern
void
WriteToRecordIO
(
recordio
::
Writer
&
writer
,
extern
void
WriteToRecordIO
(
recordio
::
Writer
*
writer
,
const
std
::
vector
<
LoDTensor
>&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
extern
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
recordio
::
Scanner
&
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
);
recordio
::
Scanner
*
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/lod_tensor_test.cc
浏览文件 @
a84b8150
...
...
@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
{
recordio
::
Writer
writer
(
stream
,
recordio
::
Compressor
::
kSnappy
);
WriteToRecordIO
(
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
&
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
&
writer
,
{
tensor
,
tensor
},
ctx
);
writer
.
Flush
();
}
...
...
@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{
std
::
unique_ptr
<
std
::
istream
>
stream_ptr
(
stream
);
recordio
::
Scanner
scanner
(
std
::
move
(
stream_ptr
));
auto
tensors
=
ReadFromRecordIO
(
scanner
,
ctx
);
auto
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
tensors
=
ReadFromRecordIO
(
scanner
,
ctx
);
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
...
...
paddle/fluid/framework/reader.h
浏览文件 @
a84b8150
...
...
@@ -14,14 +14,13 @@
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace
paddle
{
namespace
framework
{
...
...
@@ -31,8 +30,6 @@ class ReaderBase {
virtual
void
ReInit
()
=
0
;
virtual
bool
HasNext
()
const
=
0
;
virtual
~
ReaderBase
();
};
...
...
@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void
ReInit
()
override
{
reader_
->
ReInit
();
}
bool
HasNext
()
const
override
{
return
reader_
->
HasNext
();
}
protected:
ReaderBase
*
reader_
;
};
...
...
@@ -80,8 +75,6 @@ class ReaderHolder {
reader_
->
ReInit
();
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
private:
std
::
unique_ptr
<
ReaderBase
>
reader_
;
};
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
a84b8150
...
...
@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher
();
}
bool
HasNext
()
const
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
private:
bool
HasNext
()
const
;
void
StartPrefetcher
()
{
channel_
=
framework
::
MakeChannel
<
Item
>
(
kChannelSize
);
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
...
...
@@ -149,22 +150,15 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
bool
DoubleBufferReader
::
HasNext
()
const
{
while
(
!
channel_
->
IsClosed
()
&&
!
channel_
->
CanReceive
())
{
}
return
channel_
->
CanReceive
();
}
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
Item
batch
;
channel_
->
Receive
(
&
batch
);
*
out
=
batch
.
payloads_
;
if
(
batch
.
ctx_
)
{
batch
.
ctx_
->
Wait
();
out
->
clear
();
if
(
HasNext
())
{
Item
batch
;
channel_
->
Receive
(
&
batch
);
*
out
=
batch
.
payloads_
;
if
(
batch
.
ctx_
)
{
batch
.
ctx_
->
Wait
();
}
}
}
...
...
@@ -174,16 +168,26 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher
();
}
bool
DoubleBufferReader
::
HasNext
()
const
{
while
(
!
channel_
->
IsClosed
()
&&
!
channel_
->
CanReceive
())
{
}
return
channel_
->
CanReceive
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
cpu_tensor_cache
(
kCacheSize
);
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
gpu_tensor_cache
(
kCacheSize
);
size_t
cached_tensor_id
=
0
;
while
(
reader_
->
HasNext
()
)
{
while
(
true
)
{
Item
batch
;
auto
&
cpu_batch
=
cpu_tensor_cache
[
cached_tensor_id
];
reader_
->
ReadNext
(
&
cpu_batch
);
if
(
cpu_batch
.
empty
())
{
// The underlying reader have no next data.
break
;
}
if
(
platform
::
is_gpu_place
(
place_
))
{
auto
&
gpu_batch
=
gpu_tensor_cache
[
cached_tensor_id
];
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
a84b8150
...
...
@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
reader_
->
ReadNext
(
out
);
}
bool
HasNext
()
const
override
{
if
(
reader_
->
HasNext
())
{
return
true
;
}
else
{
if
(
out
->
empty
())
{
++
pass_count_
;
if
(
pass_count_
>=
pass_num_
)
{
return
false
;
}
else
{
if
(
pass_count_
<
pass_num_
)
{
reader_
->
ReInit
();
re
turn
true
;
re
ader_
->
ReadNext
(
out
)
;
}
}
}
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
a84b8150
...
...
@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
void
ReInit
()
override
{
return
;
}
bool
HasNext
()
const
override
{
return
true
;
}
private:
float
min_
;
float
max_
;
...
...
@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
a84b8150
...
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
...
...
@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
}
bool
HasNext
()
const
override
{
return
scanner_
.
HasNext
();
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
ThreadSafe
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mutex_
);
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
*
out
=
framework
::
ReadFromRecordIO
(
&
scanner_
,
dev_ctx_
);
}
else
{
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
*
out
=
framework
::
ReadFromRecordIO
(
&
scanner_
,
dev_ctx_
);
}
}
...
...
@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
a84b8150
...
...
@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
std
::
random_device
device
;
seed_
=
device
();
}
Re
adIntoBuffers
();
Re
loadBuffer
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
out
->
clear
();
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
ReadIntoBuffers
();
ReloadBuffer
();
if
(
buffer_
.
empty
())
{
return
;
}
}
*
out
=
buffer_
[
iteration_pos_
++
];
}
bool
HasNext
()
const
override
{
return
iteration_pos_
<
buffer_
.
size
()
||
reader_
->
HasNext
();
}
private:
void
Re
adIntoBuffers
()
{
void
Re
loadBuffer
()
{
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
iteration_pos_
=
0
;
for
(
size_t
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
!
reader_
->
HasNext
())
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader_
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
break
;
}
buffer_
.
emplace_back
();
reader_
->
ReadNext
(
&
buffer_
.
back
());
buffer_
.
emplace_back
(
ins
);
}
std
::
mt19937
g
(
seed_
);
std
::
shuffle
(
buffer_
.
begin
(),
buffer_
.
end
(),
g
);
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
浏览文件 @
a84b8150
...
...
@@ -21,67 +21,27 @@ namespace reader {
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
public:
ThreadedReader
(
ReaderBase
*
reader
,
bool
un
safe_mode
)
:
DecoratedReader
(
reader
),
unsafe_mode_
(
un
safe_mode
)
{}
ThreadedReader
(
ReaderBase
*
reader
,
bool
safe_mode
)
:
DecoratedReader
(
reader
),
safe_mode_
(
safe_mode
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
!
unsafe_mode_
)
{
if
(
!
reader_
->
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
reader_
->
ReadNext
(
out
);
}
else
{
auto
&
thread_buffer
=
thread_buffers_
[
std
::
this_thread
::
get_id
()];
if
(
thread_buffer
.
empty
())
{
PADDLE_THROW
(
"thread_buffer is empty! HasNext() must be invoked before "
"ReadNext() in the same thread."
);
}
*
out
=
thread_buffer
;
thread_buffer
.
clear
();
}
}
bool
HasNext
()
const
override
{
if
(
!
unsafe_mode_
)
{
PADDLE_THROW
(
"ThreadedReader::HasNext() is disabled when 'unsafe_mode' is false."
);
}
std
::
thread
::
id
thread_id
=
std
::
this_thread
::
get_id
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
&
thread_buffer
=
thread_buffers_
[
thread_id
];
if
(
thread_buffer
.
empty
()
&&
reader_
->
HasNext
())
{
reader_
->
ReadNext
(
&
thread_buffer
);
}
return
!
thread_buffer
.
empty
();
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
override
{
if
(
!
un
safe_mode_
)
{
if
(
safe_mode_
)
{
PADDLE_THROW
(
"ThreadedReader::ReInit() is disabled when '
unsafe_mode' is fals
e."
);
"ThreadedReader::ReInit() is disabled when '
safe_mode' is tru
e."
);
}
VLOG
(
5
)
<<
"ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment."
;
reader_
->
ReInit
();
}
~
ThreadedReader
()
{
for
(
auto
&
p
:
thread_buffers_
)
{
if
(
!
p
.
second
.
empty
())
{
PADDLE_THROW
(
"Find an unused data batch in ThreadedReader! Maybe one thread "
"invokes 'HasNext()' without subsequent 'ReadNext()'."
);
}
}
}
private:
bool
unsafe_mode_
;
mutable
std
::
mutex
mutex_
;
mutable
std
::
unordered_map
<
std
::
thread
::
id
,
std
::
vector
<
framework
::
LoDTensor
>>
thread_buffers_
;
bool
safe_mode_
;
std
::
mutex
mutex_
;
};
class
CreateThreadedReaderOp
:
public
framework
::
OperatorBase
{
...
...
@@ -98,8 +58,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
bool
unsafe_mode
=
Attr
<
bool
>
(
"un
safe_mode"
);
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
(),
un
safe_mode
));
bool
safe_mode
=
Attr
<
bool
>
(
"
safe_mode"
);
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
(),
safe_mode
));
}
};
...
...
@@ -107,10 +67,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateThreadedReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
DecoratedReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
bool
>
(
"unsafe_mode"
,
"When 'unsafe_mode' is false, invoking 'HasNext()' or "
"'ReInit()' is not allowed to avoid unexpected bugs in "
"multi-thread environment."
)
AddAttr
<
bool
>
(
"safe_mode"
,
"When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"unexpected bugs in multi-thread environment."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
CreateThreadedReader Operator
...
...
@@ -118,13 +77,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'unsafe_mode' is false, the threaded reader's
'HasNext()' and 'ReInit()' will be disabled to avoid unexpected
bugs in multi-thread environment. If you really need them, you
can enable them by setting 'unsafe_mode' true. In this case,
'HasNext()' returning true only guarantees the safety of
invoking 'ReadNext()' in the same thread. Each thread must
invoke 'HasNext()' and 'ReadNext()' in pairs.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC"
);
}
};
...
...
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
a84b8150
...
...
@@ -30,12 +30,12 @@ class MultiFileReader : public framework::ReaderBase {
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
bool
HasNext
()
const
override
;
void
ReInit
()
override
;
~
MultiFileReader
()
{
EndScheduler
();
}
private:
bool
HasNext
();
void
StartNewScheduler
();
void
EndScheduler
();
void
ScheduleThreadFunc
();
...
...
@@ -52,16 +52,10 @@ class MultiFileReader : public framework::ReaderBase {
};
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
out
->
clear
();
if
(
HasNext
())
{
buffer_
->
Receive
(
out
);
}
buffer_
->
Receive
(
out
);
}
bool
MultiFileReader
::
HasNext
()
const
{
while
(
!
buffer_
->
IsClosed
()
&&
!
buffer_
->
CanReceive
())
{
}
return
buffer_
->
CanReceive
();
}
void
MultiFileReader
::
ReInit
()
{
...
...
@@ -69,6 +63,12 @@ void MultiFileReader::ReInit() {
StartNewScheduler
();
}
bool
MultiFileReader
::
HasNext
()
{
while
(
!
buffer_
->
IsClosed
()
&&
!
buffer_
->
CanReceive
())
{
}
return
buffer_
->
CanReceive
();
}
void
MultiFileReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
waiting_file_idx_
=
framework
::
MakeChannel
<
size_t
>
(
file_names_
.
size
());
...
...
@@ -140,9 +140,12 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
VLOG
(
5
)
<<
"The prefetch thread of file '"
<<
file_name
<<
"' starts."
;
std
::
unique_ptr
<
framework
::
ReaderBase
>
reader
=
CreateReaderByFileName
(
file_name
,
dims_
);
while
(
reader
->
HasNext
()
)
{
while
(
true
)
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
break
;
}
try
{
buffer_
->
Send
(
&
ins
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
a84b8150
...
...
@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"has_next"
,
&
framework
::
ReaderHolder
::
HasNext
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
ReInit
);
py
::
class_
<
Scope
>
(
m
,
"Scope"
,
""
)
...
...
paddle/fluid/pybind/recordio.cc
浏览文件 @
a84b8150
...
...
@@ -39,7 +39,7 @@ class RecordIOWriter {
void
CompleteAppendTensor
()
{
auto
&
ctx
=
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
framework
::
WriteToRecordIO
(
writer_
,
tensors_
,
ctx
);
framework
::
WriteToRecordIO
(
&
writer_
,
tensors_
,
ctx
);
tensors_
.
clear
();
}
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
a84b8150
...
...
@@ -236,13 +236,9 @@ def monkey_patch_reader_methods(reader):
var
=
scope
.
find_var
(
reader
.
name
)
return
var
.
get_reader
()
def
eof
():
return
not
__get_reader__
().
has_next
()
def
reset
():
return
__get_reader__
().
reset
()
reader
.
eof
=
eof
reader
.
reset
=
reset
reader
.
stop_gradient
=
True
reader
.
persistable
=
True
...
...
@@ -299,8 +295,7 @@ def open_recordio_file(filename,
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run. After completing the
given number of passes, 'has_next()' will return False.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
...
...
@@ -377,8 +372,7 @@ def open_files(filenames,
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run. After completing the
given number of passes, 'has_next()' will return False.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录