Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
90084a25
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
90084a25
编写于
4月 11, 2018
作者:
F
fengjiayi
提交者:
GitHub
4月 11, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9743 from JiayiFeng/modify_readers_to_fit_parallel_executor
Modify readers to fit the parallel executor
上级
718e1807
8c1eb869
变更
22
显示空白变更内容
内联
并排
Showing
22 changed file
with
390 addition
and
197 deletion
+390
-197
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+17
-15
paddle/fluid/framework/lod_tensor.h
paddle/fluid/framework/lod_tensor.h
+5
-2
paddle/fluid/framework/lod_tensor_test.cc
paddle/fluid/framework/lod_tensor_test.cc
+9
-9
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+1
-3
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+3
-1
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+3
-10
paddle/fluid/operators/read_op.cc
paddle/fluid/operators/read_op.cc
+1
-7
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+1
-0
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+27
-20
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+3
-13
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+1
-3
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+3
-7
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+11
-13
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+94
-0
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+39
-49
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+0
-1
paddle/fluid/pybind/recordio.cc
paddle/fluid/pybind/recordio.cc
+1
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+137
-28
python/paddle/fluid/tests/unittests/test_multi_file_reader.py
...on/paddle/fluid/tests/unittests/test_multi_file_reader.py
+6
-2
python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
...on/paddle/fluid/tests/unittests/test_multi_pass_reader.py
+7
-3
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+12
-6
python/paddle/fluid/tests/unittests/test_recordio_reader.py
python/paddle/fluid/tests/unittests/test_recordio_reader.py
+9
-4
未找到文件。
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
90084a25
...
...
@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
...
...
@@ -22,11 +27,6 @@ limitations under the License. */
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
namespace
paddle
{
namespace
framework
{
...
...
@@ -294,7 +294,7 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
}
void
WriteToRecordIO
(
recordio
::
Writer
&
writer
,
void
WriteToRecordIO
(
recordio
::
Writer
*
writer
,
const
std
::
vector
<
LoDTensor
>
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
stringstream
buffer
;
...
...
@@ -303,19 +303,21 @@ void WriteToRecordIO(recordio::Writer &writer,
for
(
auto
&
each
:
tensor
)
{
SerializeToStream
(
buffer
,
each
,
dev_ctx
);
}
writer
.
Write
(
buffer
.
str
());
writer
->
Write
(
buffer
.
str
());
}
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
recordio
::
Scanner
&
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
istringstream
sin
(
scanner
.
Next
());
recordio
::
Scanner
*
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
std
::
vector
<
LoDTensor
>
result
;
if
(
scanner
->
HasNext
())
{
std
::
istringstream
sin
(
scanner
->
Next
());
uint32_t
sz
;
sin
.
read
(
reinterpret_cast
<
char
*>
(
&
sz
),
sizeof
(
uint32_t
));
std
::
vector
<
LoDTensor
>
result
;
result
.
resize
(
sz
);
for
(
uint32_t
i
=
0
;
i
<
sz
;
++
i
)
{
DeserializeFromStream
(
sin
,
&
result
[
i
],
dev_ctx
);
}
}
return
result
;
}
...
...
paddle/fluid/framework/lod_tensor.h
浏览文件 @
90084a25
...
...
@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
...
...
@@ -216,12 +219,12 @@ void SerializeToStream(std::ostream& os, const LoDTensor& tensor,
void
DeserializeFromStream
(
std
::
istream
&
is
,
LoDTensor
*
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
extern
void
WriteToRecordIO
(
recordio
::
Writer
&
writer
,
extern
void
WriteToRecordIO
(
recordio
::
Writer
*
writer
,
const
std
::
vector
<
LoDTensor
>&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
);
extern
std
::
vector
<
LoDTensor
>
ReadFromRecordIO
(
recordio
::
Scanner
&
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
);
recordio
::
Scanner
*
scanner
,
const
platform
::
DeviceContext
&
dev_ctx
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/lod_tensor_test.cc
浏览文件 @
90084a25
...
...
@@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/recordio/scanner.h"
#include "paddle/fluid/recordio/writer.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -240,8 +240,8 @@ TEST(LoDTensor, RecordIO) {
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
{
recordio
::
Writer
writer
(
stream
,
recordio
::
Compressor
::
kSnappy
);
WriteToRecordIO
(
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
&
writer
,
{
tensor
,
tensor
},
ctx
);
WriteToRecordIO
(
&
writer
,
{
tensor
,
tensor
},
ctx
);
writer
.
Flush
();
}
...
...
@@ -254,11 +254,11 @@ TEST(LoDTensor, RecordIO) {
{
std
::
unique_ptr
<
std
::
istream
>
stream_ptr
(
stream
);
recordio
::
Scanner
scanner
(
std
::
move
(
stream_ptr
));
auto
tensors
=
ReadFromRecordIO
(
scanner
,
ctx
);
auto
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
tensors
=
ReadFromRecordIO
(
scanner
,
ctx
);
tensors
=
ReadFromRecordIO
(
&
scanner
,
ctx
);
ASSERT_EQ
(
tensors
.
size
(),
2
);
assert_tensor_ok
(
tensors
[
0
]);
assert_tensor_ok
(
tensors
[
1
]);
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
90084a25
...
...
@@ -115,14 +115,12 @@ void ParallelExecutor::BCastParamsToGPUs(
for
(
auto
&
var
:
vars
)
{
auto
*
main_var
=
main_scope
->
FindVar
(
var
);
if
(
!
main_var
->
IsType
<
LoDTensor
>
())
{
if
(
main_var
==
nullptr
||
!
main_var
->
IsType
<
LoDTensor
>
())
{
continue
;
}
auto
&
main_tensor
=
main_var
->
Get
<
LoDTensor
>
();
auto
&
dims
=
main_tensor
.
dims
();
if
(
paddle
::
platform
::
is_gpu_place
(
main_tensor
.
place
()))
{
size_t
numel
=
main_tensor
.
numel
();
ncclDataType_t
data_type
=
platform
::
ToNCCLDataType
(
main_tensor
.
type
());
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
90084a25
...
...
@@ -22,7 +22,9 @@ FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void
FileReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
ReadNextImpl
(
out
);
PADDLE_ENFORCE_EQ
(
out
->
size
(),
dims_
.
size
());
if
(
out
->
empty
())
{
return
;
}
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
auto
&
actual
=
out
->
at
(
i
).
dims
();
auto
&
expect
=
dims_
[
i
];
...
...
paddle/fluid/framework/reader.h
浏览文件 @
90084a25
...
...
@@ -14,14 +14,13 @@
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace
paddle
{
namespace
framework
{
...
...
@@ -31,8 +30,6 @@ class ReaderBase {
virtual
void
ReInit
()
=
0
;
virtual
bool
HasNext
()
const
=
0
;
virtual
~
ReaderBase
();
};
...
...
@@ -44,8 +41,6 @@ class DecoratedReader : public ReaderBase {
void
ReInit
()
override
{
reader_
->
ReInit
();
}
bool
HasNext
()
const
override
{
return
reader_
->
HasNext
();
}
protected:
ReaderBase
*
reader_
;
};
...
...
@@ -80,8 +75,6 @@ class ReaderHolder {
reader_
->
ReInit
();
}
bool
HasNext
()
const
{
return
reader_
->
HasNext
();
}
private:
std
::
unique_ptr
<
ReaderBase
>
reader_
;
};
...
...
paddle/fluid/operators/read_op.cc
浏览文件 @
90084a25
...
...
@@ -66,13 +66,7 @@ class ReadOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
reader
->
ReInit
();
reader
->
ReadNext
(
&
ins
);
PADDLE_ENFORCE
(
!
ins
.
empty
(),
"Reader can not read the next data even it has been re-initialized."
);
}
PADDLE_ENFORCE
(
!
ins
.
empty
(),
"There is no next data."
);
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
auto
*
out
=
...
...
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
90084a25
...
...
@@ -22,5 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library
(
create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc
)
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_threaded_reader_op SRCS create_threaded_reader_op.cc
)
# Export local libraries to parent
set
(
READER_LIBRARY
${
LOCAL_READER_LIBS
}
PARENT_SCOPE
)
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
90084a25
...
...
@@ -63,13 +63,14 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher
();
}
bool
HasNext
()
const
override
;
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
private:
bool
HasNext
()
const
;
void
StartPrefetcher
()
{
channel_
=
framework
::
MakeChannel
<
Item
>
(
kChannelSize
);
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
...
...
@@ -109,7 +110,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
platform
::
Place
place
;
if
(
place_str
==
"CPU"
)
{
if
(
place_str
==
"AUTO"
)
{
place
=
dev_place
;
}
else
if
(
place_str
==
"CPU"
)
{
place
=
platform
::
CPUPlace
();
}
else
{
std
::
istringstream
sin
(
place_str
);
...
...
@@ -140,29 +143,23 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
enum_range
.
insert
(
string
::
Sprintf
(
"CUDA:%d"
,
i
));
}
enum_range
.
insert
(
"CPU"
);
AddAttr
<
std
::
string
>
(
"place"
,
"The double buffer place, default is CPU"
)
.
SetDefault
(
"CPU"
)
enum_range
.
insert
(
"AUTO"
);
AddAttr
<
std
::
string
>
(
"place"
,
"The double buffer place"
)
.
SetDefault
(
"AUTO"
)
.
InEnum
({
enum_range
});
}
};
bool
DoubleBufferReader
::
HasNext
()
const
{
while
(
!
channel_
->
IsClosed
()
&&
!
channel_
->
CanReceive
())
{
}
return
channel_
->
CanReceive
();
}
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
out
->
clear
();
if
(
HasNext
())
{
Item
batch
;
channel_
->
Receive
(
&
batch
);
*
out
=
batch
.
payloads_
;
if
(
batch
.
ctx_
)
{
batch
.
ctx_
->
Wait
();
}
}
}
void
DoubleBufferReader
::
ReInit
()
{
...
...
@@ -171,16 +168,26 @@ void DoubleBufferReader::ReInit() {
StartPrefetcher
();
}
bool
DoubleBufferReader
::
HasNext
()
const
{
while
(
!
channel_
->
IsClosed
()
&&
!
channel_
->
CanReceive
())
{
}
return
channel_
->
CanReceive
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
cpu_tensor_cache
(
kCacheSize
);
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
gpu_tensor_cache
(
kCacheSize
);
size_t
cached_tensor_id
=
0
;
while
(
reader_
->
HasNext
()
)
{
while
(
true
)
{
Item
batch
;
auto
&
cpu_batch
=
cpu_tensor_cache
[
cached_tensor_id
];
reader_
->
ReadNext
(
&
cpu_batch
);
if
(
cpu_batch
.
empty
())
{
// The underlying reader have no next data.
break
;
}
if
(
platform
::
is_gpu_place
(
place_
))
{
auto
&
gpu_batch
=
gpu_tensor_cache
[
cached_tensor_id
];
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
90084a25
...
...
@@ -25,22 +25,12 @@ class MultiPassReader : public framework::DecoratedReader {
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
reader_
->
ReadNext
(
out
);
}
bool
HasNext
()
const
override
{
if
(
reader_
->
HasNext
())
{
return
true
;
}
else
{
if
(
out
->
empty
())
{
++
pass_count_
;
if
(
pass_count_
>=
pass_num_
)
{
return
false
;
}
else
{
if
(
pass_count_
<
pass_num_
)
{
reader_
->
ReInit
();
re
turn
true
;
re
ader_
->
ReadNext
(
out
)
;
}
}
}
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
90084a25
...
...
@@ -52,8 +52,6 @@ class RandomDataGenerator : public framework::ReaderBase {
void
ReInit
()
override
{
return
;
}
bool
HasNext
()
const
override
{
return
true
;
}
private:
float
min_
;
float
max_
;
...
...
@@ -74,7 +72,7 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
90084a25
...
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include <thread>
#include "paddle/fluid/operators/reader/reader_op_registry.h"
#include "paddle/fluid/recordio/scanner.h"
...
...
@@ -35,17 +33,15 @@ class RecordIOFileReader : public framework::FileReader {
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
}
bool
HasNext
()
const
override
{
return
scanner_
.
HasNext
();
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
ThreadSafe
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mutex_
);
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
*
out
=
framework
::
ReadFromRecordIO
(
&
scanner_
,
dev_ctx_
);
}
else
{
*
out
=
framework
::
ReadFromRecordIO
(
scanner_
,
dev_ctx_
);
*
out
=
framework
::
ReadFromRecordIO
(
&
scanner_
,
dev_ctx_
);
}
}
...
...
@@ -66,7 +62,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
90084a25
...
...
@@ -30,35 +30,33 @@ class ShuffleReader : public framework::DecoratedReader {
std
::
random_device
device
;
seed_
=
device
();
}
Re
adIntoBuffers
();
Re
loadBuffer
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
out
->
clear
();
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
ReadIntoBuffers
();
ReloadBuffer
();
if
(
buffer_
.
empty
())
{
return
;
}
*
out
=
buffer_
[
iteration_pos_
++
];
}
bool
HasNext
()
const
override
{
return
iteration_pos_
<
buffer_
.
size
()
||
reader_
->
HasNext
();
*
out
=
buffer_
[
iteration_pos_
++
];
}
private:
void
Re
adIntoBuffers
()
{
void
Re
loadBuffer
()
{
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
iteration_pos_
=
0
;
for
(
size_t
i
=
0
;
i
<
buffer_size_
;
++
i
)
{
if
(
!
reader_
->
HasNext
())
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader_
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
break
;
}
buffer_
.
emplace_back
();
reader_
->
ReadNext
(
&
buffer_
.
back
());
buffer_
.
emplace_back
(
ins
);
}
std
::
mt19937
g
(
seed_
);
std
::
shuffle
(
buffer_
.
begin
(),
buffer_
.
end
(),
g
);
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
0 → 100644
浏览文件 @
90084a25
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
public:
ThreadedReader
(
ReaderBase
*
reader
,
bool
safe_mode
)
:
DecoratedReader
(
reader
),
safe_mode_
(
safe_mode
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
override
{
if
(
safe_mode_
)
{
PADDLE_THROW
(
"ThreadedReader::ReInit() is disabled when 'safe_mode' is true."
);
}
VLOG
(
5
)
<<
"ThreadedReader::ReInit() is invoked! It might be buggy in "
"multi-thread environment."
;
reader_
->
ReInit
();
}
private:
bool
safe_mode_
;
std
::
mutex
mutex_
;
};
class
CreateThreadedReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)))
.
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
bool
safe_mode
=
Attr
<
bool
>
(
"safe_mode"
);
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
(),
safe_mode
));
}
};
class
CreateThreadedReaderOpMaker
:
public
DecoratedReaderMakerBase
{
public:
CreateThreadedReaderOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
DecoratedReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
bool
>
(
"safe_mode"
,
"When 'safe_mode' is true, 'ReInit()' is disabled to avoid "
"unexpected bugs in multi-thread environment."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
reader
=
paddle
::
operators
::
reader
;
REGISTER_DECORATED_READER_OPERATOR
(
create_threaded_reader
,
reader
::
CreateThreadedReaderOp
,
reader
::
CreateThreadedReaderOpMaker
);
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
90084a25
...
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thread> // NOLINT
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
...
...
@@ -19,38 +21,23 @@ namespace paddle {
namespace
operators
{
namespace
reader
{
class
MultipleReader
:
public
framework
::
ReaderBase
{
public:
class
ThreadBufferMap
{
class
MultiFileReader
:
public
framework
::
ReaderBase
{
public:
std
::
vector
<
framework
::
LoDTensor
>&
operator
[](
const
std
::
thread
::
id
&
thread_id
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
buffer_
[
thread_id
];
}
void
Clear
()
{
buffer_
.
clear
();
}
private:
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
thread
::
id
,
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
};
MultipleReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
size_t
thread_num
)
:
file_names_
(
file_names
),
dims_
(
dims
)
{
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
size_t
thread_num
,
size_t
buffer_size
)
:
file_names_
(
file_names
),
dims_
(
dims
),
buffer_size_
(
buffer_size
)
{
prefetchers_
.
resize
(
thread_num
);
StartNewScheduler
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
bool
HasNext
()
const
override
;
void
ReInit
()
override
;
~
Multi
p
leReader
()
{
EndScheduler
();
}
~
Multi
Fi
leReader
()
{
EndScheduler
();
}
private:
bool
HasNext
();
void
StartNewScheduler
();
void
EndScheduler
();
void
ScheduleThreadFunc
();
...
...
@@ -60,39 +47,36 @@ class MultipleReader : public framework::ReaderBase {
std
::
vector
<
framework
::
DDim
>
dims_
;
std
::
thread
scheduler_
;
std
::
vector
<
std
::
thread
>
prefetchers_
;
size_t
buffer_size_
;
framework
::
Channel
<
size_t
>*
waiting_file_idx_
;
framework
::
Channel
<
size_t
>*
available_thread_idx_
;
framework
::
Channel
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
mutable
ThreadBufferMap
thread_buffer_map_
;
};
void
MultipleReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
out
->
clear
();
if
(
HasNext
())
{
buffer_
->
Receive
(
out
);
}
auto
&
thread_local_buffer
=
thread_buffer_map_
[
std
::
this_thread
::
get_id
()];
*
out
=
thread_local_buffer
;
thread_local_buffer
.
clear
();
}
bool
MultipleReader
::
HasNext
()
const
{
auto
&
thread_local_buffer
=
thread_buffer_map_
[
std
::
this_thread
::
get_id
()];
return
thread_local_buffer
.
empty
()
?
buffer_
->
Receive
(
&
thread_local_buffer
)
:
true
;
}
void
MultipleReader
::
ReInit
()
{
void
MultiFileReader
::
ReInit
()
{
EndScheduler
();
thread_buffer_map_
.
Clear
();
StartNewScheduler
();
}
void
MultipleReader
::
StartNewScheduler
()
{
bool
MultiFileReader
::
HasNext
()
{
while
(
!
buffer_
->
IsClosed
()
&&
!
buffer_
->
CanReceive
())
{
}
return
buffer_
->
CanReceive
();
}
void
MultiFileReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
waiting_file_idx_
=
framework
::
MakeChannel
<
size_t
>
(
file_names_
.
size
());
available_thread_idx_
=
framework
::
MakeChannel
<
size_t
>
(
thread_num
);
buffer_
=
framework
::
MakeChannel
<
std
::
vector
<
framework
::
LoDTensor
>>
(
thread_num
);
framework
::
MakeChannel
<
std
::
vector
<
framework
::
LoDTensor
>>
(
buffer_size_
);
for
(
size_t
i
=
0
;
i
<
file_names_
.
size
();
++
i
)
{
waiting_file_idx_
->
Send
(
&
i
);
...
...
@@ -105,7 +89,7 @@ void MultipleReader::StartNewScheduler() {
scheduler_
=
std
::
thread
([
this
]
{
ScheduleThreadFunc
();
});
}
void
Multi
p
leReader
::
EndScheduler
()
{
void
Multi
Fi
leReader
::
EndScheduler
()
{
available_thread_idx_
->
Close
();
buffer_
->
Close
();
waiting_file_idx_
->
Close
();
...
...
@@ -117,8 +101,8 @@ void MultipleReader::EndScheduler() {
delete
waiting_file_idx_
;
}
void
Multi
p
leReader
::
ScheduleThreadFunc
()
{
VLOG
(
5
)
<<
"Multi
p
leReader schedule thread starts."
;
void
Multi
Fi
leReader
::
ScheduleThreadFunc
()
{
VLOG
(
5
)
<<
"Multi
Fi
leReader schedule thread starts."
;
size_t
completed_thread_num
=
0
;
size_t
thread_idx
;
while
(
available_thread_idx_
->
Receive
(
&
thread_idx
))
{
...
...
@@ -150,17 +134,20 @@ void MultipleReader::ScheduleThreadFunc() {
p
.
join
();
}
}
VLOG
(
5
)
<<
"Multi
p
leReader schedule thread terminates."
;
VLOG
(
5
)
<<
"Multi
Fi
leReader schedule thread terminates."
;
}
void
Multi
p
leReader
::
PrefetchThreadFunc
(
std
::
string
file_name
,
void
Multi
Fi
leReader
::
PrefetchThreadFunc
(
std
::
string
file_name
,
size_t
thread_idx
)
{
VLOG
(
5
)
<<
"The prefetch thread of file '"
<<
file_name
<<
"' starts."
;
std
::
unique_ptr
<
framework
::
ReaderBase
>
reader
=
CreateReaderByFileName
(
file_name
,
dims_
);
while
(
reader
->
HasNext
()
)
{
while
(
true
)
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
break
;
}
try
{
buffer_
->
Send
(
&
ins
);
}
catch
(
paddle
::
platform
::
EnforceNotMet
e
)
{
...
...
@@ -197,11 +184,13 @@ class OpenFilesOp : public framework::OperatorBase {
const
auto
&
file_names
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"file_names"
);
PADDLE_ENFORCE
(
!
file_names
.
empty
(),
"No file to be read!"
);
const
size_t
thread_num
=
Attr
<
int
>
(
"thread_num"
);
const
size_t
buffer_size
=
Attr
<
int
>
(
"buffer_size"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
MultipleReader
(
file_names
,
RestoreShapes
(
shape_concat
,
ranks
),
thread_num
));
out
->
Reset
(
new
MultiFileReader
(
file_names
,
RestoreShapes
(
shape_concat
,
ranks
),
thread_num
,
buffer_size
));
}
};
...
...
@@ -212,11 +201,12 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"file_names"
,
"Files to be read."
);
AddAttr
<
int
>
(
"thread_num"
,
"The maximal concurrent prefetch thread number."
)
.
GreaterThan
(
0
);
AddAttr
<
int
>
(
"buffer_size"
,
"The size of prefetch buffer."
).
GreaterThan
(
0
);
AddComment
(
R"DOC(
OpenFiles Operator
An OpenFilesOp creates a Multi
p
leReader, which is able to
An OpenFilesOp creates a Multi
Fi
leReader, which is able to
read data multi-threaded from multiple files.
)DOC"
);
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
90084a25
...
...
@@ -252,7 +252,6 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"has_next"
,
&
framework
::
ReaderHolder
::
HasNext
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
ReInit
);
py
::
class_
<
Scope
>
(
m
,
"Scope"
,
""
)
...
...
paddle/fluid/pybind/recordio.cc
浏览文件 @
90084a25
...
...
@@ -39,7 +39,7 @@ class RecordIOWriter {
void
CompleteAppendTensor
()
{
auto
&
ctx
=
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
framework
::
WriteToRecordIO
(
writer_
,
tensors_
,
ctx
);
framework
::
WriteToRecordIO
(
&
writer_
,
tensors_
,
ctx
);
tensors_
.
clear
();
}
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
90084a25
...
...
@@ -21,8 +21,7 @@ from ..executor import global_scope
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'open_files'
,
'read_file'
,
'create_shuffle_reader'
,
'create_double_buffer_reader'
,
'create_multi_pass_reader'
'open_files'
,
'read_file'
,
'shuffle'
,
'double_buffer'
]
...
...
@@ -237,13 +236,9 @@ def monkey_patch_reader_methods(reader):
var
=
scope
.
find_var
(
reader
.
name
)
return
var
.
get_reader
()
def
eof
():
return
not
__get_reader__
().
has_next
()
def
reset
():
return
__get_reader__
().
reset
()
reader
.
eof
=
eof
reader
.
reset
=
reset
reader
.
stop_gradient
=
True
reader
.
persistable
=
True
...
...
@@ -283,7 +278,42 @@ def _copy_reader_create_op_(block, op):
return
new_op
def
open_recordio_file
(
filename
,
shapes
,
lod_levels
,
dtypes
):
def
open_recordio_file
(
filename
,
shapes
,
lod_levels
,
dtypes
,
pass_num
=
1
,
for_parallel
=
False
):
"""
Open a RecordIO file
This layer takes a RecordIO file to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from the given RecordIO file.
Args:
filename(str): The RecordIO file's name.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get RecordIO file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_recordio_file(
filename='./data.recordio',
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
"""
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
shape_concat
=
[]
ranks
=
[]
...
...
@@ -310,10 +340,63 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var
.
persistable
=
True
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
if
pass_num
>
1
:
main_prog_var
=
multi_pass
(
reader
=
main_prog_var
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_var
=
parallel
(
reader
=
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
def
open_files
(
filenames
,
thread_num
,
shapes
,
lod_levels
,
dtypes
):
def
open_files
(
filenames
,
shapes
,
lod_levels
,
dtypes
,
thread_num
,
buffer_size
=
None
,
pass_num
=
1
,
for_parallel
=
False
):
"""
Open files
This layer takes a list of files to read from and returns a Reader Variable.
Via the Reader Variable, we can get data from given files. All files must
have name suffixs to indicate their formats, e.g., '*.recordio'.
Args:
filenames(list): The list of file names.
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int): The size of prefetch buffer.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Returns:
Variable: A Reader Variable via which we can get file data.
Examples:
.. code-block:: python
reader = fluid.layers.io.open_files(filenames=['./data1.recordio',
'./data2.recordio'],
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
if
buffer_size
is
None
:
buffer_size
=
thread_num
if
isinstance
(
filenames
,
basestring
):
filenames
=
[
filenames
]
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
shape_concat
=
[]
ranks
=
[]
...
...
@@ -322,29 +405,36 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
shape_concat
.
extend
(
shape
)
ranks
.
append
(
len
(
shape
))
var_name
=
unique_name
(
'multiple_reader'
)
multi_file_reader_name
=
unique_name
(
'multi_file_reader'
)
startup_blk
=
default_startup_program
().
current_block
()
startup_
var
=
startup_blk
.
create_var
(
name
=
va
r_name
)
startup_
reader
=
startup_blk
.
create_var
(
name
=
multi_file_reade
r_name
)
startup_blk
.
append_op
(
type
=
'open_files'
,
outputs
=
{
'Out'
:
[
startup_
va
r
]},
outputs
=
{
'Out'
:
[
startup_
reade
r
]},
attrs
=
{
'shape_concat'
:
shape_concat
,
'lod_levels'
:
lod_levels
,
'ranks'
:
ranks
,
'file_names'
:
filenames
,
'thread_num'
:
thread_num
'thread_num'
:
thread_num
,
'buffer_size'
:
buffer_size
})
startup_var
.
desc
.
set_dtypes
(
dtypes
)
startup_var
.
persistable
=
True
main_prog_var
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
startup_reader
.
desc
.
set_dtypes
(
dtypes
)
startup_reader
.
persistable
=
True
main_prog_reader
=
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_reader
)
if
pass_num
>
1
:
main_prog_reader
=
multi_pass
(
reader
=
main_prog_reader
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_reader
=
parallel
(
reader
=
main_prog_reader
)
return
monkey_patch_reader_methods
(
main_prog_reader
)
def
__create_decorated_reader__
(
op_type
,
reader
,
attrs
):
def
__create_shared_decorated_reader__
(
op_type
,
reader
,
attrs
):
var_name
=
unique_name
(
op_type
)
startup_blk
=
default_startup_program
().
current_block
()
startup_var
=
startup_blk
.
create_var
(
name
=
var_name
)
...
...
@@ -360,22 +450,41 @@ def __create_decorated_reader__(op_type, reader, attrs):
return
monkey_patch_reader_methods
(
main_prog_var
)
def
create_shuffle_reader
(
reader
,
buffer_size
):
return
__create_decorated_reader__
(
'create_shuffle_reader'
,
reader
,
{
'buffer_size'
:
int
(
buffer_size
)})
def
__create_unshared_decorated_reader__
(
op_type
,
reader
,
attrs
):
new_reader_name
=
unique_name
(
op_type
)
main_blk
=
default_main_program
().
current_block
()
new_reader
=
main_blk
.
create_var
(
name
=
new_reader_name
)
main_blk
.
append_op
(
type
=
op_type
,
inputs
=
{
'UnderlyingReader'
:
reader
},
outputs
=
{
'Out'
:
[
new_reader
]},
attrs
=
attrs
)
new_reader
.
persistable
=
True
new_reader
.
stop_gradient
=
True
return
monkey_patch_reader_methods
(
new_reader
)
def
shuffle
(
reader
,
buffer_size
):
return
__create_unshared_decorated_reader__
(
'create_shuffle_reader'
,
reader
,
{
'buffer_size'
:
int
(
buffer_size
)})
def
create_double_buffer_read
er
(
reader
,
place
=
None
):
def
double_buff
er
(
reader
,
place
=
None
):
attrs
=
dict
()
if
place
is
not
None
:
attrs
[
'place'
]
=
str
(
place
).
upper
()
return
__create_decorated_reader__
(
'create_double_buffer_reader'
,
reader
,
attrs
)
return
__create_unshared_decorated_reader__
(
'create_double_buffer_reader'
,
reader
,
attrs
)
def
multi_pass
(
reader
,
pass_num
):
return
__create_shared_decorated_reader__
(
'create_multi_pass_reader'
,
reader
,
{
'pass_num'
:
int
(
pass_num
)})
def
create_multi_pass_reader
(
reader
,
pass_num
):
return
__create_
decorated_reader__
(
'create_multi_pass
_reader'
,
reader
,
{
'pass_num'
:
int
(
pass_num
)
})
def
parallel
(
reader
):
return
__create_
shared_decorated_reader__
(
'create_threaded
_reader'
,
reader
,
{
})
def
read_file
(
file_obj
):
...
...
python/paddle/fluid/tests/unittests/test_multi
p
le_reader.py
→
python/paddle/fluid/tests/unittests/test_multi
_fi
le_reader.py
浏览文件 @
90084a25
...
...
@@ -61,8 +61,12 @@ class TestMultipleReader(unittest.TestCase):
exe
.
run
(
fluid
.
default_startup_program
())
batch_count
=
0
while
not
data_files
.
eof
():
while
True
:
try
:
img_val
,
=
exe
.
run
(
fetch_list
=
[
img
])
except
fluid
.
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"There is no next data."
,
ex
.
message
)
break
batch_count
+=
1
self
.
assertLessEqual
(
img_val
.
shape
[
0
],
self
.
batch_size
)
data_files
.
reset
()
...
...
python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
浏览文件 @
90084a25
...
...
@@ -44,7 +44,7 @@ class TestMultipleReader(unittest.TestCase):
shapes
=
[(
-
1
,
784
),
(
-
1
,
1
)],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
data_file
=
fluid
.
layers
.
create_multi_pass_reader
(
data_file
=
fluid
.
layers
.
io
.
multi_pass
(
reader
=
data_file
,
pass_num
=
self
.
pass_num
)
img
,
label
=
fluid
.
layers
.
read_file
(
data_file
)
...
...
@@ -57,8 +57,12 @@ class TestMultipleReader(unittest.TestCase):
exe
.
run
(
fluid
.
default_startup_program
())
batch_count
=
0
while
not
data_file
.
eof
():
while
True
:
try
:
img_val
,
=
exe
.
run
(
fetch_list
=
[
img
])
except
fluid
.
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"There is no next data."
,
ex
.
message
)
break
batch_count
+=
1
self
.
assertLessEqual
(
img_val
.
shape
[
0
],
self
.
batch_size
)
data_file
.
reset
()
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
90084a25
...
...
@@ -26,11 +26,14 @@ def simple_fc_net(use_feed):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
else
:
reader
=
fluid
.
layers
.
open_
recordio_file
(
filename
=
'./mnist.recordio'
,
reader
=
fluid
.
layers
.
open_
files
(
filename
s
=
[
'./mnist.recordio'
]
,
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
dtypes
=
[
'float32'
,
'int64'
],
thread_num
=
1
,
for_parallel
=
True
)
reader
=
fluid
.
layers
.
io
.
double_buffer
(
reader
)
img
,
label
=
fluid
.
layers
.
read_file
(
reader
)
hidden
=
img
for
_
in
xrange
(
4
):
...
...
@@ -51,11 +54,14 @@ def fc_with_batchnorm(use_feed):
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
else
:
reader
=
fluid
.
layers
.
open_
recordio_file
(
filename
=
'./mnist.recordio'
,
reader
=
fluid
.
layers
.
open_
files
(
filename
s
=
[
'mnist.recordio'
]
,
shapes
=
[[
-
1
,
784
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
dtypes
=
[
'float32'
,
'int64'
],
thread_num
=
1
,
for_parallel
=
True
)
reader
=
fluid
.
layers
.
io
.
double_buffer
(
reader
)
img
,
label
=
fluid
.
layers
.
read_file
(
reader
)
hidden
=
img
...
...
python/paddle/fluid/tests/unittests/test_recordio_reader.py
浏览文件 @
90084a25
...
...
@@ -65,8 +65,13 @@ class TestRecordIO(unittest.TestCase):
# train a pass
batch_id
=
0
while
not
data_file
.
eof
():
while
True
:
try
:
tmp
,
=
exe
.
run
(
fetch_list
=
[
avg_loss
])
except
fluid
.
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"There is no next data."
,
ex
.
message
)
break
avg_loss_np
.
append
(
tmp
)
batch_id
+=
1
data_file
.
reset
()
...
...
@@ -74,8 +79,8 @@ class TestRecordIO(unittest.TestCase):
self
.
assertLess
(
avg_loss_np
[
-
1
],
avg_loss_np
[
0
])
def
test_shuffle_reader
(
self
):
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
create_shuffle_reader
(
reader
,
buffer_size
=
200
))
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
io
.
shuffle
(
reader
,
buffer_size
=
200
))
def
test_double_buffer_reader
(
self
):
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
create_double_buffer_read
er
(
reader
,
self
.
test_main
(
decorator_callback
=
lambda
reader
:
fluid
.
layers
.
io
.
double_buff
er
(
reader
,
place
=
'cuda:0'
if
fluid
.
core
.
is_compiled_with_cuda
()
else
'cpu'
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录