Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
a84b8150
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a84b8150
编写于
4月 11, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove Readers' HasNext()
上级
53aea5e1
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
105 addition
and
170 deletion
+105
-170
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+17
-15
paddle/fluid/framework/lod_tensor.h
paddle/fluid/framework/lod_tensor.h
+5
-2
paddle/fluid/framework/lod_tensor_test.cc
paddle/fluid/framework/lod_tensor_test.cc
+9
-9
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+3
-10
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+21
-17
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+3
-13
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+1
-3
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+3
-7
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+11
-13
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+15
-60
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+14
-11
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
paddle/fluid/pybind/recordio.cc
paddle/fluid/pybind/recordio.cc
+1
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+2
-8
未找到文件。
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
a84b8150
...
@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/memory.h"
...
@@ -22,11 +27,6 @@ limitations under the License. */
...
@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
...
@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
}
}
void
WriteToRecordIO
(
recordio
::
Writer
&
writer
,
void
WriteToRecordIO
(
recordio
::
Writer
*
writer
,
const
std
::
vector
<
LoDTensor
>
&
tensor
,
const
std
::
vector
<
LoDTensor
>
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
stringstream
buffer
;
std
::
stringstream
buffer
;
...
@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
...
@@ -303,18 +303,20 @@ void WriteToRecordIO(recordio::Writer &writer,
for
(
auto
&
each
:
tensor
)
{
for
(
auto
&
each
:
tensor
)
{
SerializeToStream
(
buffer
,
each
,
dev_ctx
);
SerializeToStream
(
buffer
,
each
,
dev_ctx
);
}
}
writer
.
Write
(
buffer
.
str
());
writer
->
Write
(
buffer
.
str
());
}
}
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
recordio
::
Scanner
&
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
recordio
::
Scanner
*
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
istringstream
sin
(
scanner
.
Next
());
uint32_t
sz
;
sin
.
read
(
reinterpret_cast
<
char
*>
(
&
sz
),
sizeof
(
uint32_t
));
std
::
vector
<
LoDTensor
>
result
;
std
::
vector
<
LoDTensor
>
result
;
result
.
resize
(
sz
);
if
(
scanner
->
HasNext
())
{
for
(
uint32_t
i
=
0
;
i
<
sz
;
++
i
)
{
std
::
istringstream
sin
(
scanner
->
Next
());
DeserializeFromStream
(
sin
,
&
result
[
i
],
dev_ctx
);
uint32_t
sz
;
sin
.
read
(
reinterpret_cast
<
char
*>
(
&
sz
),
sizeof
(
uint32_t
));
result
.
resize
(
sz
);
for
(
uint32_t
i
=
0
;
i
<
sz
;
++
i
)
{
DeserializeFromStream
(
sin
,
&
result
[
i
],
dev_ctx
);
}
}
}
return
result
;
return
result
;
}
}
...
...
paddle/fluid/framework/lod_tensor.h
浏览文件 @
a84b8150
...
@@ -15,6 +15,9 @@ limitations under the License. */
...
@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#pragma once
#include <memory>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/host_vector.h>
...
@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
...
@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
,
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
const
platform
::
DeviceContext
&
dev_ctx
);
extern
void
WriteToRecordIO
(
recordio
::
Writer
&
writer
,
extern
void
WriteToRecordIO
(
recordio
::
Writer
*
writer
,
const
std
::
vector
<
LoDTensor
>&
tensor
,
const
std
::
vector
<
LoDTensor
>&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
const
platform
::
DeviceContext
&
dev_ctx
);
extern
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
extern
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
recordio
::
Scanner
&
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
);
recordio
::
Scanner
*
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/lod_tensor_test.cc
浏览文件 @
a84b8150
...
@@ -12,17 +12,17 @@
...
@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
...
@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
{
{
recordio
::
Writer
writer
(
stream
,
recordio
::
Compressor
::
kSnappy
);
recordio
::
Writer
writer
(
stream
,
recordio
::
Compressor
::
kSnappy
);
WriteToRecordIO
(
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
&
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
&
writer
,
{
tensor
,
tensor
},
ctx
);
writer
.
Flush
();
writer
.
Flush
();
}
}
...
@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
...
@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{
{
std
::
unique_ptr
<
std
::
istream
>
stream_ptr
(
stream
);
std
::
unique_ptr
<
std
::
istream
>
stream_ptr
(
stream
);
recordio
::
Scanner
scanner
(
std
::
move
(
stream_ptr
));
recordio
::
Scanner
scanner
(
std
::
move
(
stream_ptr
));
auto
tensors
=
ReadFromRecordIO
(
scanner
,
ctx
);
auto
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
assert_tensor_ok
(
tensors
[
1
]);
tensors
=
ReadFromRecordIO
(
scanner
,
ctx
);
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
assert_tensor_ok
(
tensors
[
1
]);
...
...
paddle/fluid/framework/reader.h
浏览文件 @
a84b8150
...
@@ -14,14 +14,13 @@
...
@@ -14,14 +14,13 @@
#pragma once
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -31,8 +30,6 @@ class ReaderBase {
...
@@ -31,8 +30,6 @@ class ReaderBase {
virtual
void
ReInit
()
=
0
;
virtual
void
ReInit
()
=
0
;
virtual
bool
HasNext
()
const
=
0
;
virtual
~
ReaderBase
();
virtual
~
ReaderBase
();
};
};
...
@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
...
@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void
ReInit
()
override
{
reader_
->
ReInit
();
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
bool
HasNext
()
const
override
{
return
reader_
->
HasNext
();
}
protected:
protected:
ReaderBase
*
reader_
;
ReaderBase
*
reader_
;
};
};
...
@@ -80,8 +75,6 @@ class ReaderHolder {
...
@@ -80,8 +75,6 @@ class ReaderHolder {
reader_
->
ReInit
();
reader_
->
ReInit
();
}
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
private:
private:
std
::
unique_ptr
<
ReaderBase
>
reader_
;
std
::
unique_ptr
<
ReaderBase
>
reader_
;
};
};
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
a84b8150
...
@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
...
@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher
();
StartPrefetcher
();
}
}
bool
HasNext
()
const
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
void
ReInit
()
override
;
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
private:
private:
bool
HasNext
()
const
;
void
StartPrefetcher
()
{
void
StartPrefetcher
()
{
channel_
=
framework
::
MakeChannel
<
Item
>
(
kChannelSize
);
channel_
=
framework
::
MakeChannel
<
Item
>
(
kChannelSize
);
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
...
@@ -149,22 +150,15 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -149,22 +150,15 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
}
};
};
bool
DoubleBufferReader
::
HasNext
()
const
{
while
(
!
channel_
->
IsClosed
()
&&
!
channel_
->
CanReceive
())
{
}
return
channel_
->
CanReceive
();
}
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
out
->
clear
();
PADDLE_THROW
(
"There is no next data!"
);
if
(
HasNext
())
{
}
Item
batch
;
channel_
->
Receive
(
&
batch
);
Item
batch
;
*
out
=
batch
.
payloads_
;
channel_
->
Receive
(
&
batch
);
if
(
batch
.
ctx_
)
{
*
out
=
batch
.
payloads_
;
batch
.
ctx_
->
Wait
();
if
(
batch
.
ctx_
)
{
}
batch
.
ctx_
->
Wait
();
}
}
}
}
...
@@ -174,16 +168,26 @@ void DoubleBufferReader::ReInit() {
...
@@ -174,16 +168,26 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher
();
StartPrefetcher
();
}
}
bool
DoubleBufferReader
::
HasNext
()
const
{
while
(
!
channel_
->
IsClosed
()
&&
!
channel_
->
CanReceive
())
{
}
return
channel_
->
CanReceive
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
cpu_tensor_cache
(
kCacheSize
);
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
cpu_tensor_cache
(
kCacheSize
);
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
gpu_tensor_cache
(
kCacheSize
);
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
gpu_tensor_cache
(
kCacheSize
);
size_t
cached_tensor_id
=
0
;
size_t
cached_tensor_id
=
0
;
while
(
reader_
->
HasNext
()
)
{
while
(
true
)
{
Item
batch
;
Item
batch
;
auto
&
cpu_batch
=
cpu_tensor_cache
[
cached_tensor_id
];
auto
&
cpu_batch
=
cpu_tensor_cache
[
cached_tensor_id
];
reader_
->
ReadNext
(
&
cpu_batch
);
reader_
->
ReadNext
(
&
cpu_batch
);
if
(
cpu_batch
.
empty
())
{
// The underlying reader have no next data.
break
;
}
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
auto
&
gpu_batch
=
gpu_tensor_cache
[
cached_tensor_id
];
auto
&
gpu_batch
=
gpu_tensor_cache
[
cached_tensor_id
];
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
a84b8150
...
@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
...
@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
reader_
->
ReadNext
(
out
);
reader_
->
ReadNext
(
out
);
}
if
(
out
->
empty
())
{
bool
HasNext
()
const
override
{
if
(
reader_
->
HasNext
())
{
return
true
;
}
else
{
++
pass_count_
;
++
pass_count_
;
if
(
pass_count_
>=
pass_num_
)
{
if
(
pass_count_
<
pass_num_
)
{
return
false
;
}
else
{
reader_
->
ReInit
();
reader_
->
ReInit
();
re
turn
true
;
re
ader_
->
ReadNext
(
out
)
;
}
}
}
}
}
}
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
a84b8150
...
@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
...
@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
void
ReInit
()
override
{
return
;
}
void
ReInit
()
override
{
return
;
}
bool
HasNext
()
const
override
{
return
true
;
}
private:
private:
float
min_
;
float
min_
;
float
max_
;
float
max_
;
...
@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
...
@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
a84b8150
...
@@ -12,8 +12,6 @@
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/scanner.h"
...
@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
...
@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
}
}
bool
HasNext
()
const
override
{
return
scanner_
.
HasNext
();
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
ThreadSafe
)
{
if
(
ThreadSafe
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mutex_
);
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
*
out
=
framework
::
ReadFromRecordIO
(
&
scanner_
,
dev_ctx_
);
}
else
{
}
else
{
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
*
out
=
framework
::
ReadFromRecordIO
(
&
scanner_
,
dev_ctx_
);
}
}
}
}
...
@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
...
@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
"shape concat's length."
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
a84b8150
...
@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
...
@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
std
::
random_device
device
;
std
::
random_device
device
;
seed_
=
device
();
seed_
=
device
();
}
}
Re
adIntoBuffers
();
Re
loadBuffer
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
!
HasNext
())
{
out
->
clear
();
PADDLE_THROW
(
"There is no next data!"
);
}
if
(
iteration_pos_
>=
buffer_
.
size
())
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
ReadIntoBuffers
();
ReloadBuffer
();
if
(
buffer_
.
empty
())
{
return
;
}
}
}
*
out
=
buffer_
[
iteration_pos_
++
];
*
out
=
buffer_
[
iteration_pos_
++
];
}
}
bool
HasNext
()
const
override
{
return
iteration_pos_
<
buffer_
.
size
()
||
reader_
->
HasNext
();
}
private:
private:
void
Re
adIntoBuffers
()
{
void
Re
loadBuffer
()
{
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
buffer_
.
reserve
(
buffer_size_
);
iteration_pos_
=
0
;
iteration_pos_
=
0
;
for
(
size_t
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
!
reader_
->
HasNext
())
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader_
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
break
;
break
;
}
}
buffer_
.
emplace_back
();
buffer_
.
emplace_back
(
ins
);
reader_
->
ReadNext
(
&
buffer_
.
back
());
}
}
std
::
mt19937
g
(
seed_
);
std
::
mt19937
g
(
seed_
);
std
::
shuffle
(
buffer_
.
begin
(),
buffer_
.
end
(),
g
);
std
::
shuffle
(
buffer_
.
begin
(),
buffer_
.
end
(),
g
);
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
浏览文件 @
a84b8150
...
@@ -21,67 +21,27 @@ namespace reader {
...
@@ -21,67 +21,27 @@ namespace reader {
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
public:
public:
ThreadedReader
(
ReaderBase
*
reader
,
bool
un
safe_mode
)
ThreadedReader
(
ReaderBase
*
reader
,
bool
safe_mode
)
:
DecoratedReader
(
reader
),
unsafe_mode_
(
un
safe_mode
)
{}
:
DecoratedReader
(
reader
),
safe_mode_
(
safe_mode
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
!
unsafe_mode_
)
{
reader_
->
ReadNext
(
out
);
if
(
!
reader_
->
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
reader_
->
ReadNext
(
out
);
}
else
{
auto
&
thread_buffer
=
thread_buffers_
[
std
::
this_thread
::
get_id
()];
if
(
thread_buffer
.
empty
())
{
PADDLE_THROW
(
"thread_buffer is empty! HasNext() must be invoked before "
"ReadNext() in the same thread."
);
}
*
out
=
thread_buffer
;
thread_buffer
.
clear
();
}
}
bool
HasNext
()
const
override
{
if
(
!
unsafe_mode_
)
{
PADDLE_THROW
(
"ThreadedReader::HasNext() is disabled when 'unsafe_mode' is false."
);
}
std
::
thread
::
id
thread_id
=
std
::
this_thread
::
get_id
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
auto
&
thread_buffer
=
thread_buffers_
[
thread_id
];
if
(
thread_buffer
.
empty
()
&&
reader_
->
HasNext
())
{
reader_
->
ReadNext
(
&
thread_buffer
);
}
return
!
thread_buffer
.
empty
();
}
}
void
ReInit
()
override
{
void
ReInit
()
override
{
if
(
!
un
safe_mode_
)
{
if
(
safe_mode_
)
{
PADDLE_THROW
(
PADDLE_THROW
(
"ThreadedReader::ReInit() is disabled when '
unsafe_mode' is fals
e."
);
"ThreadedReader::ReInit() is disabled when '
safe_mode' is tru
e."
);
}
}
VLOG
(
5
)
<<
"ThreadedReader::ReInit() is invoked! It might be buggy in "
VLOG
(
5
)
<<
"ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment."
;
"multi-thread environment."
;
reader_
->
ReInit
();
reader_
->
ReInit
();
}
}
~
ThreadedReader
()
{
for
(
auto
&
p
:
thread_buffers_
)
{
if
(
!
p
.
second
.
empty
())
{
PADDLE_THROW
(
"Find an unused data batch in ThreadedReader! Maybe one thread "
"invokes 'HasNext()' without subsequent 'ReadNext()'."
);
}
}
}
private:
private:
bool
unsafe_mode_
;
bool
safe_mode_
;
mutable
std
::
mutex
mutex_
;
std
::
mutex
mutex_
;
mutable
std
::
unordered_map
<
std
::
thread
::
id
,
std
::
vector
<
framework
::
LoDTensor
>>
thread_buffers_
;
};
};
class
CreateThreadedReaderOp
:
public
framework
::
OperatorBase
{
class
CreateThreadedReaderOp
:
public
framework
::
OperatorBase
{
...
@@ -98,8 +58,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
...
@@ -98,8 +58,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
}
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
bool
unsafe_mode
=
Attr
<
bool
>
(
"un
safe_mode"
);
bool
safe_mode
=
Attr
<
bool
>
(
"
safe_mode"
);
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
(),
un
safe_mode
));
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
(),
safe_mode
));
}
}
};
};
...
@@ -107,10 +67,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -107,10 +67,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
public:
public:
CreateThreadedReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
CreateThreadedReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
DecoratedReaderMakerBase
(
op_proto
,
op_checker
)
{
:
DecoratedReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
bool
>
(
"unsafe_mode"
,
AddAttr
<
bool
>
(
"safe_mode"
,
"When 'unsafe_mode' is false, invoking 'HasNext()' or "
"When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"'ReInit()' is not allowed to avoid unexpected bugs in "
"unexpected bugs in multi-thread environment."
)
"multi-thread environment."
)
.
SetDefault
(
true
);
.
SetDefault
(
true
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
CreateThreadedReader Operator
CreateThreadedReader Operator
...
@@ -118,13 +77,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -118,13 +77,9 @@ class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
This operator creates a threaded reader. A threaded reader's
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
'ReadNext()' can be invoked by several threads at the same
time.
time.
When the attribute 'unsafe_mode' is false, the threaded reader's
When the attribute 'safe_mode' is true, the threaded reader's
'HasNext()' and 'ReInit()' will be disabled to avoid unexpected
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
bugs in multi-thread environment. If you really need them, you
environment.
can enable them by setting 'unsafe_mode' true. In this case,
'HasNext()' returning true only guarantees the safety of
invoking 'ReadNext()' in the same thread. Each thread must
invoke 'HasNext()' and 'ReadNext()' in pairs.
)DOC"
);
)DOC"
);
}
}
};
};
...
...
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
a84b8150
...
@@ -30,12 +30,12 @@ class MultiFileReader : public framework::ReaderBase {
...
@@ -30,12 +30,12 @@ class MultiFileReader : public framework::ReaderBase {
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
bool
HasNext
()
const
override
;
void
ReInit
()
override
;
void
ReInit
()
override
;
~
MultiFileReader
()
{
EndScheduler
();
}
~
MultiFileReader
()
{
EndScheduler
();
}
private:
private:
bool
HasNext
();
void
StartNewScheduler
();
void
StartNewScheduler
();
void
EndScheduler
();
void
EndScheduler
();
void
ScheduleThreadFunc
();
void
ScheduleThreadFunc
();
...
@@ -52,16 +52,10 @@ class MultiFileReader : public framework::ReaderBase {
...
@@ -52,16 +52,10 @@ class MultiFileReader : public framework::ReaderBase {
};
};
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
out
->
clear
();
PADDLE_THROW
(
"There is no next data!"
);
if
(
HasNext
())
{
buffer_
->
Receive
(
out
);
}
}
buffer_
->
Receive
(
out
);
}
bool
MultiFileReader
::
HasNext
()
const
{
while
(
!
buffer_
->
IsClosed
()
&&
!
buffer_
->
CanReceive
())
{
}
return
buffer_
->
CanReceive
();
}
}
void
MultiFileReader
::
ReInit
()
{
void
MultiFileReader
::
ReInit
()
{
...
@@ -69,6 +63,12 @@ void MultiFileReader::ReInit() {
...
@@ -69,6 +63,12 @@ void MultiFileReader::ReInit() {
StartNewScheduler
();
StartNewScheduler
();
}
}
bool
MultiFileReader
::
HasNext
()
{
while
(
!
buffer_
->
IsClosed
()
&&
!
buffer_
->
CanReceive
())
{
}
return
buffer_
->
CanReceive
();
}
void
MultiFileReader
::
StartNewScheduler
()
{
void
MultiFileReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
size_t
thread_num
=
prefetchers_
.
size
();
waiting_file_idx_
=
framework
::
MakeChannel
<
size_t
>
(
file_names_
.
size
());
waiting_file_idx_
=
framework
::
MakeChannel
<
size_t
>
(
file_names_
.
size
());
...
@@ -140,9 +140,12 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
...
@@ -140,9 +140,12 @@ void MultiFileReader::PrefetchThreadFunc(std::string file_name,
VLOG
(
5
)
<<
"The prefetch thread of file '"
<<
file_name
<<
"' starts."
;
VLOG
(
5
)
<<
"The prefetch thread of file '"
<<
file_name
<<
"' starts."
;
std
::
unique_ptr
<
framework
::
ReaderBase
>
reader
=
std
::
unique_ptr
<
framework
::
ReaderBase
>
reader
=
CreateReaderByFileName
(
file_name
,
dims_
);
CreateReaderByFileName
(
file_name
,
dims_
);
while
(
reader
->
HasNext
()
)
{
while
(
true
)
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
break
;
}
try
{
try
{
buffer_
->
Send
(
&
ins
);
buffer_
->
Send
(
&
ins
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
a84b8150
...
@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"has_next"
,
&
framework
::
ReaderHolder
::
HasNext
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
ReInit
);
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
ReInit
);
py
::
class_
<
Scope
>
(
m
,
"Scope"
,
""
)
py
::
class_
<
Scope
>
(
m
,
"Scope"
,
""
)
...
...
paddle/fluid/pybind/recordio.cc
浏览文件 @
a84b8150
...
@@ -39,7 +39,7 @@ class RecordIOWriter {
...
@@ -39,7 +39,7 @@ class RecordIOWriter {
void
CompleteAppendTensor
()
{
void
CompleteAppendTensor
()
{
auto
&
ctx
=
auto
&
ctx
=
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
framework
::
WriteToRecordIO
(
writer_
,
tensors_
,
ctx
);
framework
::
WriteToRecordIO
(
&
writer_
,
tensors_
,
ctx
);
tensors_
.
clear
();
tensors_
.
clear
();
}
}
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
a84b8150
...
@@ -236,13 +236,9 @@ def monkey_patch_reader_methods(reader):
...
@@ -236,13 +236,9 @@ def monkey_patch_reader_methods(reader):
var
=
scope
.
find_var
(
reader
.
name
)
var
=
scope
.
find_var
(
reader
.
name
)
return
var
.
get_reader
()
return
var
.
get_reader
()
def
eof
():
return
not
__get_reader__
().
has_next
()
def
reset
():
def
reset
():
return
__get_reader__
().
reset
()
return
__get_reader__
().
reset
()
reader
.
eof
=
eof
reader
.
reset
=
reset
reader
.
reset
=
reset
reader
.
stop_gradient
=
True
reader
.
stop_gradient
=
True
reader
.
persistable
=
True
reader
.
persistable
=
True
...
@@ -299,8 +295,7 @@ def open_recordio_file(filename,
...
@@ -299,8 +295,7 @@ def open_recordio_file(filename,
shapes(list): List of tuples which declaring data shapes.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run. After completing the
pass_num(int): Number of passes to run.
given number of passes, 'has_next()' will return False.
for_parallel(Bool): Set it as True if you are going to run
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
subsequent operators in parallel.
...
@@ -377,8 +372,7 @@ def open_files(filenames,
...
@@ -377,8 +372,7 @@ def open_files(filenames,
dtypes(list): List of strs which declaring data type.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run. After completing the
pass_num(int): Number of passes to run.
given number of passes, 'has_next()' will return False.
for_parallel(Bool): Set it as True if you are going to run
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
subsequent operators in parallel.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录