Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7c041e48
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7c041e48
编写于
3月 21, 2018
作者:
F
fengjiayi
提交者:
GitHub
3月 21, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9182 from JiayiFeng/dev_MultipleReader
Multi-threaded reader in C++
上级
e4bd63d0
2532b922
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
348 addition
and
4 deletion
+348
-4
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+1
-0
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+4
-1
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+212
-0
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+15
-0
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+7
-2
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+32
-1
python/paddle/fluid/tests/unittests/.gitignore
python/paddle/fluid/tests/unittests/.gitignore
+3
-0
python/paddle/fluid/tests/unittests/test_multiple_reader.py
python/paddle/fluid/tests/unittests/test_multiple_reader.py
+74
-0
未找到文件。
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
7c041e48
...
...
@@ -15,6 +15,7 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE
)
endfunction
()
reader_library
(
open_files_op SRCS open_files_op.cc
)
reader_library
(
create_random_data_generator_op SRCS create_random_data_generator_op.cc
)
reader_library
(
create_shuffle_reader_op SRCS create_shuffle_reader_op.cc
)
reader_library
(
create_batch_reader_op SRCS create_batch_reader_op.cc
)
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
7c041e48
...
...
@@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
};
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
if
(
local_buffer_
.
payloads_
.
empty
())
{
buffer_
->
Receive
(
&
local_buffer_
);
}
*
out
=
local_buffer_
.
payloads_
;
local_buffer_
.
payloads_
.
clear
();
if
(
local_buffer_
.
ctx_
)
{
...
...
paddle/fluid/operators/reader/open_files_op.cc
0 → 100644
浏览文件 @
7c041e48
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
MultipleReader
:
public
framework
::
ReaderBase
{
public:
MultipleReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
size_t
thread_num
)
:
file_names_
(
file_names
),
dims_
(
dims
)
{
prefetchers_
.
resize
(
thread_num
);
StartNewScheduler
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
bool
HasNext
()
const
override
;
void
ReInit
()
override
;
~
MultipleReader
()
{
EndScheduler
();
}
private:
void
StartNewScheduler
();
void
EndScheduler
();
void
ScheduleThreadFunc
();
void
PrefetchThreadFunc
(
std
::
string
file_name
,
size_t
thread_idx
);
std
::
vector
<
std
::
string
>
file_names_
;
std
::
vector
<
framework
::
DDim
>
dims_
;
std
::
thread
scheduler_
;
std
::
vector
<
std
::
thread
>
prefetchers_
;
framework
::
Channel
<
size_t
>*
waiting_file_idx_
;
framework
::
Channel
<
size_t
>*
available_thread_idx_
;
framework
::
Channel
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
mutable
std
::
vector
<
framework
::
LoDTensor
>
local_buffer_
;
};
void
MultipleReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
HasNext
())
{
PADDLE_THROW
(
"There is no next data!"
);
}
if
(
local_buffer_
.
empty
())
{
buffer_
->
Receive
(
&
local_buffer_
);
}
*
out
=
local_buffer_
;
local_buffer_
.
clear
();
}
bool
MultipleReader
::
HasNext
()
const
{
return
local_buffer_
.
empty
()
?
buffer_
->
Receive
(
&
local_buffer_
)
:
true
;
}
void
MultipleReader
::
ReInit
()
{
EndScheduler
();
local_buffer_
.
clear
();
StartNewScheduler
();
}
void
MultipleReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
waiting_file_idx_
=
framework
::
MakeChannel
<
size_t
>
(
file_names_
.
size
());
available_thread_idx_
=
framework
::
MakeChannel
<
size_t
>
(
thread_num
);
buffer_
=
framework
::
MakeChannel
<
std
::
vector
<
framework
::
LoDTensor
>>
(
thread_num
);
for
(
size_t
i
=
0
;
i
<
file_names_
.
size
();
++
i
)
{
waiting_file_idx_
->
Send
(
&
i
);
}
waiting_file_idx_
->
Close
();
for
(
size_t
i
=
0
;
i
<
thread_num
;
++
i
)
{
available_thread_idx_
->
Send
(
&
i
);
}
scheduler_
=
std
::
thread
([
this
]
{
ScheduleThreadFunc
();
});
}
void
MultipleReader
::
EndScheduler
()
{
available_thread_idx_
->
Close
();
buffer_
->
Close
();
waiting_file_idx_
->
Close
();
if
(
scheduler_
.
joinable
())
{
scheduler_
.
join
();
}
delete
buffer_
;
delete
available_thread_idx_
;
delete
waiting_file_idx_
;
}
void
MultipleReader
::
ScheduleThreadFunc
()
{
VLOG
(
5
)
<<
"MultipleReader schedule thread starts."
;
size_t
completed_thread_num
=
0
;
size_t
thread_idx
;
while
(
available_thread_idx_
->
Receive
(
&
thread_idx
))
{
std
::
thread
&
prefetcher
=
prefetchers_
[
thread_idx
];
if
(
prefetcher
.
joinable
())
{
prefetcher
.
join
();
}
size_t
file_idx
;
if
(
waiting_file_idx_
->
Receive
(
&
file_idx
))
{
// Still have files to read. Start a new prefetch thread.
std
::
string
file_name
=
file_names_
[
file_idx
];
prefetcher
=
std
::
thread
([
this
,
file_name
,
thread_idx
]
{
PrefetchThreadFunc
(
file_name
,
thread_idx
);
});
}
else
{
// No more file to read.
++
completed_thread_num
;
if
(
completed_thread_num
==
prefetchers_
.
size
())
{
buffer_
->
Close
();
break
;
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for
(
auto
&
p
:
prefetchers_
)
{
if
(
p
.
joinable
())
{
p
.
join
();
}
}
VLOG
(
5
)
<<
"MultipleReader schedule thread terminates."
;
}
void
MultipleReader
::
PrefetchThreadFunc
(
std
::
string
file_name
,
size_t
thread_idx
)
{
VLOG
(
5
)
<<
"The prefetch thread of file '"
<<
file_name
<<
"' starts."
;
std
::
unique_ptr
<
framework
::
ReaderBase
>
reader
=
CreateReaderByFileName
(
file_name
,
dims_
);
while
(
reader
->
HasNext
())
{
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
if
(
!
buffer_
->
Send
(
&
ins
))
{
VLOG
(
5
)
<<
"WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<<
file_name
<<
"' will terminate."
;
break
;
}
}
if
(
!
available_thread_idx_
->
Send
(
&
thread_idx
))
{
VLOG
(
5
)
<<
"WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx."
;
}
VLOG
(
5
)
<<
"The prefetch thread of file '"
<<
file_name
<<
"' terminates."
;
}
class
OpenFilesOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
int
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
const
auto
&
file_names
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"file_names"
);
PADDLE_ENFORCE
(
!
file_names
.
empty
(),
"No file to be read!"
);
const
size_t
thread_num
=
Attr
<
int
>
(
"thread_num"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
MultipleReader
(
file_names
,
RestoreShapes
(
shape_concat
,
ranks
),
thread_num
));
}
};
class
OpenFilesOpMaker
:
public
FileReaderMakerBase
{
public:
OpenFilesOpMaker
(
OpProto
*
op_proto
,
OpAttrChecker
*
op_checker
)
:
FileReaderMakerBase
(
op_proto
,
op_checker
)
{
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"file_names"
,
"Files to be read."
);
AddAttr
<
int
>
(
"thread_num"
,
"The maximal concurrent prefetch thread number."
)
.
GreaterThan
(
0
);
AddComment
(
R"DOC(
OpenFiles Operator
An OpenFilesOp creates a MultipleReader, which is able to
read data multi-threaded from multiple files.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
reader
=
paddle
::
operators
::
reader
;
REGISTER_FILE_READER_OPERATOR
(
open_files
,
reader
::
OpenFilesOp
,
reader
::
OpenFilesOpMaker
);
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
7c041e48
...
...
@@ -36,6 +36,21 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
return
regs
;
}
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
size_t
separator_pos
=
file_name
.
find_last_of
(
kFileFormatSeparator
);
PADDLE_ENFORCE_NE
(
separator_pos
,
std
::
string
::
npos
,
"File name illegal! A legal file name should be like: "
"[file_name].[file_format] (e.g., 'data_file.recordio')."
);
std
::
string
filetype
=
file_name
.
substr
(
separator_pos
+
1
);
auto
itor
=
FileReaderRegistry
().
find
(
filetype
);
PADDLE_ENFORCE
(
itor
!=
FileReaderRegistry
().
end
(),
"No file reader registered for '%s' format."
,
filetype
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
file_name
,
dims
);
return
std
::
unique_ptr
<
framework
::
ReaderBase
>
(
reader
);
}
FileReaderMakerBase
::
FileReaderMakerBase
(
framework
::
OpProtoAndCheckerMaker
::
OpProto
*
op_proto
,
framework
::
OpAttrChecker
*
op_checker
)
...
...
paddle/fluid/operators/reader/reader_op_registry.h
浏览文件 @
7c041e48
...
...
@@ -21,6 +21,8 @@ namespace paddle {
namespace
operators
{
namespace
reader
{
static
constexpr
char
kFileFormatSeparator
[]
=
"."
;
using
FileReaderCreator
=
std
::
function
<
framework
::
ReaderBase
*
(
const
std
::
string
&
,
const
std
::
vector
<
framework
::
DDim
>&
)
>
;
...
...
@@ -29,12 +31,15 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template
<
typename
Reader
>
int
RegisterFileReader
(
const
std
::
string
&
filetype
)
{
FileReaderRegistry
()[
filetype
]
=
[](
const
std
::
string
&
fn
,
const
std
::
vector
<
paddle
::
framework
::
DDim
>&
dim
)
{
return
new
Reader
(
fn
,
dim
);
const
std
::
string
&
fn
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
return
new
Reader
(
fn
,
dim
s
);
};
return
0
;
}
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
);
extern
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
);
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
7c041e48
...
...
@@ -21,7 +21,8 @@ from ..executor import global_scope
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'read_file'
,
'create_shuffle_reader'
,
'create_double_buffer_reader'
'open_files'
,
'read_file'
,
'create_shuffle_reader'
,
'create_double_buffer_reader'
]
...
...
@@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var
)
def
open_files
(
filenames
,
thread_num
,
shapes
,
lod_levels
,
dtypes
):
dtypes
=
[
convert_np_dtype_to_dtype_
(
dt
)
for
dt
in
dtypes
]
shape_concat
=
[]
ranks
=
[]
for
shape
in
shapes
:
shape_concat
.
extend
(
shape
)
ranks
.
append
(
len
(
shape
))
var_name
=
unique_name
(
'multiple_reader'
)
startup_blk
=
default_startup_program
().
current_block
()
startup_var
=
startup_blk
.
create_var
(
name
=
var_name
)
startup_blk
.
append_op
(
type
=
'open_files'
,
outputs
=
{
'Out'
:
[
startup_var
]},
attrs
=
{
'shape_concat'
:
shape_concat
,
'lod_levels'
:
lod_levels
,
'ranks'
:
ranks
,
'file_names'
:
filenames
,
'thread_num'
:
thread_num
})
startup_var
.
desc
.
set_dtypes
(
dtypes
)
startup_var
.
persistable
=
True
return
_copy_reader_var_
(
default_main_program
().
current_block
(),
startup_var
)
def
__create_decorated_reader__
(
op_type
,
reader
,
attrs
):
var_name
=
unique_name
(
op_type
)
startup_blk
=
default_startup_program
().
current_block
()
...
...
python/paddle/fluid/tests/unittests/.gitignore
浏览文件 @
7c041e48
mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
python/paddle/fluid/tests/unittests/test_multiple_reader.py
0 → 100644
浏览文件 @
7c041e48
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
paddle.v2
as
paddle
import
paddle.v2.dataset.mnist
as
mnist
from
shutil
import
copyfile
class
TestMultipleReader
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
64
# Convert mnist to recordio file
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
reader
=
paddle
.
batch
(
mnist
.
train
(),
batch_size
=
self
.
batch_size
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
# order is image and label
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
784
]),
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
),
],
place
=
fluid
.
CPUPlace
())
self
.
num_batch
=
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
'./mnist_0.recordio'
,
reader
,
feeder
)
copyfile
(
'./mnist_0.recordio'
,
'./mnist_1.recordio'
)
copyfile
(
'./mnist_0.recordio'
,
'./mnist_2.recordio'
)
def
main
(
self
,
thread_num
):
file_list
=
[
'./mnist_0.recordio'
,
'./mnist_1.recordio'
,
'./mnist_2.recordio'
]
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
data_files
=
fluid
.
layers
.
open_files
(
filenames
=
file_list
,
thread_num
=
thread_num
,
shapes
=
[(
-
1
,
784
),
(
-
1
,
1
)],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
img
,
label
=
fluid
.
layers
.
read_file
(
data_files
)
if
fluid
.
core
.
is_compiled_with_cuda
():
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
batch_count
=
0
while
not
data_files
.
eof
():
img_val
,
=
exe
.
run
(
fetch_list
=
[
img
])
batch_count
+=
1
self
.
assertLessEqual
(
img_val
.
shape
[
0
],
self
.
batch_size
)
data_files
.
reset
()
self
.
assertEqual
(
batch_count
,
self
.
num_batch
*
3
)
def
test_main
(
self
):
self
.
main
(
thread_num
=
3
)
# thread number equals to file number
self
.
main
(
thread_num
=
10
)
# thread number is larger than file number
self
.
main
(
thread_num
=
2
)
# thread number is less than file number
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录