Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6d6f49cd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d6f49cd
编写于
7月 09, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'yuyang/feature/decorated_reader_chain' into dev_reader_ResetAll
上级
611716e9
62c1133f
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
138 addition
and
19 deletion
+138
-19
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-0
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+28
-0
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+36
-3
paddle/fluid/framework/reader_test.cc
paddle/fluid/framework/reader_test.cc
+53
-0
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+3
-2
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+4
-4
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+2
-1
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+2
-1
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+1
-1
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+2
-2
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+2
-1
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+2
-3
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+2
-1
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
6d6f49cd
...
...
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
cc_library
(
reader SRCS reader.cc DEPS lod_tensor ddim
)
cc_test
(
reader_test SRCS reader_test.cc DEPS reader
)
cc_test
(
variable_test SRCS variable_test.cc
)
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
6d6f49cd
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <deque>
namespace
paddle
{
namespace
framework
{
...
...
@@ -23,6 +24,33 @@ void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl
(
out
);
}
void
ReaderBase
::
InsertDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>
&
decorated_reader
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mu_
));
decorated_readers_
.
emplace_back
(
decorated_reader
);
}
std
::
unordered_set
<
ReaderBase
*>
ReaderBase
::
GetEndPoints
()
{
std
::
unordered_set
<
ReaderBase
*>
result
;
std
::
deque
<
ReaderBase
*>
queue
;
queue
.
emplace_back
(
this
);
while
(
!
queue
.
empty
())
{
// BFS search
auto
*
front
=
queue
.
front
();
queue
.
pop_front
();
if
(
front
->
decorated_readers_
.
empty
())
{
result
.
emplace
(
front
);
}
else
{
for
(
auto
&
reader
:
front
->
decorated_readers_
)
{
if
(
auto
*
reader_ptr
=
reader
.
lock
().
get
())
{
queue
.
emplace_back
(
reader_ptr
);
}
}
}
}
return
result
;
}
void
ReaderBase
::
Shutdown
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
status_
!=
ReaderStatus
::
kStopped
)
{
...
...
paddle/fluid/framework/reader.h
浏览文件 @
6d6f49cd
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
...
...
@@ -34,6 +35,10 @@ class ReaderBase {
void
Start
();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std
::
unordered_set
<
ReaderBase
*>
GetEndPoints
();
virtual
~
ReaderBase
();
protected:
...
...
@@ -46,15 +51,29 @@ class ReaderBase {
ReaderStatus
status_
{
kRunning
};
mutable
std
::
mutex
mu_
;
private:
friend
class
DecoratedReader
;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void
InsertDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
decorated_reader
);
// A set of which readers that decorated this reader.
std
::
vector
<
std
::
weak_ptr
<
ReaderBase
>>
decorated_readers_
;
};
class
DecoratedReader
:
public
ReaderBase
{
class
DecoratedReader
:
public
ReaderBase
,
public
std
::
enable_shared_from_this
<
DecoratedReader
>
{
public:
explicit
DecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
ReaderBase
(),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
void
RegisterDecorateChain
()
{
reader_
->
InsertDecoratedReader
(
shared_from_this
());
}
protected:
void
ShutdownImpl
()
override
{
reader_
->
Shutdown
();
}
...
...
@@ -70,9 +89,14 @@ class FileReader : public ReaderBase {};
// making it easier to access different type reader in Variables.
class
ReaderHolder
{
public:
void
Reset
(
ReaderBase
*
reader
)
{
reader_
.
reset
(
reader
);
}
template
<
typename
T
>
void
Reset
(
const
std
::
shared_ptr
<
T
>&
reader
)
{
auto
reader_base
=
std
::
dynamic_pointer_cast
<
ReaderBase
>
(
reader
);
PADDLE_ENFORCE_NOT_NULL
(
reader_base
);
reader_
=
reader_base
;
}
std
::
shared_ptr
<
ReaderBase
>
Get
()
const
{
return
reader_
;
}
const
std
::
shared_ptr
<
ReaderBase
>&
Get
()
const
{
return
reader_
;
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
...
...
@@ -93,9 +117,18 @@ class ReaderHolder {
reader_
->
Start
();
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
private:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
template
<
typename
T
,
typename
...
ARGS
>
inline
std
::
shared_ptr
<
DecoratedReader
>
MakeDecoratedReader
(
ARGS
&&
...
args
)
{
std
::
shared_ptr
<
DecoratedReader
>
reader
(
new
T
(
std
::
forward
<
ARGS
>
(
args
)...));
reader
->
RegisterDecorateChain
();
return
reader
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/reader_test.cc
0 → 100644
浏览文件 @
6d6f49cd
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class
StubDecoratedReader
:
public
paddle
::
framework
::
DecoratedReader
{
public:
explicit
StubDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>
&
reader
)
:
DecoratedReader
(
reader
)
{}
void
ReadNext
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
};
class
StubRootReader
:
public
paddle
::
framework
::
ReaderBase
{
public:
void
ReadNext
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
void
ReInit
()
override
{}
};
TEST
(
READER
,
decorate_chain
)
{
auto
root
=
std
::
make_shared
<
StubRootReader
>
();
auto
end_point1
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
auto
end_point2
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
{
auto
endpoints
=
root
->
GetEndPoints
();
ASSERT_EQ
(
endpoints
.
size
(),
2U
);
ASSERT_NE
(
endpoints
.
count
(
end_point1
.
get
()),
0
);
ASSERT_NE
(
endpoints
.
count
(
end_point2
.
get
()),
0
);
}
{
auto
end_point3
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
3U
);
}
{
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
2U
);
}
}
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -50,7 +50,8 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
BatchReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"batch_size"
),
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BatchReader
>
(
underlying_reader
,
Attr
<
int
>
(
"batch_size"
),
Attr
<
bool
>
(
"discard_leftover"
)));
}
};
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -60,8 +60,8 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
CustomReader
(
underlying_reader
.
Get
()
,
*
sub_block
,
out
->
Reset
(
framework
::
MakeDecoratedReader
<
CustomReader
>
(
underlying_reader
,
*
sub_block
,
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
}
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -118,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
}
out
->
Reset
(
new
DoubleBufferReader
(
underlying_reader
.
Get
(),
place
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
DoubleBufferReader
>
(
underlying_reader
,
place
));
}
};
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -59,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
int
pass_num
=
Attr
<
int
>
(
"pass_num"
);
out
->
Reset
(
new
MultiPassReader
(
underlying_reader
.
Get
(),
pass_num
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
MultiPassReader
>
(
underlying_reader
,
pass_num
));
}
};
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -63,7 +63,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto
*
queue_holder
=
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
out
->
Reset
(
new
PyReader
(
queue_holder
->
GetQueue
()));
out
->
Reset
(
std
::
make_shared
<
PyReader
>
(
queue_holder
->
GetQueue
()));
}
};
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -77,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RandomDataGenerator
<
T
>
(
shapes
,
Attr
<
float
>
(
"low"
),
Attr
<
float
>
(
"high"
)));
out
->
Reset
(
std
::
make_shared
<
RandomDataGenerator
<
T
>>
(
shapes
,
Attr
<
float
>
(
"low"
),
Attr
<
float
>
(
"high"
)));
}
};
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -59,7 +59,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RecordIOFileReader
<
true
>
(
filename
));
out
->
Reset
(
std
::
make_shared
<
RecordIOFileReader
<
true
>>
(
filename
));
}
};
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -97,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
ShuffleReader
>
(
underlying_reader
,
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
}
};
...
...
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
6d6f49cd
...
...
@@ -178,7 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
MultiFileReader
(
file_names
,
thread_num
,
buffer_size
));
out
->
Reset
(
std
::
make_shared
<
MultiFileReader
>
(
file_names
,
thread_num
,
buffer_size
));
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录