Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6dadb5de
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6dadb5de
编写于
2月 10, 2020
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix iterable=False reset bug, add some logs and polish code, test=develop
上级
60d18a8f
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
123 addition
and
77 deletion
+123
-77
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+3
-0
paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_count_pass.cc
.../multi_devices_graph_pass/set_reader_device_count_pass.cc
+20
-0
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+6
-0
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+13
-3
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+6
-13
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
+67
-56
paddle/fluid/pybind/reader_py.cc
paddle/fluid/pybind/reader_py.cc
+1
-1
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+7
-4
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
6dadb5de
...
@@ -402,6 +402,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -402,6 +402,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
}
else
if
(
pass
->
Type
()
==
"set_reader_device_count_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"set_reader_device_count_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
}
}
VLOG
(
1
)
<<
"Start Apply Pass "
<<
pass
->
Type
();
VLOG
(
1
)
<<
"Start Apply Pass "
<<
pass
->
Type
();
graph
=
pass
->
Apply
(
graph
);
graph
=
pass
->
Apply
(
graph
);
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_count_pass.cc
浏览文件 @
6dadb5de
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -29,6 +30,8 @@ class SetReaderDeviceCountPass : public Pass {
...
@@ -29,6 +30,8 @@ class SetReaderDeviceCountPass : public Pass {
int
GetDeviceCount
()
const
;
int
GetDeviceCount
()
const
;
std
::
unordered_set
<
std
::
string
>
ReaderOpSet
()
const
;
std
::
unordered_set
<
std
::
string
>
ReaderOpSet
()
const
;
const
Scope
*
GlobalScope
()
const
;
};
};
int
SetReaderDeviceCountPass
::
GetDeviceCount
()
const
{
int
SetReaderDeviceCountPass
::
GetDeviceCount
()
const
{
...
@@ -40,9 +43,14 @@ std::unordered_set<std::string> SetReaderDeviceCountPass::ReaderOpSet() const {
...
@@ -40,9 +43,14 @@ std::unordered_set<std::string> SetReaderDeviceCountPass::ReaderOpSet() const {
return
{
"create_py_reader"
};
return
{
"create_py_reader"
};
}
}
const
Scope
*
SetReaderDeviceCountPass
::
GlobalScope
()
const
{
return
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
)[
0
];
}
void
SetReaderDeviceCountPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
SetReaderDeviceCountPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
auto
dev_cnt
=
GetDeviceCount
();
auto
dev_cnt
=
GetDeviceCount
();
auto
reader_ops
=
ReaderOpSet
();
auto
reader_ops
=
ReaderOpSet
();
auto
scope
=
GlobalScope
();
size_t
found_op_num
=
0
;
size_t
found_op_num
=
0
;
for
(
auto
&
node
:
graph
->
Nodes
())
{
for
(
auto
&
node
:
graph
->
Nodes
())
{
...
@@ -61,6 +69,18 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
...
@@ -61,6 +69,18 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
op_base_attrs
[
"device_index"
]
=
dev_idx
;
op_base_attrs
[
"device_index"
]
=
dev_idx
;
op_base_attrs
[
"device_count"
]
=
dev_cnt
;
op_base_attrs
[
"device_count"
]
=
dev_cnt
;
auto
queue_name
=
op_handle
.
GetOp
()
->
Input
(
"blocking_queue"
);
auto
var
=
scope
->
FindVar
(
queue_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Blocking queue of DataLoader not found"
));
using
QueueHolder
=
operators
::
reader
::
OrderedMultiDeviceLoDTensorBlockingQueueHolder
;
if
(
var
->
IsType
<
QueueHolder
>
())
{
var
->
GetMutable
<
QueueHolder
>
()
->
GetQueue
()
->
SetDeviceCount
(
dev_cnt
);
}
++
found_op_num
;
++
found_op_num
;
VLOG
(
10
)
<<
"Found op "
<<
op_desc
->
Type
()
<<
" on device "
<<
dev_idx
;
VLOG
(
10
)
<<
"Found op "
<<
op_desc
->
Type
()
<<
" on device "
<<
dev_idx
;
}
}
...
...
paddle/fluid/framework/reader.h
浏览文件 @
6dadb5de
...
@@ -117,6 +117,10 @@ class DecoratedReader : public ReaderBase,
...
@@ -117,6 +117,10 @@ class DecoratedReader : public ReaderBase,
~
DecoratedReader
();
~
DecoratedReader
();
const
std
::
shared_ptr
<
ReaderBase
>&
UnderlyingReader
()
const
{
return
reader_
;
}
protected:
protected:
void
ShutdownImpl
()
override
{
void
ShutdownImpl
()
override
{
VLOG
(
1
)
<<
"ShutdownImpl"
;
VLOG
(
1
)
<<
"ShutdownImpl"
;
...
@@ -190,6 +194,8 @@ class ReaderHolder {
...
@@ -190,6 +194,8 @@ class ReaderHolder {
return
reader_
->
NeedCheckFeed
();
return
reader_
->
NeedCheckFeed
();
}
}
void
Clear
()
{
reader_
.
reset
();
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
private:
private:
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
6dadb5de
...
@@ -27,12 +27,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...
@@ -27,12 +27,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
auto
*
decorated_reader
=
dynamic_cast
<
framework
::
DecoratedReader
*>
(
out
->
Get
().
get
());
PADDLE_ENFORCE_NOT_NULL
(
decorated_reader
,
platform
::
errors
::
NotFound
(
"Not inited with DecoratedReader"
));
if
(
decorated_reader
->
UnderlyingReader
()
==
underlying_reader
.
Get
())
{
return
;
}
}
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
platform
::
Place
place
;
platform
::
Place
place
;
if
(
place_str
==
"AUTO"
)
{
if
(
place_str
==
"AUTO"
)
{
...
@@ -47,6 +55,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...
@@ -47,6 +55,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
}
}
VLOG
(
10
)
<<
"Create new double buffer reader on "
<<
place
;
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BufferedReader
>
(
underlying_reader
,
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BufferedReader
>
(
underlying_reader
,
place
,
2
));
place
,
2
));
}
}
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
6dadb5de
...
@@ -40,6 +40,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
...
@@ -40,6 +40,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
queue_name
);
queue_name
);
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
;
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
;
std
::
shared_ptr
<
OrderedMultiDeviceLoDTensorBlockingQueue
>
ordered_queue
;
std
::
shared_ptr
<
OrderedMultiDeviceLoDTensorBlockingQueue
>
ordered_queue
;
int
dev_idx
=
-
1
;
if
(
queue_holder_var
->
IsType
<
LoDTensorBlockingQueueHolder
>
())
{
if
(
queue_holder_var
->
IsType
<
LoDTensorBlockingQueueHolder
>
())
{
queue
=
queue_holder_var
->
Get
<
LoDTensorBlockingQueueHolder
>
().
GetQueue
();
queue
=
queue_holder_var
->
Get
<
LoDTensorBlockingQueueHolder
>
().
GetQueue
();
}
else
if
(
queue_holder_var
}
else
if
(
queue_holder_var
...
@@ -47,10 +48,9 @@ class CreatePyReaderOp : public framework::OperatorBase {
...
@@ -47,10 +48,9 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto
*
queue_holder
=
auto
*
queue_holder
=
queue_holder_var
queue_holder_var
->
GetMutable
<
OrderedMultiDeviceLoDTensorBlockingQueueHolder
>
();
->
GetMutable
<
OrderedMultiDeviceLoDTensorBlockingQueueHolder
>
();
auto
dev_cnt
=
Attr
<
int
>
(
"device_count"
);
dev_idx
=
Attr
<
int
>
(
"device_index"
);
auto
dev_idx
=
static_cast
<
size_t
>
(
Attr
<
int
>
(
"device_index"
));
ordered_queue
=
queue_holder
->
GetQueue
();
ordered_queue
=
queue_holder
->
GetQueue
();
ordered_queue
->
InitOnce
(
dev_cnt
);
ordered_queue
->
SetDeviceCount
(
Attr
<
int
>
(
"device_count"
)
);
queue
=
ordered_queue
->
GetQueue
(
dev_idx
);
queue
=
ordered_queue
->
GetQueue
(
dev_idx
);
}
}
...
@@ -87,15 +87,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
...
@@ -87,15 +87,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto
py_reader
=
auto
py_reader
=
std
::
make_shared
<
PyReader
>
(
queue
,
dims
,
var_types
,
need_check_feed
);
std
::
make_shared
<
PyReader
>
(
queue
,
dims
,
var_types
,
need_check_feed
);
if
(
ordered_queue
)
{
if
(
ordered_queue
)
{
ordered_queue
->
AddResetMethod
([
py_reader
]
{
ordered_queue
->
SetResetMethod
(
dev_idx
,
[
out
]
{
out
->
Clear
();
});
auto
end_readers
=
py_reader
->
GetEndPoints
();
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Shutdown
();
}
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Start
();
}
});
}
}
out
->
Reset
(
py_reader
);
out
->
Reset
(
py_reader
);
}
}
...
@@ -109,8 +101,9 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
...
@@ -109,8 +101,9 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
AddAttr
<
int
>
(
"device_index"
,
"The device index this reader offers data"
)
AddAttr
<
int
>
(
"device_index"
,
"The device index this reader offers data"
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"device_count"
,
AddAttr
<
int
>
(
"device_count"
,
"The total
number of devices the
reader offers data"
)
"The total
device number this
reader offers data"
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
...
...
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
浏览文件 @
6dadb5de
...
@@ -32,6 +32,8 @@ class LoDTensorBlockingQueue {
...
@@ -32,6 +32,8 @@ class LoDTensorBlockingQueue {
explicit
LoDTensorBlockingQueue
(
size_t
capacity
,
bool
speed_test_mode
=
false
)
explicit
LoDTensorBlockingQueue
(
size_t
capacity
,
bool
speed_test_mode
=
false
)
:
queue_
(
capacity
,
speed_test_mode
)
{}
:
queue_
(
capacity
,
speed_test_mode
)
{}
~
LoDTensorBlockingQueue
()
{
VLOG
(
10
)
<<
"Destruct LoDTensorBlockingQueue"
;
}
bool
Push
(
const
std
::
vector
<
framework
::
LoDTensor
>&
lod_tensor_vec
)
{
bool
Push
(
const
std
::
vector
<
framework
::
LoDTensor
>&
lod_tensor_vec
)
{
return
queue_
.
Send
(
lod_tensor_vec
);
return
queue_
.
Send
(
lod_tensor_vec
);
}
}
...
@@ -62,7 +64,7 @@ class LoDTensorBlockingQueue {
...
@@ -62,7 +64,7 @@ class LoDTensorBlockingQueue {
inline
void
Kill
()
{
queue_
.
Kill
();
}
inline
void
Kill
()
{
queue_
.
Kill
();
}
inline
bool
WaitForInited
()
{
return
true
;
}
inline
bool
WaitForInited
(
size_t
)
{
return
true
;
}
private:
private:
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>
queue_
;
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>
queue_
;
...
@@ -74,47 +76,47 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
...
@@ -74,47 +76,47 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
bool
speed_test_mode
=
false
)
bool
speed_test_mode
=
false
)
:
capacity_
(
capacity
),
speed_test_mode_
(
speed_test_mode
)
{}
:
capacity_
(
capacity
),
speed_test_mode_
(
speed_test_mode
)
{}
inline
bool
WaitForInited
()
{
~
OrderedMultiDeviceLoDTensorBlockingQueue
()
{
VLOG
(
10
)
<<
"Destruct OrderedMultiDeviceLoDTensorBlockingQueue"
;
}
bool
WaitForInited
(
size_t
milliseconds
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
init_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
init_mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
queues_
!=
nullptr
||
is_closing_
;
});
return
cv_
.
wait_for
(
lock
,
std
::
chrono
::
milliseconds
(
milliseconds
),
is_closing_
=
false
;
[
this
]
{
return
!
queues_
.
empty
();
});
return
queues_
!=
nullptr
;
}
}
inline
void
InitOnce
(
size_t
dev_cnt
)
{
void
SetDeviceCount
(
size_t
dev_cnt
)
{
PADDLE_ENFORCE_GE
(
dev_cnt
,
1
,
platform
::
errors
::
InvalidArgument
(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"
));
VLOG
(
3
)
<<
"Ordered queue init start"
;
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
init_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
init_mutex_
);
if
(
queues_
)
{
PADDLE_ENFORCE_GE
(
dev_cnt
,
1
,
PADDLE_ENFORCE_EQ
(
queues_
->
size
(),
dev_cnt
,
platform
::
errors
::
InvalidArgument
(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"
));
if
(
!
queues_
.
empty
())
{
PADDLE_ENFORCE_EQ
(
queues_
.
size
(),
dev_cnt
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Device count to init queue must be equal"
));
"queues should be only inited once"
));
}
else
{
return
;
queues_
.
reset
(
}
new
std
::
vector
<
std
::
shared_ptr
<
LoDTensorBlockingQueue
>>
(
dev_cnt
));
for
(
auto
&
item
:
*
queues_
)
{
VLOG
(
1
)
<<
"Init queue with size "
<<
dev_cnt
;
auto
cap
=
(
capacity_
+
dev_cnt
-
1
)
/
dev_cnt
;
queues_
.
resize
(
dev_cnt
);
item
.
reset
(
new
LoDTensorBlockingQueue
(
cap
,
speed_test_mode_
));
for
(
auto
&
item
:
queues_
)
{
}
auto
cap
=
(
capacity_
+
dev_cnt
-
1
)
/
dev_cnt
;
item
.
reset
(
new
LoDTensorBlockingQueue
(
cap
,
speed_test_mode_
));
}
}
}
}
VLOG
(
3
)
<<
"Ordered queue init finish"
;
cv_
.
notify_all
();
cv_
.
notify_all
();
}
}
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
GetQueue
(
size_t
idx
)
const
{
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
GetQueue
(
size_t
idx
)
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
init_mutex_
);
EnforceIsInited
();
PADDLE_ENFORCE_NOT_NULL
(
queues_
,
platform
::
errors
::
NotFound
(
"Queues must be inited first before getting"
));
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
idx
,
queues_
->
size
(),
idx
,
queues_
.
size
(),
platform
::
errors
::
OutOfRange
(
"The queue index is out of range"
));
platform
::
errors
::
OutOfRange
(
"The queue index is out of range"
));
return
(
*
queues_
)
[
idx
];
return
queues_
[
idx
];
}
}
bool
Push
(
const
std
::
vector
<
framework
::
LoDTensor
>&
lod_tensor_vec
)
{
bool
Push
(
const
std
::
vector
<
framework
::
LoDTensor
>&
lod_tensor_vec
)
{
...
@@ -123,65 +125,74 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
...
@@ -123,65 +125,74 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
inline
size_t
Size
()
const
{
inline
size_t
Size
()
const
{
size_t
size
=
0
;
size_t
size
=
0
;
if
(
queues_
)
{
for
(
auto
&
item
:
queues_
)
{
for
(
auto
&
item
:
*
queues_
)
{
size
+=
item
->
Size
();
size
+=
item
->
Size
();
}
}
}
return
size
;
return
size
;
}
}
inline
void
Close
()
{
inline
void
Close
()
{
{
for
(
auto
&
item
:
queues_
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
init_mutex_
);
item
->
Close
();
if
(
queues_
==
nullptr
)
{
is_closing_
=
true
;
}
}
cv_
.
notify_all
();
if
(
queues_
)
{
for
(
auto
&
item
:
*
queues_
)
{
item
->
Close
();
}
}
}
data_index_
=
0
;
}
}
inline
void
Kill
()
{
inline
void
Kill
()
{
if
(
queues_
)
{
for
(
auto
&
item
:
queues_
)
{
for
(
auto
&
item
:
*
queues_
)
{
item
->
Kill
();
item
->
Kill
();
}
}
}
}
}
inline
void
Reset
()
{
inline
void
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
reset_lock
(
reset_mutex_
);
{
for
(
auto
&
method
:
reset_methods_
)
{
std
::
lock_guard
<
std
::
mutex
>
reset_lock
(
reset_mutex_
);
method
();
for
(
auto
&
method
:
reset_methods_
)
{
if
(
method
)
method
();
}
}
auto
dev_cnt
=
queues_
.
size
();
for
(
auto
&
item
:
queues_
)
{
auto
cap
=
(
capacity_
+
dev_cnt
-
1
)
/
dev_cnt
;
item
.
reset
(
new
LoDTensorBlockingQueue
(
cap
,
speed_test_mode_
));
}
}
data_index_
=
0
;
}
}
inline
void
AddResetMethod
(
const
std
::
function
<
void
()
>&
reset_method
)
{
inline
void
SetResetMethod
(
size_t
idx
,
const
std
::
function
<
void
()
>&
reset_method
)
{
std
::
lock_guard
<
std
::
mutex
>
reset_lock
(
reset_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
reset_lock
(
reset_mutex_
);
reset_methods_
.
emplace_back
(
reset_method
);
EnforceIsInited
();
if
(
reset_methods_
.
size
()
<=
idx
)
{
reset_methods_
.
resize
(
idx
+
1
);
}
reset_methods_
[
idx
]
=
reset_method
;
}
}
private:
private:
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
CurQueue
()
{
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
CurQueue
()
{
return
(
*
queues_
)[
data_index_
.
fetch_add
(
1
)
%
queues_
->
size
()];
EnforceIsInited
();
return
queues_
[
data_index_
.
fetch_add
(
1
)
%
queues_
.
size
()];
}
private:
void
EnforceIsInited
()
const
{
PADDLE_ENFORCE_EQ
(
queues_
.
empty
(),
false
,
platform
::
errors
::
NotFound
(
"queue has not been inited"
));
}
}
private:
private:
std
::
unique_ptr
<
std
::
vector
<
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
>>
queues_
;
std
::
vector
<
std
::
shared_ptr
<
LoDTensorBlockingQueue
>>
queues_
;
mutable
std
::
atomic
<
uint64_t
>
data_index_
{
0
};
mutable
std
::
atomic
<
uint64_t
>
data_index_
{
0
};
size_t
dev_cnt_
{
0
};
const
size_t
capacity_
;
const
size_t
capacity_
;
const
bool
speed_test_mode_
;
const
bool
speed_test_mode_
;
bool
is_closed_
{
false
};
std
::
vector
<
std
::
function
<
void
()
>>
reset_methods_
;
std
::
vector
<
std
::
function
<
void
()
>>
reset_methods_
;
mutable
std
::
mutex
reset_mutex_
;
mutable
std
::
mutex
reset_mutex_
;
bool
is_closing_
{
false
};
mutable
std
::
mutex
init_mutex_
;
mutable
std
::
mutex
init_mutex_
;
mutable
std
::
condition_variable
cv_
;
mutable
std
::
condition_variable
cv_
;
};
};
...
...
paddle/fluid/pybind/reader_py.cc
浏览文件 @
6dadb5de
...
@@ -354,7 +354,7 @@ void BindReader(py::module *module) {
...
@@ -354,7 +354,7 @@ void BindReader(py::module *module) {
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
bool
use_double_buffer
)
{
bool
use_double_buffer
)
{
queue
->
InitOnce
(
dst_places
.
size
());
queue
->
SetDeviceCount
(
dst_places
.
size
());
return
new
MultiDeviceFeedReader
<
return
new
MultiDeviceFeedReader
<
reader
::
OrderedMultiDeviceLoDTensorBlockingQueue
>
(
reader
::
OrderedMultiDeviceLoDTensorBlockingQueue
>
(
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
...
...
python/paddle/fluid/reader.py
浏览文件 @
6dadb5de
...
@@ -347,7 +347,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -347,7 +347,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
self
.
_batch_reader
=
None
self
.
_batch_reader
=
None
self
.
_places
=
None
self
.
_places
=
None
self
.
_feed_list
=
feed_list
self
.
_feed_list
=
feed_list
self
.
_keep_order
=
True
if
not
capacity
:
if
not
capacity
:
raise
ValueError
(
"Please give value to capacity."
)
raise
ValueError
(
"Please give value to capacity."
)
...
@@ -420,7 +419,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -420,7 +419,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
self
.
_dtypes
=
[]
self
.
_dtypes
=
[]
self
.
_need_check_feed
=
[]
self
.
_need_check_feed
=
[]
self
.
_blocking_queue
=
core
.
init_lod_tensor_blocking_queue
(
self
.
_blocking_queue
=
core
.
init_lod_tensor_blocking_queue
(
core
.
Variable
(),
self
.
_capacity
,
self
.
_keep_order
)
core
.
Variable
(),
self
.
_capacity
,
False
)
self
.
_reader
=
core
.
create_py_reader
(
self
.
_reader
=
core
.
create_py_reader
(
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
)
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
)
...
@@ -635,6 +634,7 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -635,6 +634,7 @@ class GeneratorLoader(DataLoaderBase):
self
.
_thread
=
None
self
.
_thread
=
None
self
.
_queue
=
None
self
.
_queue
=
None
self
.
_feed_list
=
feed_list
self
.
_feed_list
=
feed_list
self
.
_exited
=
False
if
not
capacity
:
if
not
capacity
:
raise
ValueError
(
"Please give value to capacity."
)
raise
ValueError
(
"Please give value to capacity."
)
self
.
_iterable
=
iterable
self
.
_iterable
=
iterable
...
@@ -798,8 +798,9 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -798,8 +798,9 @@ class GeneratorLoader(DataLoaderBase):
def
_start
(
self
):
def
_start
(
self
):
def
__thread_main__
():
def
__thread_main__
():
try
:
try
:
if
not
self
.
_queue
.
wait_for_inited
():
while
not
self
.
_queue
.
wait_for_inited
(
1
):
return
if
self
.
_exited
:
return
for
tensors
in
self
.
_tensor_reader
():
for
tensors
in
self
.
_tensor_reader
():
array
=
core
.
LoDTensorArray
()
array
=
core
.
LoDTensorArray
()
...
@@ -829,10 +830,12 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -829,10 +830,12 @@ class GeneratorLoader(DataLoaderBase):
def
_reset
(
self
):
def
_reset
(
self
):
self
.
_queue
.
close
()
self
.
_queue
.
close
()
self
.
_exited
=
True
thread
=
self
.
_thread
thread
=
self
.
_thread
if
thread
is
not
None
:
if
thread
is
not
None
:
thread
.
join
()
thread
.
join
()
self
.
_exited
=
False
self
.
_reader
.
reset
()
self
.
_reader
.
reset
()
def
set_sample_generator
(
self
,
def
set_sample_generator
(
self
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录