Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
26ae6111
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
26ae6111
编写于
7月 10, 2018
作者:
F
fengjiayi
提交者:
GitHub
7月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12051 from JiayiFeng/dev_reader_ResetAll
[WIP] Dev reader reset all
上级
10fbb831
d55919c6
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
283 addition
and
219 deletion
+283
-219
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-0
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+46
-14
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+76
-20
paddle/fluid/framework/reader_test.cc
paddle/fluid/framework/reader_test.cc
+52
-0
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+0
-1
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+18
-6
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+6
-6
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+16
-12
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+10
-10
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+13
-8
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+5
-7
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+5
-16
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+14
-4
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+0
-79
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+13
-16
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+2
-2
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+5
-6
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+0
-11
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
26ae6111
...
...
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
cc_library
(
reader SRCS reader.cc DEPS lod_tensor ddim
)
cc_test
(
reader_test SRCS reader_test.cc DEPS reader
)
cc_test
(
variable_test SRCS variable_test.cc
)
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
26ae6111
...
...
@@ -13,29 +13,61 @@
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <deque>
namespace
paddle
{
namespace
framework
{
ReaderBase
::~
ReaderBase
()
{}
FileReader
::
FileReader
(
const
std
::
vector
<
DDim
>
&
dims
)
:
dims_
(
dims
)
{}
void
FileReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
void
ReaderBase
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
PADDLE_ENFORCE_EQ
(
status_
,
ReaderStatus
::
kRunning
);
ReadNextImpl
(
out
);
if
(
out
->
empty
())
{
return
;
}
}
PADDLE_ENFORCE_EQ
(
out
->
size
(),
dims_
.
size
());
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
auto
&
actual
=
(
*
out
)[
i
].
dims
();
auto
&
expect
=
dims_
[
i
];
void
ReaderBase
::
InsertDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>
&
decorated_reader
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mu_
);
decorated_readers_
.
emplace_back
(
decorated_reader
);
}
PADDLE_ENFORCE_EQ
(
actual
.
size
(),
expect
.
size
());
for
(
int
j
=
0
;
j
<
actual
.
size
();
++
j
)
{
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
std
::
unordered_set
<
ReaderBase
*>
ReaderBase
::
GetEndPoints
()
{
std
::
unordered_set
<
ReaderBase
*>
result
;
std
::
deque
<
ReaderBase
*>
queue
;
queue
.
emplace_back
(
this
);
while
(
!
queue
.
empty
())
{
// BFS search
auto
*
front
=
queue
.
front
();
queue
.
pop_front
();
if
(
front
->
decorated_readers_
.
empty
())
{
result
.
emplace
(
front
);
}
else
{
for
(
auto
&
reader
:
front
->
decorated_readers_
)
{
if
(
auto
*
reader_ptr
=
reader
.
lock
().
get
())
{
queue
.
emplace_back
(
reader_ptr
);
}
}
}
}
return
result
;
}
void
ReaderBase
::
Shutdown
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
status_
!=
ReaderStatus
::
kStopped
)
{
ShutdownImpl
();
status_
=
ReaderStatus
::
kStopped
;
}
}
void
ReaderBase
::
Start
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
status_
!=
ReaderStatus
::
kRunning
)
{
StartImpl
();
status_
=
ReaderStatus
::
kRunning
;
}
}
ReaderBase
::~
ReaderBase
()
{
Shutdown
();
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
26ae6111
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
...
...
@@ -24,61 +25,116 @@
namespace
paddle
{
namespace
framework
{
enum
ReaderStatus
{
kRunning
,
kStopped
};
class
ReaderBase
{
public:
virtual
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
);
void
Shutdown
();
virtual
void
ReInit
()
=
0
;
void
Start
();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std
::
unordered_set
<
ReaderBase
*>
GetEndPoints
();
virtual
~
ReaderBase
();
protected:
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
virtual
void
ShutdownImpl
()
{}
virtual
void
StartImpl
()
{}
ReaderStatus
status_
{
kRunning
};
mutable
std
::
mutex
mu_
;
private:
friend
class
DecoratedReader
;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void
InsertDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
decorated_reader
);
// A set of which readers that decorated this reader.
std
::
vector
<
std
::
weak_ptr
<
ReaderBase
>>
decorated_readers_
;
};
class
DecoratedReader
:
public
ReaderBase
{
class
DecoratedReader
:
public
ReaderBase
,
public
std
::
enable_shared_from_this
<
DecoratedReader
>
{
public:
explicit
DecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
ReaderBase
(),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
void
RegisterDecorateChain
()
{
reader_
->
InsertDecoratedReader
(
shared_from_this
());
}
protected:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
class
FileReader
:
public
ReaderBase
{
public:
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
dims
);
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
override
;
void
ShutdownImpl
()
override
{
reader_
->
Shutdown
();
}
protected:
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
void
StartImpl
()
override
{
reader_
->
Start
();
}
private:
std
::
vector
<
DDim
>
dims_
;
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
// FileReader is just a conceptual class.
class
FileReader
:
public
ReaderBase
{};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class
ReaderHolder
{
public:
void
Reset
(
ReaderBase
*
reader
)
{
reader_
.
reset
(
reader
);
}
template
<
typename
T
>
void
Reset
(
const
std
::
shared_ptr
<
T
>&
reader
)
{
auto
reader_base
=
std
::
dynamic_pointer_cast
<
ReaderBase
>
(
reader
);
PADDLE_ENFORCE_NOT_NULL
(
reader_base
);
reader_
=
reader_base
;
}
std
::
shared_ptr
<
ReaderBase
>
Get
()
const
{
return
reader_
;
}
const
std
::
shared_ptr
<
ReaderBase
>&
Get
()
const
{
return
reader_
;
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
{
void
ResetAll
()
{
auto
end_readers
=
reader_
->
GetEndPoints
();
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Shutdown
();
}
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Start
();
}
}
void
Shutdown
()
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
Shutdown
();
}
void
Start
()
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
ReIni
t
();
reader_
->
Star
t
();
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
private:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
template
<
typename
T
,
typename
...
ARGS
>
inline
std
::
shared_ptr
<
DecoratedReader
>
MakeDecoratedReader
(
ARGS
&&
...
args
)
{
std
::
shared_ptr
<
DecoratedReader
>
reader
(
new
T
(
std
::
forward
<
ARGS
>
(
args
)...));
reader
->
RegisterDecorateChain
();
return
reader
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/reader_test.cc
0 → 100644
浏览文件 @
26ae6111
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class
StubDecoratedReader
:
public
paddle
::
framework
::
DecoratedReader
{
public:
explicit
StubDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>
&
reader
)
:
DecoratedReader
(
reader
)
{}
void
ReadNextImpl
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
};
class
StubRootReader
:
public
paddle
::
framework
::
ReaderBase
{
public:
void
ReadNextImpl
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
};
TEST
(
READER
,
decorate_chain
)
{
auto
root
=
std
::
make_shared
<
StubRootReader
>
();
auto
end_point1
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
auto
end_point2
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
{
auto
endpoints
=
root
->
GetEndPoints
();
ASSERT_EQ
(
endpoints
.
size
(),
2U
);
ASSERT_NE
(
endpoints
.
count
(
end_point1
.
get
()),
0
);
ASSERT_NE
(
endpoints
.
count
(
end_point2
.
get
()),
0
);
}
{
auto
end_point3
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
3U
);
}
{
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
2U
);
}
}
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
26ae6111
...
...
@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library
(
create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc
)
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_threaded_reader_op SRCS create_threaded_reader_op.cc
)
reader_library
(
create_custom_reader_op SRCS create_custom_reader_op.cc
)
reader_library
(
create_py_reader_op SRCS create_py_reader_op.cc
)
...
...
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -20,15 +20,19 @@ namespace reader {
class
BatchReader
:
public
framework
::
DecoratedReader
{
public:
BatchReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
batch_size
)
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
)
{
BatchReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
batch_size
,
bool
discard_leftover
)
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
),
discard_leftover_
(
discard_leftover
)
{
buffer_
.
reserve
(
batch_size_
);
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
int
batch_size_
;
bool
discard_leftover_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
};
...
...
@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
BatchReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"batch_size"
)));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BatchReader
>
(
underlying_reader
,
Attr
<
int
>
(
"batch_size"
),
Attr
<
bool
>
(
"discard_leftover"
)));
}
};
...
...
@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr
<
int
>
(
"batch_size"
,
"How many instances the batch reader yields each time."
)
.
GreaterThan
(
0
);
AddAttr
<
bool
>
(
"discard_leftover"
,
"If true, the leftover instances that are not enough for a "
"new batch will be discarded."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
CreateBatchReader Operator
...
...
@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void
BatchReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
BatchReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
buffer_
.
clear
();
buffer_
.
reserve
(
batch_size_
);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
...
...
@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
break
;
}
}
if
(
discard_leftover_
&&
buffer_
.
size
()
<
batch_size_
)
{
buffer_
.
clear
();
}
// Concat instances
out
->
clear
();
if
(
buffer_
.
empty
())
{
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
source_var_names_
(
source_var_names
),
sink_var_names_
(
sink_var_names
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
const
framework
::
ProgramDesc
program_
;
...
...
@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
CustomReader
(
underlying_reader
.
Get
()
,
*
sub_block
,
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
CustomReader
>
(
underlying_reader
,
*
sub_block
,
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
}
};
...
...
@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
}
};
void
CustomReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
CustomReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
out
->
clear
();
std
::
vector
<
framework
::
LoDTensor
>
underlying_outs
;
reader_
->
ReadNext
(
&
underlying_outs
);
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static
constexpr
size_t
kCacheSize
=
5
;
static
constexpr
size_t
kCacheSize
=
3
;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static
constexpr
size_t
kChannelSize
=
3
;
// kCacheSize - 2
static
constexpr
size_t
kChannelSize
=
1
;
// kCacheSize - 2
class
DoubleBufferReader
:
public
framework
::
DecoratedReader
{
public:
...
...
@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
private:
void
ShutdownImpl
()
override
{
EndPrefetcher
();
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
StartPrefetcher
();
}
void
StartPrefetcher
()
{
channel_
=
new
reader
::
BlockingQueue
<
size_t
>
(
kChannelSize
);
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
...
...
@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
}
out
->
Reset
(
new
DoubleBufferReader
(
underlying_reader
.
Get
(),
place
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
DoubleBufferReader
>
(
underlying_reader
,
place
));
}
};
...
...
@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
DoubleBufferReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
size_t
cached_tensor_id
;
if
(
channel_
->
Receive
(
&
cached_tensor_id
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
...
...
@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
}
}
void
DoubleBufferReader
::
ReInit
()
{
reader_
->
ReInit
();
EndPrefetcher
();
StartPrefetcher
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
size_t
cached_tensor_id
=
0
;
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader {
MultiPassReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
pass_num
)
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
reader_
->
ReadNext
(
out
);
if
(
out
->
empty
())
{
if
(
out
->
empty
()
&&
pass_count_
<
pass_num_
-
1
)
{
reader_
->
Shutdown
();
reader_
->
Start
();
reader_
->
ReadNext
(
out
);
++
pass_count_
;
if
(
pass_count_
<
pass_num_
)
{
reader_
->
ReInit
();
reader_
->
ReadNext
(
out
);
}
}
}
void
ReInit
()
override
{
private:
void
StartImpl
()
override
{
pass_count_
=
0
;
reader_
->
ReIni
t
();
reader_
->
Star
t
();
}
private:
int
pass_num_
;
mutable
int
pass_count_
;
};
...
...
@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
int
pass_num
=
Attr
<
int
>
(
"pass_num"
);
out
->
Reset
(
new
MultiPassReader
(
underlying_reader
.
Get
(),
pass_num
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
MultiPassReader
>
(
underlying_reader
,
pass_num
));
}
};
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -19,22 +19,27 @@ namespace paddle {
namespace
operators
{
namespace
reader
{
class
PyReader
:
public
framework
::
ReaderBase
{
class
PyReader
:
public
framework
::
FileReader
{
public:
explicit
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
)
{
explicit
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
)
:
framework
::
FileReader
()
{
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
queue_
=
queue
;
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
bool
success
;
*
out
=
queue_
->
Pop
(
&
success
);
if
(
!
success
)
out
->
clear
();
}
void
ReInit
()
override
{}
private:
void
ShutdownImpl
()
override
{
/* TODO */
}
void
StartImpl
()
override
{
/* TODO */
}
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue_
;
};
...
...
@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase {
const
std
::
string
&
queue_name
=
Input
(
"blocking_queue"
);
auto
*
queue_holder_var
=
scope
.
FindVar
(
queue_name
);
PADDLE_ENFORCE
(
queue_holder_var
!=
nullptr
,
PADDLE_ENFORCE
_NOT_NULL
(
queue_holder_var
,
"No LoDTensorBlockingQueueHolder variable with name %s found"
,
queue_name
);
auto
*
queue_holder
=
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
out
->
Reset
(
new
PyReader
(
queue_holder
->
GetQueue
()));
out
->
Reset
(
std
::
make_shared
<
PyReader
>
(
queue_holder
->
GetQueue
()));
}
};
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
26ae6111
...
...
@@ -19,11 +19,11 @@ namespace operators {
namespace
reader
{
template
<
typename
T
>
class
RandomDataGenerator
:
public
framework
::
ReaderBase
{
class
RandomDataGenerator
:
public
framework
::
FileReader
{
public:
RandomDataGenerator
(
const
std
::
vector
<
framework
::
DDim
>&
shapes
,
float
low
,
float
high
)
:
framework
::
ReaderBase
(),
low_
(
low
),
high_
(
high
),
shapes_
(
shapes
)
{
:
framework
::
FileReader
(),
low_
(
low
),
high_
(
high
),
shapes_
(
shapes
)
{
PADDLE_ENFORCE_LE
(
low
,
high
,
"'low' shouldn't be greater than 'high'.(%f vs %f)"
,
low
,
high
);
...
...
@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
dist_
=
std
::
uniform_real_distribution
<
float
>
(
low_
,
high_
);
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
out
->
clear
();
out
->
reserve
(
shapes_
.
size
());
for
(
const
framework
::
DDim
&
shape
:
shapes_
)
{
...
...
@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase {
}
}
void
ReInit
()
override
{
return
;
}
private:
float
low_
;
float
high_
;
...
...
@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RandomDataGenerator
<
T
>
(
shapes
,
Attr
<
float
>
(
"low"
),
Attr
<
float
>
(
"high"
)));
out
->
Reset
(
std
::
make_shared
<
RandomDataGenerator
<
T
>>
(
shapes
,
Attr
<
float
>
(
"low"
),
Attr
<
float
>
(
"high"
)));
}
};
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -21,10 +21,8 @@ namespace reader {
template
<
bool
ThreadSafe
>
class
RecordIOFileReader
:
public
framework
::
FileReader
{
public:
explicit
RecordIOFileReader
(
const
std
::
string
&
filename
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
:
FileReader
(
dims
),
scanner_
(
filename
),
explicit
RecordIOFileReader
(
const
std
::
string
&
filename
)
:
scanner_
(
filename
),
dev_ctx_
(
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()))
{
if
(
ThreadSafe
)
{
...
...
@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader {
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
ThreadSafe
)
{
...
...
@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader {
}
}
void
StartImpl
()
override
{
scanner_
.
Reset
();
}
private:
std
::
unique_ptr
<
std
::
mutex
>
mutex_
;
recordio
::
Scanner
scanner_
;
...
...
@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RecordIOFileReader
<
true
>
(
filename
,
RestoreShapes
(
shape_concat
,
ranks
)));
out
->
Reset
(
std
::
make_shared
<
RecordIOFileReader
<
true
>>
(
filename
));
}
};
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
26ae6111
...
...
@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader {
ReloadBuffer
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
out
->
clear
();
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
...
...
@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
}
private:
void
ShutdownImpl
()
override
{
buffer_
.
clear
();
iteration_pos_
=
0
;
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
ReloadBuffer
();
}
void
ReloadBuffer
()
{
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
...
...
@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
ShuffleReader
>
(
underlying_reader
,
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
}
};
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
已删除
100644 → 0
浏览文件 @
10fbb831
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
public:
explicit
ThreadedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
DecoratedReader
(
reader
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
private:
std
::
mutex
mutex_
;
};
class
CreateThreadedReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)))
.
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
()));
}
};
class
CreateThreadedReaderOpMaker
:
public
DecoratedReaderMakerBase
{
protected:
void
Apply
()
override
{
AddComment
(
R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
reader
=
paddle
::
operators
::
reader
;
REGISTER_DECORATED_READER_OPERATOR
(
create_threaded_reader
,
reader
::
CreateThreadedReaderOp
,
reader
::
CreateThreadedReaderOpMaker
);
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
26ae6111
...
...
@@ -23,24 +23,26 @@ namespace reader {
class
MultiFileReader
:
public
framework
::
ReaderBase
{
public:
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
size_t
thread_num
,
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
size_t
thread_num
,
size_t
buffer_size
)
:
buffer_size_
(
buffer_size
)
{
readers_
.
reserve
(
file_names
.
size
());
for
(
const
std
::
string
&
f_name
:
file_names
)
{
readers_
.
emplace_back
(
CreateReaderByFileName
(
f_name
,
dims
));
readers_
.
emplace_back
(
CreateReaderByFileName
(
f_name
));
}
prefetchers_
.
resize
(
thread_num
);
StartNewScheduler
();
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
~
MultiFileReader
()
{
EndScheduler
();
}
private:
void
ShutdownImpl
()
override
{
EndScheduler
();
}
void
StartImpl
()
override
{
StartNewScheduler
();
}
void
StartNewScheduler
();
void
EndScheduler
();
void
ScheduleThreadFunc
();
...
...
@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase {
reader
::
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
};
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
MultiFileReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
buffer_
->
Receive
(
out
))
{
out
->
clear
();
}
}
void
MultiFileReader
::
ReInit
()
{
EndScheduler
();
StartNewScheduler
();
}
void
MultiFileReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
waiting_reader_idx_
=
new
reader
::
BlockingQueue
<
size_t
>
(
readers_
.
size
());
...
...
@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() {
}
}
}
// If users invoke
ReInit
() when scheduler is running, it will close the
// If users invoke
Shutdown
() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for
(
auto
&
p
:
prefetchers_
)
{
...
...
@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
reader
->
ReInit
();
reader
->
Shutdown
();
reader
->
Start
();
break
;
}
try
{
...
...
@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
MultiFileReader
(
file_names
,
RestoreShapes
(
shape_concat
,
ranks
),
thread_num
,
buffer_size
));
out
->
Reset
(
std
::
make_shared
<
MultiFileReader
>
(
file_names
,
thread_num
,
buffer_size
));
}
};
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
26ae6111
...
...
@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
}
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
const
std
::
string
&
file_name
)
{
size_t
separator_pos
=
file_name
.
find_last_of
(
kFileFormatSeparator
);
PADDLE_ENFORCE_NE
(
separator_pos
,
std
::
string
::
npos
,
"File name illegal! A legal file name should be like: "
...
...
@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
auto
itor
=
FileReaderRegistry
().
find
(
filetype
);
PADDLE_ENFORCE
(
itor
!=
FileReaderRegistry
().
end
(),
"No file reader registered for '%s' format."
,
filetype
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
file_name
,
dims
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
file_name
);
return
std
::
unique_ptr
<
framework
::
ReaderBase
>
(
reader
);
}
...
...
paddle/fluid/operators/reader/reader_op_registry.h
浏览文件 @
26ae6111
...
...
@@ -25,22 +25,21 @@ namespace reader {
static
constexpr
char
kFileFormatSeparator
[]
=
"."
;
using
FileReaderCreator
=
std
::
function
<
framework
::
ReaderBase
*
(
const
std
::
string
&
,
const
std
::
vector
<
framework
::
DDim
>
&
)
>
;
using
FileReaderCreator
=
std
::
function
<
framework
::
ReaderBase
*
(
const
std
::
string
&
)
>
;
std
::
unordered_map
<
std
::
string
,
FileReaderCreator
>&
FileReaderRegistry
();
template
<
typename
Reader
>
int
RegisterFileReader
(
const
std
::
string
&
filetype
)
{
FileReaderRegistry
()[
filetype
]
=
[](
const
std
::
string
&
fn
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
return
new
Reader
(
fn
,
dims
);
FileReaderRegistry
()[
filetype
]
=
[](
const
std
::
string
&
fn
)
{
return
new
Reader
(
fn
);
};
return
0
;
}
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
);
const
std
::
string
&
file_name
);
extern
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
26ae6111
...
...
@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
Re
Init
);
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
Re
setAll
);
using
LoDTensorBlockingQueue
=
::
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
26ae6111
...
...
@@ -375,9 +375,6 @@ def open_recordio_file(filename,
if
pass_num
>
1
:
main_prog_var
=
multi_pass
(
reader
=
main_prog_var
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_var
=
parallel
(
reader
=
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
...
...
@@ -529,9 +526,6 @@ def open_files(filenames,
main_prog_reader
=
multi_pass
(
reader
=
main_prog_reader
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_reader
=
parallel
(
reader
=
main_prog_reader
)
return
monkey_patch_reader_methods
(
main_prog_reader
)
...
...
@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader'
,
reader
,
{
'pass_num'
:
int
(
pass_num
)})
def
parallel
(
reader
):
return
__create_shared_decorated_reader__
(
'create_threaded_reader'
,
reader
,
{})
def
read_file
(
reader
):
"""
Execute the given reader and get data via it.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录