Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c48c586a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c48c586a
编写于
7月 08, 2018
作者:
Y
yuyang18
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use weak_ptr to implement DecoratedReaderChain
上级
2bbe5f77
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
62 addition
and
44 deletion
+62
-44
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+8
-11
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+25
-9
paddle/fluid/framework/reader_test.cc
paddle/fluid/framework/reader_test.cc
+8
-5
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+2
-2
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+4
-4
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+2
-1
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+2
-1
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+1
-1
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+2
-2
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+1
-1
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+2
-3
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+2
-1
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+3
-3
未找到文件。
paddle/fluid/framework/reader.cc
浏览文件 @
c48c586a
...
@@ -19,15 +19,11 @@ namespace paddle {
...
@@ -19,15 +19,11 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
ReaderBase
::~
ReaderBase
()
{}
ReaderBase
::~
ReaderBase
()
{}
void
ReaderBase
::
InsertDecoratedReader
(
ReaderBase
*
decorated_reader
)
{
void
ReaderBase
::
InsertDecoratedReader
(
decorated_readers_
.
emplace
(
decorated_reader
);
const
std
::
shared_ptr
<
ReaderBase
>
&
decorated_reader
)
{
}
decorated_readers_
.
emplace_back
(
decorated_reader
);
void
ReaderBase
::
EraseDecoratedReader
(
ReaderBase
*
decorated_reader
)
{
auto
it
=
decorated_readers_
.
find
(
decorated_reader
);
PADDLE_ENFORCE
(
it
!=
decorated_readers_
.
end
(),
"Cannot find the decorated reader to erase"
);
decorated_readers_
.
erase
(
it
);
}
}
std
::
unordered_set
<
ReaderBase
*>
ReaderBase
::
GetEndPoints
()
{
std
::
unordered_set
<
ReaderBase
*>
ReaderBase
::
GetEndPoints
()
{
std
::
unordered_set
<
ReaderBase
*>
result
;
std
::
unordered_set
<
ReaderBase
*>
result
;
std
::
deque
<
ReaderBase
*>
queue
;
std
::
deque
<
ReaderBase
*>
queue
;
...
@@ -38,8 +34,10 @@ std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
...
@@ -38,8 +34,10 @@ std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
if
(
front
->
decorated_readers_
.
empty
())
{
if
(
front
->
decorated_readers_
.
empty
())
{
result
.
emplace
(
front
);
result
.
emplace
(
front
);
}
else
{
}
else
{
for
(
ReaderBase
*
reader
:
front
->
decorated_readers_
)
{
for
(
auto
&
reader
:
front
->
decorated_readers_
)
{
queue
.
emplace_back
(
reader
);
if
(
auto
*
reader_ptr
=
reader
.
lock
().
get
())
{
queue
.
emplace_back
(
reader_ptr
);
}
}
}
}
}
}
}
...
@@ -66,6 +64,5 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
...
@@ -66,6 +64,5 @@ void FileReader::ReadNext(std::vector<LoDTensor> *out) {
}
}
}
}
}
}
DecoratedReader
::~
DecoratedReader
()
{
reader_
->
EraseDecoratedReader
(
this
);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
c48c586a
...
@@ -41,24 +41,26 @@ class ReaderBase {
...
@@ -41,24 +41,26 @@ class ReaderBase {
friend
class
DecoratedReader
;
friend
class
DecoratedReader
;
// These methods can be only invoked inside DecoratedReader to record the
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
// decorating chain.
void
InsertDecoratedReader
(
ReaderBase
*
decorated_reader
);
void
InsertDecoratedReader
(
void
EraseDecoratedReader
(
ReaderBase
*
decorated_reader
);
const
std
::
shared_ptr
<
ReaderBase
>&
decorated_reader
);
// A set of which readers that decorated this reader.
// A set of which readers that decorated this reader.
std
::
unordered_set
<
ReaderBase
*
>
decorated_readers_
;
std
::
vector
<
std
::
weak_ptr
<
ReaderBase
>
>
decorated_readers_
;
};
};
class
DecoratedReader
:
public
ReaderBase
{
class
DecoratedReader
:
public
ReaderBase
,
public
std
::
enable_shared_from_this
<
DecoratedReader
>
{
public:
public:
explicit
DecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
explicit
DecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
ReaderBase
(),
reader_
(
reader
)
{
:
ReaderBase
(),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
InsertDecoratedReader
(
this
);
}
}
~
DecoratedReader
();
void
ReInit
()
override
{
reader_
->
ReInit
();
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
void
RegisterDecorateChain
()
{
reader_
->
InsertDecoratedReader
(
shared_from_this
());
}
protected:
protected:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
};
...
@@ -80,9 +82,14 @@ class FileReader : public ReaderBase {
...
@@ -80,9 +82,14 @@ class FileReader : public ReaderBase {
// making it easier to access different type reader in Variables.
// making it easier to access different type reader in Variables.
class
ReaderHolder
{
class
ReaderHolder
{
public:
public:
void
Reset
(
ReaderBase
*
reader
)
{
reader_
.
reset
(
reader
);
}
template
<
typename
T
>
void
Reset
(
const
std
::
shared_ptr
<
T
>&
reader
)
{
auto
reader_base
=
std
::
dynamic_pointer_cast
<
ReaderBase
>
(
reader
);
PADDLE_ENFORCE_NOT_NULL
(
reader_base
);
reader_
=
reader_base
;
}
std
::
shared_ptr
<
ReaderBase
>
Get
()
const
{
return
reader_
;
}
const
std
::
shared_ptr
<
ReaderBase
>&
Get
()
const
{
return
reader_
;
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
...
@@ -93,9 +100,18 @@ class ReaderHolder {
...
@@ -93,9 +100,18 @@ class ReaderHolder {
reader_
->
ReInit
();
reader_
->
ReInit
();
}
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
private:
private:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
};
template
<
typename
T
,
typename
...
ARGS
>
inline
std
::
shared_ptr
<
DecoratedReader
>
MakeDecoratedReader
(
ARGS
&&
...
args
)
{
std
::
shared_ptr
<
DecoratedReader
>
reader
(
new
T
(
std
::
forward
<
ARGS
>
(
args
)...));
reader
->
RegisterDecorateChain
();
return
reader
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/reader_test.cc
浏览文件 @
c48c586a
...
@@ -32,18 +32,21 @@ class StubRootReader : public paddle::framework::ReaderBase {
...
@@ -32,18 +32,21 @@ class StubRootReader : public paddle::framework::ReaderBase {
TEST
(
READER
,
decorate_chain
)
{
TEST
(
READER
,
decorate_chain
)
{
auto
root
=
std
::
make_shared
<
StubRootReader
>
();
auto
root
=
std
::
make_shared
<
StubRootReader
>
();
auto
end_point1
=
StubDecoratedReader
(
root
);
auto
end_point1
=
auto
end_point2
=
StubDecoratedReader
(
root
);
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
auto
end_point2
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
{
{
auto
endpoints
=
root
->
GetEndPoints
();
auto
endpoints
=
root
->
GetEndPoints
();
ASSERT_EQ
(
endpoints
.
size
(),
2U
);
ASSERT_EQ
(
endpoints
.
size
(),
2U
);
ASSERT_NE
(
endpoints
.
count
(
&
end_point1
),
0
);
ASSERT_NE
(
endpoints
.
count
(
end_point1
.
get
()
),
0
);
ASSERT_NE
(
endpoints
.
count
(
&
end_point2
),
0
);
ASSERT_NE
(
endpoints
.
count
(
end_point2
.
get
()
),
0
);
}
}
{
{
auto
end_point3
=
StubDecoratedReader
(
root
);
auto
end_point3
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
3U
);
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
3U
);
}
}
{
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
2U
);
}
{
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
2U
);
}
...
...
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
c48c586a
...
@@ -46,8 +46,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
...
@@ -46,8 +46,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BatchReader
>
(
new
BatchReader
(
underlying_reader
.
Get
()
,
Attr
<
int
>
(
"batch_size"
)));
underlying_reader
,
Attr
<
int
>
(
"batch_size"
)));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
c48c586a
...
@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
...
@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
out
->
Reset
(
framework
::
MakeDecoratedReader
<
CustomReader
>
(
new
CustomReader
(
underlying_reader
.
Get
()
,
*
sub_block
,
underlying_reader
,
*
sub_block
,
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
c48c586a
...
@@ -109,7 +109,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...
@@ -109,7 +109,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
}
}
out
->
Reset
(
new
DoubleBufferReader
(
underlying_reader
.
Get
(),
place
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
DoubleBufferReader
>
(
underlying_reader
,
place
));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
c48c586a
...
@@ -60,7 +60,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
...
@@ -60,7 +60,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
int
pass_num
=
Attr
<
int
>
(
"pass_num"
);
int
pass_num
=
Attr
<
int
>
(
"pass_num"
);
out
->
Reset
(
new
MultiPassReader
(
underlying_reader
.
Get
(),
pass_num
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
MultiPassReader
>
(
underlying_reader
,
pass_num
));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
c48c586a
...
@@ -58,7 +58,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
...
@@ -58,7 +58,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto
*
queue_holder
=
auto
*
queue_holder
=
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
out
->
Reset
(
new
PyReader
(
queue_holder
->
GetQueue
()));
out
->
Reset
(
std
::
make_shared
<
PyReader
>
(
queue_holder
->
GetQueue
()));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
c48c586a
...
@@ -79,8 +79,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
...
@@ -79,8 +79,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RandomDataGenerator
<
T
>
(
shapes
,
Attr
<
float
>
(
"low"
),
out
->
Reset
(
std
::
make_shared
<
RandomDataGenerator
<
T
>>
(
Attr
<
float
>
(
"high"
)));
shapes
,
Attr
<
float
>
(
"low"
),
Attr
<
float
>
(
"high"
)));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
c48c586a
...
@@ -70,7 +70,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
...
@@ -70,7 +70,7 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RecordIOFileReader
<
true
>
(
out
->
Reset
(
std
::
make_shared
<
RecordIOFileReader
<
true
>
>
(
filename
,
RestoreShapes
(
shape_concat
,
ranks
)));
filename
,
RestoreShapes
(
shape_concat
,
ranks
)));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
c48c586a
...
@@ -86,9 +86,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
...
@@ -86,9 +86,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
out
->
Reset
(
framework
::
MakeDecoratedReader
<
ShuffleReader
>
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
underlying_reader
,
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
浏览文件 @
c48c586a
...
@@ -49,7 +49,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
...
@@ -49,7 +49,8 @@ class CreateThreadedReaderOp : public framework::OperatorBase {
}
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
()));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
ThreadedReader
>
(
underlying_reader
));
}
}
};
};
...
...
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
c48c586a
...
@@ -180,9 +180,9 @@ class OpenFilesOp : public framework::OperatorBase {
...
@@ -180,9 +180,9 @@ class OpenFilesOp : public framework::OperatorBase {
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
MultiFileReader
(
file_names
,
out
->
Reset
(
std
::
make_shared
<
MultiFileReader
>
(
RestoreShapes
(
shape_concat
,
ranks
)
,
file_names
,
RestoreShapes
(
shape_concat
,
ranks
),
thread_num
,
thread_num
,
buffer_size
));
buffer_size
));
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录