Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
26ae6111
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
26ae6111
编写于
7月 10, 2018
作者:
F
fengjiayi
提交者:
GitHub
7月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12051 from JiayiFeng/dev_reader_ResetAll
[WIP] Dev reader reset all
上级
10fbb831
d55919c6
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
283 addition
and
219 deletion
+283
-219
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-0
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+46
-14
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+76
-20
paddle/fluid/framework/reader_test.cc
paddle/fluid/framework/reader_test.cc
+52
-0
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+0
-1
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+18
-6
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+6
-6
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+16
-12
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+10
-10
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+13
-8
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+5
-7
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+5
-16
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+14
-4
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+0
-79
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+13
-16
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+2
-2
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+5
-6
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+0
-11
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
26ae6111
...
...
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
cc_library
(
reader SRCS reader.cc DEPS lod_tensor ddim
)
cc_test
(
reader_test SRCS reader_test.cc DEPS reader
)
cc_test
(
variable_test SRCS variable_test.cc
)
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
26ae6111
...
...
@@ -13,29 +13,61 @@
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <deque>
namespace
paddle
{
namespace
framework
{
ReaderBase
::~
ReaderBase
()
{}
FileReader
::
FileReader
(
const
std
::
vector
<
DDim
>
&
dims
)
:
dims_
(
dims
)
{}
void
FileReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
void
ReaderBase
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
PADDLE_ENFORCE_EQ
(
status_
,
ReaderStatus
::
kRunning
);
ReadNextImpl
(
out
);
if
(
out
->
empty
())
{
return
;
}
}
PADDLE_ENFORCE_EQ
(
out
->
size
(),
dims_
.
size
());
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
auto
&
actual
=
(
*
out
)[
i
].
dims
();
auto
&
expect
=
dims_
[
i
];
void
ReaderBase
::
InsertDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>
&
decorated_reader
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mu_
);
decorated_readers_
.
emplace_back
(
decorated_reader
);
}
PADDLE_ENFORCE_EQ
(
actual
.
size
(),
expect
.
size
());
for
(
int
j
=
0
;
j
<
actual
.
size
();
++
j
)
{
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
std
::
unordered_set
<
ReaderBase
*>
ReaderBase
::
GetEndPoints
()
{
std
::
unordered_set
<
ReaderBase
*>
result
;
std
::
deque
<
ReaderBase
*>
queue
;
queue
.
emplace_back
(
this
);
while
(
!
queue
.
empty
())
{
// BFS search
auto
*
front
=
queue
.
front
();
queue
.
pop_front
();
if
(
front
->
decorated_readers_
.
empty
())
{
result
.
emplace
(
front
);
}
else
{
for
(
auto
&
reader
:
front
->
decorated_readers_
)
{
if
(
auto
*
reader_ptr
=
reader
.
lock
().
get
())
{
queue
.
emplace_back
(
reader_ptr
);
}
}
}
}
return
result
;
}
void
ReaderBase
::
Shutdown
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
status_
!=
ReaderStatus
::
kStopped
)
{
ShutdownImpl
();
status_
=
ReaderStatus
::
kStopped
;
}
}
void
ReaderBase
::
Start
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
status_
!=
ReaderStatus
::
kRunning
)
{
StartImpl
();
status_
=
ReaderStatus
::
kRunning
;
}
}
ReaderBase
::~
ReaderBase
()
{
Shutdown
();
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
26ae6111
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
...
...
@@ -24,61 +25,116 @@
namespace
paddle
{
namespace
framework
{
enum
ReaderStatus
{
kRunning
,
kStopped
};
class
ReaderBase
{
public:
virtual
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
);
void
Shutdown
();
virtual
void
ReInit
()
=
0
;
void
Start
();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std
::
unordered_set
<
ReaderBase
*>
GetEndPoints
();
virtual
~
ReaderBase
();
protected:
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
virtual
void
ShutdownImpl
()
{}
virtual
void
StartImpl
()
{}
ReaderStatus
status_
{
kRunning
};
mutable
std
::
mutex
mu_
;
private:
friend
class
DecoratedReader
;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void
InsertDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
decorated_reader
);
// A set of which readers that decorated this reader.
std
::
vector
<
std
::
weak_ptr
<
ReaderBase
>>
decorated_readers_
;
};
class
DecoratedReader
:
public
ReaderBase
{
class
DecoratedReader
:
public
ReaderBase
,
public
std
::
enable_shared_from_this
<
DecoratedReader
>
{
public:
explicit
DecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
ReaderBase
(),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
void
RegisterDecorateChain
()
{
reader_
->
InsertDecoratedReader
(
shared_from_this
());
}
protected:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
class
FileReader
:
public
ReaderBase
{
public:
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
dims
);
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
override
;
void
ShutdownImpl
()
override
{
reader_
->
Shutdown
();
}
protected:
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
void
StartImpl
()
override
{
reader_
->
Start
();
}
private:
std
::
vector
<
DDim
>
dims_
;
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
// FileReader is just a conceptual class.
class
FileReader
:
public
ReaderBase
{};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class
ReaderHolder
{
public:
void
Reset
(
ReaderBase
*
reader
)
{
reader_
.
reset
(
reader
);
}
template
<
typename
T
>
void
Reset
(
const
std
::
shared_ptr
<
T
>&
reader
)
{
auto
reader_base
=
std
::
dynamic_pointer_cast
<
ReaderBase
>
(
reader
);
PADDLE_ENFORCE_NOT_NULL
(
reader_base
);
reader_
=
reader_base
;
}
std
::
shared_ptr
<
ReaderBase
>
Get
()
const
{
return
reader_
;
}
const
std
::
shared_ptr
<
ReaderBase
>&
Get
()
const
{
return
reader_
;
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
{
void
ResetAll
()
{
auto
end_readers
=
reader_
->
GetEndPoints
();
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Shutdown
();
}
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Start
();
}
}
void
Shutdown
()
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
Shutdown
();
}
void
Start
()
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
ReIni
t
();
reader_
->
Star
t
();
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
private:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
template
<
typename
T
,
typename
...
ARGS
>
inline
std
::
shared_ptr
<
DecoratedReader
>
MakeDecoratedReader
(
ARGS
&&
...
args
)
{
std
::
shared_ptr
<
DecoratedReader
>
reader
(
new
T
(
std
::
forward
<
ARGS
>
(
args
)...));
reader
->
RegisterDecorateChain
();
return
reader
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/reader_test.cc
0 → 100644
浏览文件 @
26ae6111
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class
StubDecoratedReader
:
public
paddle
::
framework
::
DecoratedReader
{
public:
explicit
StubDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>
&
reader
)
:
DecoratedReader
(
reader
)
{}
void
ReadNextImpl
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
};
class
StubRootReader
:
public
paddle
::
framework
::
ReaderBase
{
public:
void
ReadNextImpl
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
};
TEST
(
READER
,
decorate_chain
)
{
auto
root
=
std
::
make_shared
<
StubRootReader
>
();
auto
end_point1
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
auto
end_point2
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
{
auto
endpoints
=
root
->
GetEndPoints
();
ASSERT_EQ
(
endpoints
.
size
(),
2U
);
ASSERT_NE
(
endpoints
.
count
(
end_point1
.
get
()),
0
);
ASSERT_NE
(
endpoints
.
count
(
end_point2
.
get
()),
0
);
}
{
auto
end_point3
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
3U
);
}
{
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
2U
);
}
}
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
26ae6111
...
...
@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library
(
create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc
)
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_threaded_reader_op SRCS create_threaded_reader_op.cc
)
reader_library
(
create_custom_reader_op SRCS create_custom_reader_op.cc
)
reader_library
(
create_py_reader_op SRCS create_py_reader_op.cc
)
...
...
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -20,15 +20,19 @@ namespace reader {
class
BatchReader
:
public
framework
::
DecoratedReader
{
public:
BatchReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
batch_size
)
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
)
{
BatchReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
batch_size
,
bool
discard_leftover
)
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
),
discard_leftover_
(
discard_leftover
)
{
buffer_
.
reserve
(
batch_size_
);
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
int
batch_size_
;
bool
discard_leftover_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
};
...
...
@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
BatchReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"batch_size"
)));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BatchReader
>
(
underlying_reader
,
Attr
<
int
>
(
"batch_size"
),
Attr
<
bool
>
(
"discard_leftover"
)));
}
};
...
...
@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr
<
int
>
(
"batch_size"
,
"How many instances the batch reader yields each time."
)
.
GreaterThan
(
0
);
AddAttr
<
bool
>
(
"discard_leftover"
,
"If true, the leftover instances that are not enough for a "
"new batch will be discarded."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
CreateBatchReader Operator
...
...
@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void
BatchReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
BatchReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
buffer_
.
clear
();
buffer_
.
reserve
(
batch_size_
);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
...
...
@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
break
;
}
}
if
(
discard_leftover_
&&
buffer_
.
size
()
<
batch_size_
)
{
buffer_
.
clear
();
}
// Concat instances
out
->
clear
();
if
(
buffer_
.
empty
())
{
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
source_var_names_
(
source_var_names
),
sink_var_names_
(
sink_var_names
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
const
framework
::
ProgramDesc
program_
;
...
...
@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
CustomReader
(
underlying_reader
.
Get
()
,
*
sub_block
,
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
CustomReader
>
(
underlying_reader
,
*
sub_block
,
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
}
};
...
...
@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
}
};
void
CustomReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
CustomReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
out
->
clear
();
std
::
vector
<
framework
::
LoDTensor
>
underlying_outs
;
reader_
->
ReadNext
(
&
underlying_outs
);
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static
constexpr
size_t
kCacheSize
=
5
;
static
constexpr
size_t
kCacheSize
=
3
;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static
constexpr
size_t
kChannelSize
=
3
;
// kCacheSize - 2
static
constexpr
size_t
kChannelSize
=
1
;
// kCacheSize - 2
class
DoubleBufferReader
:
public
framework
::
DecoratedReader
{
public:
...
...
@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
private:
void
ShutdownImpl
()
override
{
EndPrefetcher
();
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
StartPrefetcher
();
}
void
StartPrefetcher
()
{
channel_
=
new
reader
::
BlockingQueue
<
size_t
>
(
kChannelSize
);
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
...
...
@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
}
out
->
Reset
(
new
DoubleBufferReader
(
underlying_reader
.
Get
(),
place
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
DoubleBufferReader
>
(
underlying_reader
,
place
));
}
};
...
...
@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
DoubleBufferReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
size_t
cached_tensor_id
;
if
(
channel_
->
Receive
(
&
cached_tensor_id
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
...
...
@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
}
}
void
DoubleBufferReader
::
ReInit
()
{
reader_
->
ReInit
();
EndPrefetcher
();
StartPrefetcher
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
size_t
cached_tensor_id
=
0
;
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader {
MultiPassReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
pass_num
)
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
reader_
->
ReadNext
(
out
);
if
(
out
->
empty
())
{
if
(
out
->
empty
()
&&
pass_count_
<
pass_num_
-
1
)
{
reader_
->
Shutdown
();
reader_
->
Start
();
reader_
->
ReadNext
(
out
);
++
pass_count_
;
if
(
pass_count_
<
pass_num_
)
{
reader_
->
ReInit
();
reader_
->
ReadNext
(
out
);
}
}
}
void
ReInit
()
override
{
private:
void
StartImpl
()
override
{
pass_count_
=
0
;
reader_
->
ReIni
t
();
reader_
->
Star
t
();
}
private:
int
pass_num_
;
mutable
int
pass_count_
;
};
...
...
@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
int
pass_num
=
Attr
<
int
>
(
"pass_num"
);
out
->
Reset
(
new
MultiPassReader
(
underlying_reader
.
Get
(),
pass_num
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
MultiPassReader
>
(
underlying_reader
,
pass_num
));
}
};
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -19,22 +19,27 @@ namespace paddle {
namespace
operators
{
namespace
reader
{
class
PyReader
:
public
framework
::
ReaderBase
{
class
PyReader
:
public
framework
::
FileReader
{
public:
explicit
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
)
{
explicit
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
)
:
framework
::
FileReader
()
{
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
queue_
=
queue
;
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
bool
success
;
*
out
=
queue_
->
Pop
(
&
success
);
if
(
!
success
)
out
->
clear
();
}
void
ReInit
()
override
{}
private:
void
ShutdownImpl
()
override
{
/* TODO */
}
void
StartImpl
()
override
{
/* TODO */
}
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue_
;
};
...
...
@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase {
const
std
::
string
&
queue_name
=
Input
(
"blocking_queue"
);
auto
*
queue_holder_var
=
scope
.
FindVar
(
queue_name
);
PADDLE_ENFORCE
(
queue_holder_var
!=
nullptr
,
PADDLE_ENFORCE
_NOT_NULL
(
queue_holder_var
,
"No LoDTensorBlockingQueueHolder variable with name %s found"
,
queue_name
);
auto
*
queue_holder
=
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
out
->
Reset
(
new
PyReader
(
queue_holder
->
GetQueue
()));
out
->
Reset
(
std
::
make_shared
<
PyReader
>
(
queue_holder
->
GetQueue
()));
}
};
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
26ae6111
...
...
@@ -19,11 +19,11 @@ namespace operators {
namespace
reader
{
template
<
typename
T
>
class
RandomDataGenerator
:
public
framework
::
ReaderBase
{
class
RandomDataGenerator
:
public
framework
::
FileReader
{
public:
RandomDataGenerator
(
const
std
::
vector
<
framework
::
DDim
>&
shapes
,
float
low
,
float
high
)
:
framework
::
ReaderBase
(),
low_
(
low
),
high_
(
high
),
shapes_
(
shapes
)
{
:
framework
::
FileReader
(),
low_
(
low
),
high_
(
high
),
shapes_
(
shapes
)
{
PADDLE_ENFORCE_LE
(
low
,
high
,
"'low' shouldn't be greater than 'high'.(%f vs %f)"
,
low
,
high
);
...
...
@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
dist_
=
std
::
uniform_real_distribution
<
float
>
(
low_
,
high_
);
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
out
->
clear
();
out
->
reserve
(
shapes_
.
size
());
for
(
const
framework
::
DDim
&
shape
:
shapes_
)
{
...
...
@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase {
}
}
void
ReInit
()
override
{
return
;
}
private:
float
low_
;
float
high_
;
...
...
@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RandomDataGenerator
<
T
>
(
shapes
,
Attr
<
float
>
(
"low"
),
Attr
<
float
>
(
"high"
)));
out
->
Reset
(
std
::
make_shared
<
RandomDataGenerator
<
T
>>
(
shapes
,
Attr
<
float
>
(
"low"
),
Attr
<
float
>
(
"high"
)));
}
};
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -21,10 +21,8 @@ namespace reader {
template
<
bool
ThreadSafe
>
class
RecordIOFileReader
:
public
framework
::
FileReader
{
public:
explicit
RecordIOFileReader
(
const
std
::
string
&
filename
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
:
FileReader
(
dims
),
scanner_
(
filename
),
explicit
RecordIOFileReader
(
const
std
::
string
&
filename
)
:
scanner_
(
filename
),
dev_ctx_
(
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()))
{
if
(
ThreadSafe
)
{
...
...
@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader {
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
ThreadSafe
)
{
...
...
@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader {
}
}
void
StartImpl
()
override
{
scanner_
.
Reset
();
}
private:
std
::
unique_ptr
<
std
::
mutex
>
mutex_
;
recordio
::
Scanner
scanner_
;
...
...
@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RecordIOFileReader
<
true
>
(
filename
,
RestoreShapes
(
shape_concat
,
ranks
)));
out
->
Reset
(
std
::
make_shared
<
RecordIOFileReader
<
true
>>
(
filename
));
}
};
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader {
ReloadBuffer
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
out
->
clear
();
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
...
...
@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
}
private:
void
ShutdownImpl
()
override
{
buffer_
.
clear
();
iteration_pos_
=
0
;
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
ReloadBuffer
();
}
void
ReloadBuffer
()
{
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
...
...
@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
ShuffleReader
>
(
underlying_reader
,
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
}
};
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
已删除
100644 → 0
浏览文件 @
10fbb831
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
public:
explicit
ThreadedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
DecoratedReader
(
reader
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
private:
std
::
mutex
mutex_
;
};
class
CreateThreadedReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)))
.
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
()));
}
};
class
CreateThreadedReaderOpMaker
:
public
DecoratedReaderMakerBase
{
protected:
void
Apply
()
override
{
AddComment
(
R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
reader
=
paddle
::
operators
::
reader
;
REGISTER_DECORATED_READER_OPERATOR
(
create_threaded_reader
,
reader
::
CreateThreadedReaderOp
,
reader
::
CreateThreadedReaderOpMaker
);
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
26ae6111
...
...
@@ -23,24 +23,26 @@ namespace reader {
class
MultiFileReader
:
public
framework
::
ReaderBase
{
public:
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
size_t
thread_num
,
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
size_t
thread_num
,
size_t
buffer_size
)
:
buffer_size_
(
buffer_size
)
{
readers_
.
reserve
(
file_names
.
size
());
for
(
const
std
::
string
&
f_name
:
file_names
)
{
readers_
.
emplace_back
(
CreateReaderByFileName
(
f_name
,
dims
));
readers_
.
emplace_back
(
CreateReaderByFileName
(
f_name
));
}
prefetchers_
.
resize
(
thread_num
);
StartNewScheduler
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
~
MultiFileReader
()
{
EndScheduler
();
}
private:
void
ShutdownImpl
()
override
{
EndScheduler
();
}
void
StartImpl
()
override
{
StartNewScheduler
();
}
void
StartNewScheduler
();
void
EndScheduler
();
void
ScheduleThreadFunc
();
...
...
@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase {
reader
::
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
};
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
MultiFileReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
buffer_
->
Receive
(
out
))
{
out
->
clear
();
}
}
void
MultiFileReader
::
ReInit
()
{
EndScheduler
();
StartNewScheduler
();
}
void
MultiFileReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
waiting_reader_idx_
=
new
reader
::
BlockingQueue
<
size_t
>
(
readers_
.
size
());
...
...
@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() {
}
}
}
// If users invoke
ReInit
() when scheduler is running, it will close the
// If users invoke
Shutdown
() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for
(
auto
&
p
:
prefetchers_
)
{
...
...
@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
reader
->
ReInit
();
reader
->
Shutdown
();
reader
->
Start
();
break
;
}
try
{
...
...
@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
MultiFileReader
(
file_names
,
RestoreShapes
(
shape_concat
,
ranks
),
thread_num
,
buffer_size
));
out
->
Reset
(
std
::
make_shared
<
MultiFileReader
>
(
file_names
,
thread_num
,
buffer_size
));
}
};
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
26ae6111
...
...
@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
}
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
const
std
::
string
&
file_name
)
{
size_t
separator_pos
=
file_name
.
find_last_of
(
kFileFormatSeparator
);
PADDLE_ENFORCE_NE
(
separator_pos
,
std
::
string
::
npos
,
"File name illegal! A legal file name should be like: "
...
...
@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
auto
itor
=
FileReaderRegistry
().
find
(
filetype
);
PADDLE_ENFORCE
(
itor
!=
FileReaderRegistry
().
end
(),
"No file reader registered for '%s' format."
,
filetype
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
file_name
,
dims
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
file_name
);
return
std
::
unique_ptr
<
framework
::
ReaderBase
>
(
reader
);
}
...
...
paddle/fluid/operators/reader/reader_op_registry.h
浏览文件 @
26ae6111
...
...
@@ -25,22 +25,21 @@ namespace reader {
static
constexpr
char
kFileFormatSeparator
[]
=
"."
;
using
FileReaderCreator
=
std
::
function
<
framework
::
ReaderBase
*
(
const
std
::
string
&
,
const
std
::
vector
<
framework
::
DDim
>
&
)
>
;
using
FileReaderCreator
=
std
::
function
<
framework
::
ReaderBase
*
(
const
std
::
string
&
)
>
;
std
::
unordered_map
<
std
::
string
,
FileReaderCreator
>&
FileReaderRegistry
();
template
<
typename
Reader
>
int
RegisterFileReader
(
const
std
::
string
&
filetype
)
{
FileReaderRegistry
()[
filetype
]
=
[](
const
std
::
string
&
fn
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
return
new
Reader
(
fn
,
dims
);
FileReaderRegistry
()[
filetype
]
=
[](
const
std
::
string
&
fn
)
{
return
new
Reader
(
fn
);
};
return
0
;
}
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
);
const
std
::
string
&
file_name
);
extern
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
26ae6111
...
...
@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
Re
Init
);
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
Re
setAll
);
using
LoDTensorBlockingQueue
=
::
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
26ae6111
...
...
@@ -375,9 +375,6 @@ def open_recordio_file(filename,
if
pass_num
>
1
:
main_prog_var
=
multi_pass
(
reader
=
main_prog_var
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_var
=
parallel
(
reader
=
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
...
...
@@ -529,9 +526,6 @@ def open_files(filenames,
main_prog_reader
=
multi_pass
(
reader
=
main_prog_reader
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_reader
=
parallel
(
reader
=
main_prog_reader
)
return
monkey_patch_reader_methods
(
main_prog_reader
)
...
...
@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader'
,
reader
,
{
'pass_num'
:
int
(
pass_num
)})
def
parallel
(
reader
):
return
__create_shared_decorated_reader__
(
'create_threaded_reader'
,
reader
,
{})
def
read_file
(
reader
):
"""
Execute the given reader and get data via it.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录