Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7d6afee5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7d6afee5
编写于
7月 10, 2018
作者:
Y
yuyang18
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into feature/exception_safe_pe
上级
0a445da6
26ae6111
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
283 addition
and
219 deletion
+283
-219
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-0
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+46
-14
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+76
-20
paddle/fluid/framework/reader_test.cc
paddle/fluid/framework/reader_test.cc
+52
-0
paddle/fluid/operators/reader/CMakeLists.txt
paddle/fluid/operators/reader/CMakeLists.txt
+0
-1
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+18
-6
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+6
-6
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+16
-12
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
+10
-10
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+13
-8
paddle/fluid/operators/reader/create_random_data_generator_op.cc
...fluid/operators/reader/create_random_data_generator_op.cc
+5
-7
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
.../fluid/operators/reader/create_recordio_file_reader_op.cc
+5
-16
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
+14
-4
paddle/fluid/operators/reader/create_threaded_reader_op.cc
paddle/fluid/operators/reader/create_threaded_reader_op.cc
+0
-79
paddle/fluid/operators/reader/open_files_op.cc
paddle/fluid/operators/reader/open_files_op.cc
+13
-16
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+2
-2
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+5
-6
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-1
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+0
-11
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
7d6afee5
...
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
...
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
cc_library
(
reader SRCS reader.cc DEPS lod_tensor ddim
)
cc_library
(
reader SRCS reader.cc DEPS lod_tensor ddim
)
cc_test
(
reader_test SRCS reader_test.cc DEPS reader
)
cc_test
(
variable_test SRCS variable_test.cc
)
cc_test
(
variable_test SRCS variable_test.cc
)
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
7d6afee5
...
@@ -13,29 +13,61 @@
...
@@ -13,29 +13,61 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include <deque>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
ReaderBase
::~
ReaderBase
()
{}
FileReader
::
FileReader
(
const
std
::
vector
<
DDim
>
&
dims
)
:
dims_
(
dims
)
{}
void
ReaderBase
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
void
FileReader
::
ReadNext
(
std
::
vector
<
LoDTensor
>
*
out
)
{
PADDLE_ENFORCE_EQ
(
status_
,
ReaderStatus
::
kRunning
);
ReadNextImpl
(
out
);
ReadNextImpl
(
out
);
if
(
out
->
empty
())
{
}
return
;
}
PADDLE_ENFORCE_EQ
(
out
->
size
(),
dims_
.
size
());
void
ReaderBase
::
InsertDecoratedReader
(
for
(
size_t
i
=
0
;
i
<
dims_
.
size
();
++
i
)
{
const
std
::
shared_ptr
<
ReaderBase
>
&
decorated_reader
)
{
auto
&
actual
=
(
*
out
)[
i
].
dims
();
std
::
lock_guard
<
std
::
mutex
>
guard
(
mu_
);
auto
&
expect
=
dims_
[
i
];
decorated_readers_
.
emplace_back
(
decorated_reader
);
}
PADDLE_ENFORCE_EQ
(
actual
.
size
(),
expect
.
size
());
std
::
unordered_set
<
ReaderBase
*>
ReaderBase
::
GetEndPoints
()
{
for
(
int
j
=
0
;
j
<
actual
.
size
();
++
j
)
{
std
::
unordered_set
<
ReaderBase
*>
result
;
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
std
::
deque
<
ReaderBase
*>
queue
;
queue
.
emplace_back
(
this
);
while
(
!
queue
.
empty
())
{
// BFS search
auto
*
front
=
queue
.
front
();
queue
.
pop_front
();
if
(
front
->
decorated_readers_
.
empty
())
{
result
.
emplace
(
front
);
}
else
{
for
(
auto
&
reader
:
front
->
decorated_readers_
)
{
if
(
auto
*
reader_ptr
=
reader
.
lock
().
get
())
{
queue
.
emplace_back
(
reader_ptr
);
}
}
}
}
}
}
return
result
;
}
}
void
ReaderBase
::
Shutdown
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
status_
!=
ReaderStatus
::
kStopped
)
{
ShutdownImpl
();
status_
=
ReaderStatus
::
kStopped
;
}
}
void
ReaderBase
::
Start
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mu_
);
if
(
status_
!=
ReaderStatus
::
kRunning
)
{
StartImpl
();
status_
=
ReaderStatus
::
kRunning
;
}
}
ReaderBase
::~
ReaderBase
()
{
Shutdown
();
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
7d6afee5
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <memory>
#include <memory>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
...
@@ -24,61 +25,116 @@
...
@@ -24,61 +25,116 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
enum
ReaderStatus
{
kRunning
,
kStopped
};
class
ReaderBase
{
class
ReaderBase
{
public:
public:
virtual
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
);
void
Shutdown
();
virtual
void
ReInit
()
=
0
;
void
Start
();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std
::
unordered_set
<
ReaderBase
*>
GetEndPoints
();
virtual
~
ReaderBase
();
virtual
~
ReaderBase
();
protected:
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
virtual
void
ShutdownImpl
()
{}
virtual
void
StartImpl
()
{}
ReaderStatus
status_
{
kRunning
};
mutable
std
::
mutex
mu_
;
private:
friend
class
DecoratedReader
;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void
InsertDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
decorated_reader
);
// A set of which readers that decorated this reader.
std
::
vector
<
std
::
weak_ptr
<
ReaderBase
>>
decorated_readers_
;
};
};
class
DecoratedReader
:
public
ReaderBase
{
class
DecoratedReader
:
public
ReaderBase
,
public
std
::
enable_shared_from_this
<
DecoratedReader
>
{
public:
public:
explicit
DecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
explicit
DecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
ReaderBase
(),
reader_
(
reader
)
{
:
ReaderBase
(),
reader_
(
reader
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
}
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
void
RegisterDecorateChain
()
{
reader_
->
InsertDecoratedReader
(
shared_from_this
());
}
protected:
protected:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
void
ShutdownImpl
()
override
{
reader_
->
Shutdown
();
}
};
class
FileReader
:
public
ReaderBase
{
public:
explicit
FileReader
(
const
std
::
vector
<
DDim
>&
dims
);
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
override
;
protected:
void
StartImpl
()
override
{
reader_
->
Start
();
}
virtual
void
ReadNextImpl
(
std
::
vector
<
LoDTensor
>*
out
)
=
0
;
private:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
std
::
vector
<
DDim
>
dims_
;
};
};
// FileReader is just a conceptual class.
class
FileReader
:
public
ReaderBase
{};
// The ReaderHolder is used as reader' unified wrapper,
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
// making it easier to access different type reader in Variables.
class
ReaderHolder
{
class
ReaderHolder
{
public:
public:
void
Reset
(
ReaderBase
*
reader
)
{
reader_
.
reset
(
reader
);
}
template
<
typename
T
>
void
Reset
(
const
std
::
shared_ptr
<
T
>&
reader
)
{
auto
reader_base
=
std
::
dynamic_pointer_cast
<
ReaderBase
>
(
reader
);
PADDLE_ENFORCE_NOT_NULL
(
reader_base
);
reader_
=
reader_base
;
}
std
::
shared_ptr
<
ReaderBase
>
Get
()
const
{
return
reader_
;
}
const
std
::
shared_ptr
<
ReaderBase
>&
Get
()
const
{
return
reader_
;
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
ReadNext
(
out
);
reader_
->
ReadNext
(
out
);
}
}
void
ReInit
()
{
void
ResetAll
()
{
auto
end_readers
=
reader_
->
GetEndPoints
();
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Shutdown
();
}
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Start
();
}
}
void
Shutdown
()
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
Shutdown
();
}
void
Start
()
{
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
ReIni
t
();
reader_
->
Star
t
();
}
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
private:
private:
std
::
shared_ptr
<
ReaderBase
>
reader_
;
std
::
shared_ptr
<
ReaderBase
>
reader_
;
};
};
template
<
typename
T
,
typename
...
ARGS
>
inline
std
::
shared_ptr
<
DecoratedReader
>
MakeDecoratedReader
(
ARGS
&&
...
args
)
{
std
::
shared_ptr
<
DecoratedReader
>
reader
(
new
T
(
std
::
forward
<
ARGS
>
(
args
)...));
reader
->
RegisterDecorateChain
();
return
reader
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/reader_test.cc
0 → 100644
浏览文件 @
7d6afee5
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class
StubDecoratedReader
:
public
paddle
::
framework
::
DecoratedReader
{
public:
explicit
StubDecoratedReader
(
const
std
::
shared_ptr
<
ReaderBase
>
&
reader
)
:
DecoratedReader
(
reader
)
{}
void
ReadNextImpl
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
};
class
StubRootReader
:
public
paddle
::
framework
::
ReaderBase
{
public:
void
ReadNextImpl
(
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
*
out
)
override
{}
};
TEST
(
READER
,
decorate_chain
)
{
auto
root
=
std
::
make_shared
<
StubRootReader
>
();
auto
end_point1
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
auto
end_point2
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
{
auto
endpoints
=
root
->
GetEndPoints
();
ASSERT_EQ
(
endpoints
.
size
(),
2U
);
ASSERT_NE
(
endpoints
.
count
(
end_point1
.
get
()),
0
);
ASSERT_NE
(
endpoints
.
count
(
end_point2
.
get
()),
0
);
}
{
auto
end_point3
=
paddle
::
framework
::
MakeDecoratedReader
<
StubDecoratedReader
>
(
root
);
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
3U
);
}
{
ASSERT_EQ
(
root
->
GetEndPoints
().
size
(),
2U
);
}
}
paddle/fluid/operators/reader/CMakeLists.txt
浏览文件 @
7d6afee5
...
@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
...
@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library
(
create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc
)
reader_library
(
create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc
)
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc
)
reader_library
(
create_threaded_reader_op SRCS create_threaded_reader_op.cc
)
reader_library
(
create_custom_reader_op SRCS create_custom_reader_op.cc
)
reader_library
(
create_custom_reader_op SRCS create_custom_reader_op.cc
)
reader_library
(
create_py_reader_op SRCS create_py_reader_op.cc
)
reader_library
(
create_py_reader_op SRCS create_py_reader_op.cc
)
...
...
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
7d6afee5
...
@@ -20,15 +20,19 @@ namespace reader {
...
@@ -20,15 +20,19 @@ namespace reader {
class
BatchReader
:
public
framework
::
DecoratedReader
{
class
BatchReader
:
public
framework
::
DecoratedReader
{
public:
public:
BatchReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
batch_size
)
BatchReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
batch_size
,
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
)
{
bool
discard_leftover
)
:
DecoratedReader
(
reader
),
batch_size_
(
batch_size
),
discard_leftover_
(
discard_leftover
)
{
buffer_
.
reserve
(
batch_size_
);
buffer_
.
reserve
(
batch_size_
);
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
private:
int
batch_size_
;
int
batch_size_
;
bool
discard_leftover_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
std
::
vector
<
std
::
vector
<
framework
::
LoDTensor
>>
buffer_
;
};
};
...
@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
...
@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BatchReader
>
(
new
BatchReader
(
underlying_reader
.
Get
(),
Attr
<
int
>
(
"batch_size"
)));
underlying_reader
,
Attr
<
int
>
(
"batch_size"
),
Attr
<
bool
>
(
"discard_leftover"
)));
}
}
};
};
...
@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr
<
int
>
(
"batch_size"
,
AddAttr
<
int
>
(
"batch_size"
,
"How many instances the batch reader yields each time."
)
"How many instances the batch reader yields each time."
)
.
GreaterThan
(
0
);
.
GreaterThan
(
0
);
AddAttr
<
bool
>
(
"discard_leftover"
,
"If true, the leftover instances that are not enough for a "
"new batch will be discarded."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
CreateBatchReader Operator
CreateBatchReader Operator
...
@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
}
}
};
};
void
BatchReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
BatchReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reserve
(
batch_size_
);
buffer_
.
reserve
(
batch_size_
);
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size_
;
++
i
)
{
...
@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
...
@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
break
;
break
;
}
}
}
}
if
(
discard_leftover_
&&
buffer_
.
size
()
<
batch_size_
)
{
buffer_
.
clear
();
}
// Concat instances
// Concat instances
out
->
clear
();
out
->
clear
();
if
(
buffer_
.
empty
())
{
if
(
buffer_
.
empty
())
{
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
7d6afee5
...
@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
...
@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
source_var_names_
(
source_var_names
),
source_var_names_
(
source_var_names
),
sink_var_names_
(
sink_var_names
)
{}
sink_var_names_
(
sink_var_names
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
private:
private:
const
framework
::
ProgramDesc
program_
;
const
framework
::
ProgramDesc
program_
;
...
@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
...
@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
out
->
Reset
(
framework
::
MakeDecoratedReader
<
CustomReader
>
(
new
CustomReader
(
underlying_reader
.
Get
()
,
*
sub_block
,
underlying_reader
,
*
sub_block
,
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"source_var_names"
),
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
Attr
<
std
::
vector
<
std
::
string
>>
(
"sink_var_names"
)));
}
}
};
};
...
@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
...
@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
}
}
};
};
void
CustomReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
CustomReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
out
->
clear
();
out
->
clear
();
std
::
vector
<
framework
::
LoDTensor
>
underlying_outs
;
std
::
vector
<
framework
::
LoDTensor
>
underlying_outs
;
reader_
->
ReadNext
(
&
underlying_outs
);
reader_
->
ReadNext
(
&
underlying_outs
);
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
7d6afee5
...
@@ -23,13 +23,13 @@ namespace reader {
...
@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
// time. So the kCacheSize shoul be at least 2.
static
constexpr
size_t
kCacheSize
=
5
;
static
constexpr
size_t
kCacheSize
=
3
;
// There will be two bacthes out of the channel during training:
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// subsequent operators.
// So the channel size should be kChacheSize - 2
// So the channel size should be kChacheSize - 2
static
constexpr
size_t
kChannelSize
=
3
;
// kCacheSize - 2
static
constexpr
size_t
kChannelSize
=
1
;
// kCacheSize - 2
class
DoubleBufferReader
:
public
framework
::
DecoratedReader
{
class
DoubleBufferReader
:
public
framework
::
DecoratedReader
{
public:
public:
...
@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader {
...
@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher
();
StartPrefetcher
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
~
DoubleBufferReader
()
{
EndPrefetcher
();
}
private:
private:
void
ShutdownImpl
()
override
{
EndPrefetcher
();
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
StartPrefetcher
();
}
void
StartPrefetcher
()
{
void
StartPrefetcher
()
{
channel_
=
new
reader
::
BlockingQueue
<
size_t
>
(
kChannelSize
);
channel_
=
new
reader
::
BlockingQueue
<
size_t
>
(
kChannelSize
);
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
prefetcher_
=
std
::
thread
([
this
]
{
PrefetchThreadFunc
();
});
...
@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...
@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
}
}
out
->
Reset
(
new
DoubleBufferReader
(
underlying_reader
.
Get
(),
place
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
DoubleBufferReader
>
(
underlying_reader
,
place
));
}
}
};
};
...
@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
...
@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
}
};
};
void
DoubleBufferReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
DoubleBufferReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
size_t
cached_tensor_id
;
size_t
cached_tensor_id
;
if
(
channel_
->
Receive
(
&
cached_tensor_id
))
{
if
(
channel_
->
Receive
(
&
cached_tensor_id
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
...
@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
...
@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
}
}
}
}
void
DoubleBufferReader
::
ReInit
()
{
reader_
->
ReInit
();
EndPrefetcher
();
StartPrefetcher
();
}
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
void
DoubleBufferReader
::
PrefetchThreadFunc
()
{
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
VLOG
(
5
)
<<
"A new prefetch thread starts."
;
size_t
cached_tensor_id
=
0
;
size_t
cached_tensor_id
=
0
;
...
...
paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
浏览文件 @
7d6afee5
...
@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader {
...
@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader {
MultiPassReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
pass_num
)
MultiPassReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
,
int
pass_num
)
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
:
DecoratedReader
(
reader
),
pass_num_
(
pass_num
),
pass_count_
(
0
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
reader_
->
ReadNext
(
out
);
reader_
->
ReadNext
(
out
);
if
(
out
->
empty
())
{
if
(
out
->
empty
()
&&
pass_count_
<
pass_num_
-
1
)
{
reader_
->
Shutdown
();
reader_
->
Start
();
reader_
->
ReadNext
(
out
);
++
pass_count_
;
++
pass_count_
;
if
(
pass_count_
<
pass_num_
)
{
reader_
->
ReInit
();
reader_
->
ReadNext
(
out
);
}
}
}
}
}
void
ReInit
()
override
{
private:
void
StartImpl
()
override
{
pass_count_
=
0
;
pass_count_
=
0
;
reader_
->
ReIni
t
();
reader_
->
Star
t
();
}
}
private:
int
pass_num_
;
int
pass_num_
;
mutable
int
pass_count_
;
mutable
int
pass_count_
;
};
};
...
@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
...
@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
int
pass_num
=
Attr
<
int
>
(
"pass_num"
);
int
pass_num
=
Attr
<
int
>
(
"pass_num"
);
out
->
Reset
(
new
MultiPassReader
(
underlying_reader
.
Get
(),
pass_num
));
out
->
Reset
(
framework
::
MakeDecoratedReader
<
MultiPassReader
>
(
underlying_reader
,
pass_num
));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
7d6afee5
...
@@ -19,22 +19,27 @@ namespace paddle {
...
@@ -19,22 +19,27 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
reader
{
namespace
reader
{
class
PyReader
:
public
framework
::
ReaderBase
{
class
PyReader
:
public
framework
::
FileReader
{
public:
public:
explicit
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
)
{
explicit
PyReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
)
:
framework
::
FileReader
()
{
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
PADDLE_ENFORCE
(
queue
!=
nullptr
,
"LoDTensorBlockingQueue must not be null"
);
queue_
=
queue
;
queue_
=
queue
;
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
bool
success
;
bool
success
;
*
out
=
queue_
->
Pop
(
&
success
);
*
out
=
queue_
->
Pop
(
&
success
);
if
(
!
success
)
out
->
clear
();
if
(
!
success
)
out
->
clear
();
}
}
void
ReInit
()
override
{}
private:
private:
void
ShutdownImpl
()
override
{
/* TODO */
}
void
StartImpl
()
override
{
/* TODO */
}
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue_
;
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue_
;
};
};
...
@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase {
...
@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase {
const
std
::
string
&
queue_name
=
Input
(
"blocking_queue"
);
const
std
::
string
&
queue_name
=
Input
(
"blocking_queue"
);
auto
*
queue_holder_var
=
scope
.
FindVar
(
queue_name
);
auto
*
queue_holder_var
=
scope
.
FindVar
(
queue_name
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE
_NOT_NULL
(
queue_holder_var
!=
nullptr
,
queue_holder_var
,
"No LoDTensorBlockingQueueHolder variable with name %s found"
,
"No LoDTensorBlockingQueueHolder variable with name %s found"
,
queue_name
);
queue_name
);
auto
*
queue_holder
=
auto
*
queue_holder
=
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
queue_holder_var
->
template
GetMutable
<
LoDTensorBlockingQueueHolder
>();
out
->
Reset
(
new
PyReader
(
queue_holder
->
GetQueue
()));
out
->
Reset
(
std
::
make_shared
<
PyReader
>
(
queue_holder
->
GetQueue
()));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_random_data_generator_op.cc
浏览文件 @
7d6afee5
...
@@ -19,11 +19,11 @@ namespace operators {
...
@@ -19,11 +19,11 @@ namespace operators {
namespace
reader
{
namespace
reader
{
template
<
typename
T
>
template
<
typename
T
>
class
RandomDataGenerator
:
public
framework
::
ReaderBase
{
class
RandomDataGenerator
:
public
framework
::
FileReader
{
public:
public:
RandomDataGenerator
(
const
std
::
vector
<
framework
::
DDim
>&
shapes
,
float
low
,
RandomDataGenerator
(
const
std
::
vector
<
framework
::
DDim
>&
shapes
,
float
low
,
float
high
)
float
high
)
:
framework
::
ReaderBase
(),
low_
(
low
),
high_
(
high
),
shapes_
(
shapes
)
{
:
framework
::
FileReader
(),
low_
(
low
),
high_
(
high
),
shapes_
(
shapes
)
{
PADDLE_ENFORCE_LE
(
low
,
high
,
PADDLE_ENFORCE_LE
(
low
,
high
,
"'low' shouldn't be greater than 'high'.(%f vs %f)"
,
low
,
"'low' shouldn't be greater than 'high'.(%f vs %f)"
,
low
,
high
);
high
);
...
@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
...
@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
dist_
=
std
::
uniform_real_distribution
<
float
>
(
low_
,
high_
);
dist_
=
std
::
uniform_real_distribution
<
float
>
(
low_
,
high_
);
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
out
->
clear
();
out
->
clear
();
out
->
reserve
(
shapes_
.
size
());
out
->
reserve
(
shapes_
.
size
());
for
(
const
framework
::
DDim
&
shape
:
shapes_
)
{
for
(
const
framework
::
DDim
&
shape
:
shapes_
)
{
...
@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase {
...
@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase {
}
}
}
}
void
ReInit
()
override
{
return
;
}
private:
private:
float
low_
;
float
low_
;
float
high_
;
float
high_
;
...
@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
...
@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
std
::
vector
<
framework
::
DDim
>
shapes
=
RestoreShapes
(
shape_concat
,
ranks
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RandomDataGenerator
<
T
>
(
shapes
,
Attr
<
float
>
(
"low"
),
out
->
Reset
(
std
::
make_shared
<
RandomDataGenerator
<
T
>>
(
Attr
<
float
>
(
"high"
)));
shapes
,
Attr
<
float
>
(
"low"
),
Attr
<
float
>
(
"high"
)));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
浏览文件 @
7d6afee5
...
@@ -21,10 +21,8 @@ namespace reader {
...
@@ -21,10 +21,8 @@ namespace reader {
template
<
bool
ThreadSafe
>
template
<
bool
ThreadSafe
>
class
RecordIOFileReader
:
public
framework
::
FileReader
{
class
RecordIOFileReader
:
public
framework
::
FileReader
{
public:
public:
explicit
RecordIOFileReader
(
const
std
::
string
&
filename
,
explicit
RecordIOFileReader
(
const
std
::
string
&
filename
)
const
std
::
vector
<
framework
::
DDim
>&
dims
)
:
scanner_
(
filename
),
:
FileReader
(
dims
),
scanner_
(
filename
),
dev_ctx_
(
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
dev_ctx_
(
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()))
{
platform
::
CPUPlace
()))
{
if
(
ThreadSafe
)
{
if
(
ThreadSafe
)
{
...
@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader {
...
@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader {
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
LOG
(
INFO
)
<<
"Creating file reader"
<<
filename
;
}
}
void
ReInit
()
override
{
scanner_
.
Reset
();
}
protected:
protected:
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
if
(
ThreadSafe
)
{
if
(
ThreadSafe
)
{
...
@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader {
...
@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader {
}
}
}
}
void
StartImpl
()
override
{
scanner_
.
Reset
();
}
private:
private:
std
::
unique_ptr
<
std
::
mutex
>
mutex_
;
std
::
unique_ptr
<
std
::
mutex
>
mutex_
;
recordio
::
Scanner
scanner_
;
recordio
::
Scanner
scanner_
;
...
@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
...
@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
private:
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
const
auto
&
shape_concat
=
Attr
<
std
::
vector
<
int
>>
(
"shape_concat"
);
const
auto
&
ranks
=
Attr
<
std
::
vector
<
int
>>
(
"ranks"
);
PADDLE_ENFORCE
(
!
shape_concat
.
empty
()
&&
!
ranks
.
empty
());
PADDLE_ENFORCE_EQ
(
std
::
accumulate
(
ranks
.
begin
(),
ranks
.
end
(),
0
),
static_cast
<
int
>
(
shape_concat
.
size
()),
"The accumulate of all ranks should be equal to the "
"shape concat's length."
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
std
::
string
filename
=
Attr
<
std
::
string
>
(
"filename"
);
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
RecordIOFileReader
<
true
>
(
out
->
Reset
(
std
::
make_shared
<
RecordIOFileReader
<
true
>>
(
filename
));
filename
,
RestoreShapes
(
shape_concat
,
ranks
)));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_shuffle_reader_op.cc
浏览文件 @
7d6afee5
...
@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader {
...
@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader {
ReloadBuffer
();
ReloadBuffer
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
void
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
out
->
clear
();
out
->
clear
();
if
(
iteration_pos_
>=
buffer_
.
size
())
{
if
(
iteration_pos_
>=
buffer_
.
size
())
{
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
VLOG
(
10
)
<<
"Resetting shuffle buffer"
;
...
@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
...
@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
}
}
private:
private:
void
ShutdownImpl
()
override
{
buffer_
.
clear
();
iteration_pos_
=
0
;
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
ReloadBuffer
();
}
void
ReloadBuffer
()
{
void
ReloadBuffer
()
{
buffer_
.
clear
();
buffer_
.
clear
();
buffer_
.
reserve
(
buffer_size_
);
buffer_
.
reserve
(
buffer_size_
);
...
@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
...
@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
out
->
Reset
(
framework
::
MakeDecoratedReader
<
ShuffleReader
>
(
new
ShuffleReader
(
underlying_reader
.
Get
(),
underlying_reader
,
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
static_cast
<
size_t
>
(
Attr
<
int
>
(
"buffer_size"
))));
}
}
};
};
...
...
paddle/fluid/operators/reader/create_threaded_reader_op.cc
已删除
100644 → 0
浏览文件 @
0a445da6
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace
paddle
{
namespace
operators
{
namespace
reader
{
class
ThreadedReader
:
public
framework
::
DecoratedReader
{
public:
explicit
ThreadedReader
(
const
std
::
shared_ptr
<
ReaderBase
>&
reader
)
:
DecoratedReader
(
reader
)
{}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
reader_
->
ReadNext
(
out
);
}
void
ReInit
()
override
{
reader_
->
ReInit
();
}
private:
std
::
mutex
mutex_
;
};
class
CreateThreadedReaderOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
detail
::
Ref
(
scope
.
FindVar
(
Output
(
"Out"
)))
.
GetMutable
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
out
->
Reset
(
new
ThreadedReader
(
underlying_reader
.
Get
()));
}
};
class
CreateThreadedReaderOpMaker
:
public
DecoratedReaderMakerBase
{
protected:
void
Apply
()
override
{
AddComment
(
R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC"
);
}
};
}
// namespace reader
}
// namespace operators
}
// namespace paddle
namespace
reader
=
paddle
::
operators
::
reader
;
REGISTER_DECORATED_READER_OPERATOR
(
create_threaded_reader
,
reader
::
CreateThreadedReaderOp
,
reader
::
CreateThreadedReaderOpMaker
);
paddle/fluid/operators/reader/open_files_op.cc
浏览文件 @
7d6afee5
...
@@ -23,24 +23,26 @@ namespace reader {
...
@@ -23,24 +23,26 @@ namespace reader {
class
MultiFileReader
:
public
framework
::
ReaderBase
{
class
MultiFileReader
:
public
framework
::
ReaderBase
{
public:
public:
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
MultiFileReader
(
const
std
::
vector
<
std
::
string
>&
file_names
,
size_t
thread_num
,
const
std
::
vector
<
framework
::
DDim
>&
dims
,
size_t
thread_num
,
size_t
buffer_size
)
size_t
buffer_size
)
:
buffer_size_
(
buffer_size
)
{
:
buffer_size_
(
buffer_size
)
{
readers_
.
reserve
(
file_names
.
size
());
readers_
.
reserve
(
file_names
.
size
());
for
(
const
std
::
string
&
f_name
:
file_names
)
{
for
(
const
std
::
string
&
f_name
:
file_names
)
{
readers_
.
emplace_back
(
CreateReaderByFileName
(
f_name
,
dims
));
readers_
.
emplace_back
(
CreateReaderByFileName
(
f_name
));
}
}
prefetchers_
.
resize
(
thread_num
);
prefetchers_
.
resize
(
thread_num
);
StartNewScheduler
();
StartNewScheduler
();
}
}
void
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReadNextImpl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
override
;
void
ReInit
()
override
;
~
MultiFileReader
()
{
EndScheduler
();
}
~
MultiFileReader
()
{
EndScheduler
();
}
private:
private:
void
ShutdownImpl
()
override
{
EndScheduler
();
}
void
StartImpl
()
override
{
StartNewScheduler
();
}
void
StartNewScheduler
();
void
StartNewScheduler
();
void
EndScheduler
();
void
EndScheduler
();
void
ScheduleThreadFunc
();
void
ScheduleThreadFunc
();
...
@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase {
...
@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase {
reader
::
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
reader
::
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>*
buffer_
;
};
};
void
MultiFileReader
::
ReadNext
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
void
MultiFileReader
::
ReadNext
Impl
(
std
::
vector
<
framework
::
LoDTensor
>*
out
)
{
if
(
!
buffer_
->
Receive
(
out
))
{
if
(
!
buffer_
->
Receive
(
out
))
{
out
->
clear
();
out
->
clear
();
}
}
}
}
void
MultiFileReader
::
ReInit
()
{
EndScheduler
();
StartNewScheduler
();
}
void
MultiFileReader
::
StartNewScheduler
()
{
void
MultiFileReader
::
StartNewScheduler
()
{
size_t
thread_num
=
prefetchers_
.
size
();
size_t
thread_num
=
prefetchers_
.
size
();
waiting_reader_idx_
=
new
reader
::
BlockingQueue
<
size_t
>
(
readers_
.
size
());
waiting_reader_idx_
=
new
reader
::
BlockingQueue
<
size_t
>
(
readers_
.
size
());
...
@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() {
...
@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() {
}
}
}
}
}
}
// If users invoke
ReInit
() when scheduler is running, it will close the
// If users invoke
Shutdown
() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
// to release their resource. So a check is needed before scheduler ends.
for
(
auto
&
p
:
prefetchers_
)
{
for
(
auto
&
p
:
prefetchers_
)
{
...
@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
...
@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std
::
vector
<
framework
::
LoDTensor
>
ins
;
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
reader
->
ReadNext
(
&
ins
);
if
(
ins
.
empty
())
{
if
(
ins
.
empty
())
{
reader
->
ReInit
();
reader
->
Shutdown
();
reader
->
Start
();
break
;
break
;
}
}
try
{
try
{
...
@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
...
@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
out
->
Reset
(
new
MultiFileReader
(
file_names
,
out
->
Reset
(
RestoreShapes
(
shape_concat
,
ranks
),
std
::
make_shared
<
MultiFileReader
>
(
file_names
,
thread_num
,
buffer_size
));
thread_num
,
buffer_size
));
}
}
};
};
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
7d6afee5
...
@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
...
@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
}
}
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
const
std
::
string
&
file_name
)
{
size_t
separator_pos
=
file_name
.
find_last_of
(
kFileFormatSeparator
);
size_t
separator_pos
=
file_name
.
find_last_of
(
kFileFormatSeparator
);
PADDLE_ENFORCE_NE
(
separator_pos
,
std
::
string
::
npos
,
PADDLE_ENFORCE_NE
(
separator_pos
,
std
::
string
::
npos
,
"File name illegal! A legal file name should be like: "
"File name illegal! A legal file name should be like: "
...
@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
...
@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
auto
itor
=
FileReaderRegistry
().
find
(
filetype
);
auto
itor
=
FileReaderRegistry
().
find
(
filetype
);
PADDLE_ENFORCE
(
itor
!=
FileReaderRegistry
().
end
(),
PADDLE_ENFORCE
(
itor
!=
FileReaderRegistry
().
end
(),
"No file reader registered for '%s' format."
,
filetype
);
"No file reader registered for '%s' format."
,
filetype
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
file_name
,
dims
);
framework
::
ReaderBase
*
reader
=
(
itor
->
second
)(
file_name
);
return
std
::
unique_ptr
<
framework
::
ReaderBase
>
(
reader
);
return
std
::
unique_ptr
<
framework
::
ReaderBase
>
(
reader
);
}
}
...
...
paddle/fluid/operators/reader/reader_op_registry.h
浏览文件 @
7d6afee5
...
@@ -25,22 +25,21 @@ namespace reader {
...
@@ -25,22 +25,21 @@ namespace reader {
static
constexpr
char
kFileFormatSeparator
[]
=
"."
;
static
constexpr
char
kFileFormatSeparator
[]
=
"."
;
using
FileReaderCreator
=
std
::
function
<
framework
::
ReaderBase
*
(
using
FileReaderCreator
=
const
std
::
string
&
,
const
std
::
vector
<
framework
::
DDim
>
&
)
>
;
std
::
function
<
framework
::
ReaderBase
*
(
const
std
::
string
&
)
>
;
std
::
unordered_map
<
std
::
string
,
FileReaderCreator
>&
FileReaderRegistry
();
std
::
unordered_map
<
std
::
string
,
FileReaderCreator
>&
FileReaderRegistry
();
template
<
typename
Reader
>
template
<
typename
Reader
>
int
RegisterFileReader
(
const
std
::
string
&
filetype
)
{
int
RegisterFileReader
(
const
std
::
string
&
filetype
)
{
FileReaderRegistry
()[
filetype
]
=
[](
FileReaderRegistry
()[
filetype
]
=
[](
const
std
::
string
&
fn
)
{
const
std
::
string
&
fn
,
const
std
::
vector
<
framework
::
DDim
>&
dims
)
{
return
new
Reader
(
fn
);
return
new
Reader
(
fn
,
dims
);
};
};
return
0
;
return
0
;
}
}
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
std
::
unique_ptr
<
framework
::
ReaderBase
>
CreateReaderByFileName
(
const
std
::
string
&
file_name
,
const
std
::
vector
<
framework
::
DDim
>&
dims
);
const
std
::
string
&
file_name
);
extern
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
extern
std
::
vector
<
framework
::
DDim
>
RestoreShapes
(
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
);
const
std
::
vector
<
int
>&
shape_concat
,
const
std
::
vector
<
int
>&
ranks
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
7d6afee5
...
@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
);
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
py
::
class_
<
framework
::
ReaderHolder
>
(
m
,
"Reader"
,
""
)
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
Re
Init
);
.
def
(
"reset"
,
&
framework
::
ReaderHolder
::
Re
setAll
);
using
LoDTensorBlockingQueue
=
using
LoDTensorBlockingQueue
=
::
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
::
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
7d6afee5
...
@@ -375,9 +375,6 @@ def open_recordio_file(filename,
...
@@ -375,9 +375,6 @@ def open_recordio_file(filename,
if
pass_num
>
1
:
if
pass_num
>
1
:
main_prog_var
=
multi_pass
(
reader
=
main_prog_var
,
pass_num
=
pass_num
)
main_prog_var
=
multi_pass
(
reader
=
main_prog_var
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_var
=
parallel
(
reader
=
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
return
monkey_patch_reader_methods
(
main_prog_var
)
...
@@ -529,9 +526,6 @@ def open_files(filenames,
...
@@ -529,9 +526,6 @@ def open_files(filenames,
main_prog_reader
=
multi_pass
(
main_prog_reader
=
multi_pass
(
reader
=
main_prog_reader
,
pass_num
=
pass_num
)
reader
=
main_prog_reader
,
pass_num
=
pass_num
)
if
for_parallel
:
main_prog_reader
=
parallel
(
reader
=
main_prog_reader
)
return
monkey_patch_reader_methods
(
main_prog_reader
)
return
monkey_patch_reader_methods
(
main_prog_reader
)
...
@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num):
...
@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader'
,
reader
,
{
'pass_num'
:
int
(
pass_num
)})
'create_multi_pass_reader'
,
reader
,
{
'pass_num'
:
int
(
pass_num
)})
def
parallel
(
reader
):
return
__create_shared_decorated_reader__
(
'create_threaded_reader'
,
reader
,
{})
def
read_file
(
reader
):
def
read_file
(
reader
):
"""
"""
Execute the given reader and get data via it.
Execute the given reader and get data via it.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录