Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6dadb5de
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6dadb5de
编写于
2月 10, 2020
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix iterable=False reset bug, add some logs and polish code, test=develop
上级
60d18a8f
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
123 addition
and
77 deletion
+123
-77
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+3
-0
paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_count_pass.cc
.../multi_devices_graph_pass/set_reader_device_count_pass.cc
+20
-0
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+6
-0
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+13
-3
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+6
-13
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
+67
-56
paddle/fluid/pybind/reader_py.cc
paddle/fluid/pybind/reader_py.cc
+1
-1
python/paddle/fluid/reader.py
python/paddle/fluid/reader.py
+7
-4
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
6dadb5de
...
@@ -402,6 +402,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -402,6 +402,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
}
else
if
(
pass
->
Type
()
==
"set_reader_device_count_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"set_reader_device_count_pass"
)
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
}
}
VLOG
(
1
)
<<
"Start Apply Pass "
<<
pass
->
Type
();
VLOG
(
1
)
<<
"Start Apply Pass "
<<
pass
->
Type
();
graph
=
pass
->
Apply
(
graph
);
graph
=
pass
->
Apply
(
graph
);
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_count_pass.cc
浏览文件 @
6dadb5de
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -29,6 +30,8 @@ class SetReaderDeviceCountPass : public Pass {
...
@@ -29,6 +30,8 @@ class SetReaderDeviceCountPass : public Pass {
int
GetDeviceCount
()
const
;
int
GetDeviceCount
()
const
;
std
::
unordered_set
<
std
::
string
>
ReaderOpSet
()
const
;
std
::
unordered_set
<
std
::
string
>
ReaderOpSet
()
const
;
const
Scope
*
GlobalScope
()
const
;
};
};
int
SetReaderDeviceCountPass
::
GetDeviceCount
()
const
{
int
SetReaderDeviceCountPass
::
GetDeviceCount
()
const
{
...
@@ -40,9 +43,14 @@ std::unordered_set<std::string> SetReaderDeviceCountPass::ReaderOpSet() const {
...
@@ -40,9 +43,14 @@ std::unordered_set<std::string> SetReaderDeviceCountPass::ReaderOpSet() const {
return
{
"create_py_reader"
};
return
{
"create_py_reader"
};
}
}
const
Scope
*
SetReaderDeviceCountPass
::
GlobalScope
()
const
{
return
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
)[
0
];
}
void
SetReaderDeviceCountPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
SetReaderDeviceCountPass
::
ApplyImpl
(
Graph
*
graph
)
const
{
auto
dev_cnt
=
GetDeviceCount
();
auto
dev_cnt
=
GetDeviceCount
();
auto
reader_ops
=
ReaderOpSet
();
auto
reader_ops
=
ReaderOpSet
();
auto
scope
=
GlobalScope
();
size_t
found_op_num
=
0
;
size_t
found_op_num
=
0
;
for
(
auto
&
node
:
graph
->
Nodes
())
{
for
(
auto
&
node
:
graph
->
Nodes
())
{
...
@@ -61,6 +69,18 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
...
@@ -61,6 +69,18 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
op_base_attrs
[
"device_index"
]
=
dev_idx
;
op_base_attrs
[
"device_index"
]
=
dev_idx
;
op_base_attrs
[
"device_count"
]
=
dev_cnt
;
op_base_attrs
[
"device_count"
]
=
dev_cnt
;
auto
queue_name
=
op_handle
.
GetOp
()
->
Input
(
"blocking_queue"
);
auto
var
=
scope
->
FindVar
(
queue_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Blocking queue of DataLoader not found"
));
using
QueueHolder
=
operators
::
reader
::
OrderedMultiDeviceLoDTensorBlockingQueueHolder
;
if
(
var
->
IsType
<
QueueHolder
>
())
{
var
->
GetMutable
<
QueueHolder
>
()
->
GetQueue
()
->
SetDeviceCount
(
dev_cnt
);
}
++
found_op_num
;
++
found_op_num
;
VLOG
(
10
)
<<
"Found op "
<<
op_desc
->
Type
()
<<
" on device "
<<
dev_idx
;
VLOG
(
10
)
<<
"Found op "
<<
op_desc
->
Type
()
<<
" on device "
<<
dev_idx
;
}
}
...
...
paddle/fluid/framework/reader.h
浏览文件 @
6dadb5de
...
@@ -117,6 +117,10 @@ class DecoratedReader : public ReaderBase,
...
@@ -117,6 +117,10 @@ class DecoratedReader : public ReaderBase,
~
DecoratedReader
();
~
DecoratedReader
();
const
std
::
shared_ptr
<
ReaderBase
>&
UnderlyingReader
()
const
{
return
reader_
;
}
protected:
protected:
void
ShutdownImpl
()
override
{
void
ShutdownImpl
()
override
{
VLOG
(
1
)
<<
"ShutdownImpl"
;
VLOG
(
1
)
<<
"ShutdownImpl"
;
...
@@ -190,6 +194,8 @@ class ReaderHolder {
...
@@ -190,6 +194,8 @@ class ReaderHolder {
return
reader_
->
NeedCheckFeed
();
return
reader_
->
NeedCheckFeed
();
}
}
void
Clear
()
{
reader_
.
reset
();
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
operator
const
std
::
shared_ptr
<
ReaderBase
>&
()
const
{
return
this
->
reader_
;
}
private:
private:
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
6dadb5de
...
@@ -27,12 +27,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...
@@ -27,12 +27,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
auto
*
out
=
scope
.
FindVar
(
Output
(
"Out"
))
->
template
GetMutable
<
framework
::
ReaderHolder
>();
->
template
GetMutable
<
framework
::
ReaderHolder
>();
if
(
out
->
Get
()
!=
nullptr
)
{
return
;
}
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
const
auto
&
underlying_reader
=
scope
.
FindVar
(
Input
(
"UnderlyingReader"
))
->
Get
<
framework
::
ReaderHolder
>
();
->
Get
<
framework
::
ReaderHolder
>
();
if
(
out
->
Get
()
!=
nullptr
)
{
auto
*
decorated_reader
=
dynamic_cast
<
framework
::
DecoratedReader
*>
(
out
->
Get
().
get
());
PADDLE_ENFORCE_NOT_NULL
(
decorated_reader
,
platform
::
errors
::
NotFound
(
"Not inited with DecoratedReader"
));
if
(
decorated_reader
->
UnderlyingReader
()
==
underlying_reader
.
Get
())
{
return
;
}
}
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
auto
place_str
=
Attr
<
std
::
string
>
(
"place"
);
platform
::
Place
place
;
platform
::
Place
place
;
if
(
place_str
==
"AUTO"
)
{
if
(
place_str
==
"AUTO"
)
{
...
@@ -47,6 +55,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
...
@@ -47,6 +55,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
place
=
platform
::
CUDAPlace
(
static_cast
<
int
>
(
num
));
}
}
VLOG
(
10
)
<<
"Create new double buffer reader on "
<<
place
;
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BufferedReader
>
(
underlying_reader
,
out
->
Reset
(
framework
::
MakeDecoratedReader
<
BufferedReader
>
(
underlying_reader
,
place
,
2
));
place
,
2
));
}
}
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
6dadb5de
...
@@ -40,6 +40,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
...
@@ -40,6 +40,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
queue_name
);
queue_name
);
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
;
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
queue
;
std
::
shared_ptr
<
OrderedMultiDeviceLoDTensorBlockingQueue
>
ordered_queue
;
std
::
shared_ptr
<
OrderedMultiDeviceLoDTensorBlockingQueue
>
ordered_queue
;
int
dev_idx
=
-
1
;
if
(
queue_holder_var
->
IsType
<
LoDTensorBlockingQueueHolder
>
())
{
if
(
queue_holder_var
->
IsType
<
LoDTensorBlockingQueueHolder
>
())
{
queue
=
queue_holder_var
->
Get
<
LoDTensorBlockingQueueHolder
>
().
GetQueue
();
queue
=
queue_holder_var
->
Get
<
LoDTensorBlockingQueueHolder
>
().
GetQueue
();
}
else
if
(
queue_holder_var
}
else
if
(
queue_holder_var
...
@@ -47,10 +48,9 @@ class CreatePyReaderOp : public framework::OperatorBase {
...
@@ -47,10 +48,9 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto
*
queue_holder
=
auto
*
queue_holder
=
queue_holder_var
queue_holder_var
->
GetMutable
<
OrderedMultiDeviceLoDTensorBlockingQueueHolder
>
();
->
GetMutable
<
OrderedMultiDeviceLoDTensorBlockingQueueHolder
>
();
auto
dev_cnt
=
Attr
<
int
>
(
"device_count"
);
dev_idx
=
Attr
<
int
>
(
"device_index"
);
auto
dev_idx
=
static_cast
<
size_t
>
(
Attr
<
int
>
(
"device_index"
));
ordered_queue
=
queue_holder
->
GetQueue
();
ordered_queue
=
queue_holder
->
GetQueue
();
ordered_queue
->
InitOnce
(
dev_cnt
);
ordered_queue
->
SetDeviceCount
(
Attr
<
int
>
(
"device_count"
)
);
queue
=
ordered_queue
->
GetQueue
(
dev_idx
);
queue
=
ordered_queue
->
GetQueue
(
dev_idx
);
}
}
...
@@ -87,15 +87,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
...
@@ -87,15 +87,7 @@ class CreatePyReaderOp : public framework::OperatorBase {
auto
py_reader
=
auto
py_reader
=
std
::
make_shared
<
PyReader
>
(
queue
,
dims
,
var_types
,
need_check_feed
);
std
::
make_shared
<
PyReader
>
(
queue
,
dims
,
var_types
,
need_check_feed
);
if
(
ordered_queue
)
{
if
(
ordered_queue
)
{
ordered_queue
->
AddResetMethod
([
py_reader
]
{
ordered_queue
->
SetResetMethod
(
dev_idx
,
[
out
]
{
out
->
Clear
();
});
auto
end_readers
=
py_reader
->
GetEndPoints
();
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Shutdown
();
}
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Start
();
}
});
}
}
out
->
Reset
(
py_reader
);
out
->
Reset
(
py_reader
);
}
}
...
@@ -109,8 +101,9 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
...
@@ -109,8 +101,9 @@ class CreatePyReaderOpMaker : public FileReaderMakerBase {
AddAttr
<
int
>
(
"device_index"
,
"The device index this reader offers data"
)
AddAttr
<
int
>
(
"device_index"
,
"The device index this reader offers data"
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"device_count"
,
AddAttr
<
int
>
(
"device_count"
,
"The total
number of devices the
reader offers data"
)
"The total
device number this
reader offers data"
)
.
SetDefault
(
1
);
.
SetDefault
(
1
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
...
...
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
浏览文件 @
6dadb5de
...
@@ -32,6 +32,8 @@ class LoDTensorBlockingQueue {
...
@@ -32,6 +32,8 @@ class LoDTensorBlockingQueue {
explicit
LoDTensorBlockingQueue
(
size_t
capacity
,
bool
speed_test_mode
=
false
)
explicit
LoDTensorBlockingQueue
(
size_t
capacity
,
bool
speed_test_mode
=
false
)
:
queue_
(
capacity
,
speed_test_mode
)
{}
:
queue_
(
capacity
,
speed_test_mode
)
{}
~
LoDTensorBlockingQueue
()
{
VLOG
(
10
)
<<
"Destruct LoDTensorBlockingQueue"
;
}
bool
Push
(
const
std
::
vector
<
framework
::
LoDTensor
>&
lod_tensor_vec
)
{
bool
Push
(
const
std
::
vector
<
framework
::
LoDTensor
>&
lod_tensor_vec
)
{
return
queue_
.
Send
(
lod_tensor_vec
);
return
queue_
.
Send
(
lod_tensor_vec
);
}
}
...
@@ -62,7 +64,7 @@ class LoDTensorBlockingQueue {
...
@@ -62,7 +64,7 @@ class LoDTensorBlockingQueue {
inline
void
Kill
()
{
queue_
.
Kill
();
}
inline
void
Kill
()
{
queue_
.
Kill
();
}
inline
bool
WaitForInited
()
{
return
true
;
}
inline
bool
WaitForInited
(
size_t
)
{
return
true
;
}
private:
private:
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>
queue_
;
BlockingQueue
<
std
::
vector
<
framework
::
LoDTensor
>>
queue_
;
...
@@ -74,47 +76,47 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
...
@@ -74,47 +76,47 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
bool
speed_test_mode
=
false
)
bool
speed_test_mode
=
false
)
:
capacity_
(
capacity
),
speed_test_mode_
(
speed_test_mode
)
{}
:
capacity_
(
capacity
),
speed_test_mode_
(
speed_test_mode
)
{}
inline
bool
WaitForInited
()
{
~
OrderedMultiDeviceLoDTensorBlockingQueue
()
{
VLOG
(
10
)
<<
"Destruct OrderedMultiDeviceLoDTensorBlockingQueue"
;
}
bool
WaitForInited
(
size_t
milliseconds
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
init_mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
init_mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
queues_
!=
nullptr
||
is_closing_
;
});
return
cv_
.
wait_for
(
lock
,
std
::
chrono
::
milliseconds
(
milliseconds
),
is_closing_
=
false
;
[
this
]
{
return
!
queues_
.
empty
();
});
return
queues_
!=
nullptr
;
}
}
inline
void
InitOnce
(
size_t
dev_cnt
)
{
void
SetDeviceCount
(
size_t
dev_cnt
)
{
PADDLE_ENFORCE_GE
(
dev_cnt
,
1
,
platform
::
errors
::
InvalidArgument
(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"
));
VLOG
(
3
)
<<
"Ordered queue init start"
;
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
init_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
init_mutex_
);
if
(
queues_
)
{
PADDLE_ENFORCE_GE
(
dev_cnt
,
1
,
PADDLE_ENFORCE_EQ
(
queues_
->
size
(),
dev_cnt
,
platform
::
errors
::
InvalidArgument
(
"Device count to init "
"OrderedMultiDeviceLoDTensorBlockingQueue"
" must be larger than 1"
));
if
(
!
queues_
.
empty
())
{
PADDLE_ENFORCE_EQ
(
queues_
.
size
(),
dev_cnt
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Device count to init queue must be equal"
));
"queues should be only inited once"
));
}
else
{
return
;
queues_
.
reset
(
}
new
std
::
vector
<
std
::
shared_ptr
<
LoDTensorBlockingQueue
>>
(
dev_cnt
));
for
(
auto
&
item
:
*
queues_
)
{
VLOG
(
1
)
<<
"Init queue with size "
<<
dev_cnt
;
auto
cap
=
(
capacity_
+
dev_cnt
-
1
)
/
dev_cnt
;
queues_
.
resize
(
dev_cnt
);
item
.
reset
(
new
LoDTensorBlockingQueue
(
cap
,
speed_test_mode_
));
for
(
auto
&
item
:
queues_
)
{
}
auto
cap
=
(
capacity_
+
dev_cnt
-
1
)
/
dev_cnt
;
item
.
reset
(
new
LoDTensorBlockingQueue
(
cap
,
speed_test_mode_
));
}
}
}
}
VLOG
(
3
)
<<
"Ordered queue init finish"
;
cv_
.
notify_all
();
cv_
.
notify_all
();
}
}
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
GetQueue
(
size_t
idx
)
const
{
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
GetQueue
(
size_t
idx
)
const
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
init_mutex_
);
EnforceIsInited
();
PADDLE_ENFORCE_NOT_NULL
(
queues_
,
platform
::
errors
::
NotFound
(
"Queues must be inited first before getting"
));
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
idx
,
queues_
->
size
(),
idx
,
queues_
.
size
(),
platform
::
errors
::
OutOfRange
(
"The queue index is out of range"
));
platform
::
errors
::
OutOfRange
(
"The queue index is out of range"
));
return
(
*
queues_
)
[
idx
];
return
queues_
[
idx
];
}
}
bool
Push
(
const
std
::
vector
<
framework
::
LoDTensor
>&
lod_tensor_vec
)
{
bool
Push
(
const
std
::
vector
<
framework
::
LoDTensor
>&
lod_tensor_vec
)
{
...
@@ -123,65 +125,74 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
...
@@ -123,65 +125,74 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
inline
size_t
Size
()
const
{
inline
size_t
Size
()
const
{
size_t
size
=
0
;
size_t
size
=
0
;
if
(
queues_
)
{
for
(
auto
&
item
:
queues_
)
{
for
(
auto
&
item
:
*
queues_
)
{
size
+=
item
->
Size
();
size
+=
item
->
Size
();
}
}
}
return
size
;
return
size
;
}
}
inline
void
Close
()
{
inline
void
Close
()
{
{
for
(
auto
&
item
:
queues_
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
init_mutex_
);
item
->
Close
();
if
(
queues_
==
nullptr
)
{
is_closing_
=
true
;
}
}
cv_
.
notify_all
();
if
(
queues_
)
{
for
(
auto
&
item
:
*
queues_
)
{
item
->
Close
();
}
}
}
data_index_
=
0
;
}
}
inline
void
Kill
()
{
inline
void
Kill
()
{
if
(
queues_
)
{
for
(
auto
&
item
:
queues_
)
{
for
(
auto
&
item
:
*
queues_
)
{
item
->
Kill
();
item
->
Kill
();
}
}
}
}
}
inline
void
Reset
()
{
inline
void
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
reset_lock
(
reset_mutex_
);
{
for
(
auto
&
method
:
reset_methods_
)
{
std
::
lock_guard
<
std
::
mutex
>
reset_lock
(
reset_mutex_
);
method
();
for
(
auto
&
method
:
reset_methods_
)
{
if
(
method
)
method
();
}
}
auto
dev_cnt
=
queues_
.
size
();
for
(
auto
&
item
:
queues_
)
{
auto
cap
=
(
capacity_
+
dev_cnt
-
1
)
/
dev_cnt
;
item
.
reset
(
new
LoDTensorBlockingQueue
(
cap
,
speed_test_mode_
));
}
}
data_index_
=
0
;
}
}
inline
void
AddResetMethod
(
const
std
::
function
<
void
()
>&
reset_method
)
{
inline
void
SetResetMethod
(
size_t
idx
,
const
std
::
function
<
void
()
>&
reset_method
)
{
std
::
lock_guard
<
std
::
mutex
>
reset_lock
(
reset_mutex_
);
std
::
lock_guard
<
std
::
mutex
>
reset_lock
(
reset_mutex_
);
reset_methods_
.
emplace_back
(
reset_method
);
EnforceIsInited
();
if
(
reset_methods_
.
size
()
<=
idx
)
{
reset_methods_
.
resize
(
idx
+
1
);
}
reset_methods_
[
idx
]
=
reset_method
;
}
}
private:
private:
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
CurQueue
()
{
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
CurQueue
()
{
return
(
*
queues_
)[
data_index_
.
fetch_add
(
1
)
%
queues_
->
size
()];
EnforceIsInited
();
return
queues_
[
data_index_
.
fetch_add
(
1
)
%
queues_
.
size
()];
}
private:
void
EnforceIsInited
()
const
{
PADDLE_ENFORCE_EQ
(
queues_
.
empty
(),
false
,
platform
::
errors
::
NotFound
(
"queue has not been inited"
));
}
}
private:
private:
std
::
unique_ptr
<
std
::
vector
<
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
>>
queues_
;
std
::
vector
<
std
::
shared_ptr
<
LoDTensorBlockingQueue
>>
queues_
;
mutable
std
::
atomic
<
uint64_t
>
data_index_
{
0
};
mutable
std
::
atomic
<
uint64_t
>
data_index_
{
0
};
size_t
dev_cnt_
{
0
};
const
size_t
capacity_
;
const
size_t
capacity_
;
const
bool
speed_test_mode_
;
const
bool
speed_test_mode_
;
bool
is_closed_
{
false
};
std
::
vector
<
std
::
function
<
void
()
>>
reset_methods_
;
std
::
vector
<
std
::
function
<
void
()
>>
reset_methods_
;
mutable
std
::
mutex
reset_mutex_
;
mutable
std
::
mutex
reset_mutex_
;
bool
is_closing_
{
false
};
mutable
std
::
mutex
init_mutex_
;
mutable
std
::
mutex
init_mutex_
;
mutable
std
::
condition_variable
cv_
;
mutable
std
::
condition_variable
cv_
;
};
};
...
...
paddle/fluid/pybind/reader_py.cc
浏览文件 @
6dadb5de
...
@@ -354,7 +354,7 @@ void BindReader(py::module *module) {
...
@@ -354,7 +354,7 @@ void BindReader(py::module *module) {
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
bool
>
&
need_check_feed
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
const
std
::
vector
<
platform
::
Place
>
&
dst_places
,
bool
use_double_buffer
)
{
bool
use_double_buffer
)
{
queue
->
InitOnce
(
dst_places
.
size
());
queue
->
SetDeviceCount
(
dst_places
.
size
());
return
new
MultiDeviceFeedReader
<
return
new
MultiDeviceFeedReader
<
reader
::
OrderedMultiDeviceLoDTensorBlockingQueue
>
(
reader
::
OrderedMultiDeviceLoDTensorBlockingQueue
>
(
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
queue
,
names
,
shapes
,
dtypes
,
need_check_feed
,
dst_places
,
...
...
python/paddle/fluid/reader.py
浏览文件 @
6dadb5de
...
@@ -347,7 +347,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -347,7 +347,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
self
.
_batch_reader
=
None
self
.
_batch_reader
=
None
self
.
_places
=
None
self
.
_places
=
None
self
.
_feed_list
=
feed_list
self
.
_feed_list
=
feed_list
self
.
_keep_order
=
True
if
not
capacity
:
if
not
capacity
:
raise
ValueError
(
"Please give value to capacity."
)
raise
ValueError
(
"Please give value to capacity."
)
...
@@ -420,7 +419,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
...
@@ -420,7 +419,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
self
.
_dtypes
=
[]
self
.
_dtypes
=
[]
self
.
_need_check_feed
=
[]
self
.
_need_check_feed
=
[]
self
.
_blocking_queue
=
core
.
init_lod_tensor_blocking_queue
(
self
.
_blocking_queue
=
core
.
init_lod_tensor_blocking_queue
(
core
.
Variable
(),
self
.
_capacity
,
self
.
_keep_order
)
core
.
Variable
(),
self
.
_capacity
,
False
)
self
.
_reader
=
core
.
create_py_reader
(
self
.
_reader
=
core
.
create_py_reader
(
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
queue
,
self
.
_var_names
,
self
.
_shapes
,
self
.
_dtypes
,
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
)
self
.
_need_check_feed
,
self
.
_places
,
self
.
_use_double_buffer
)
...
@@ -635,6 +634,7 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -635,6 +634,7 @@ class GeneratorLoader(DataLoaderBase):
self
.
_thread
=
None
self
.
_thread
=
None
self
.
_queue
=
None
self
.
_queue
=
None
self
.
_feed_list
=
feed_list
self
.
_feed_list
=
feed_list
self
.
_exited
=
False
if
not
capacity
:
if
not
capacity
:
raise
ValueError
(
"Please give value to capacity."
)
raise
ValueError
(
"Please give value to capacity."
)
self
.
_iterable
=
iterable
self
.
_iterable
=
iterable
...
@@ -798,8 +798,9 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -798,8 +798,9 @@ class GeneratorLoader(DataLoaderBase):
def
_start
(
self
):
def
_start
(
self
):
def
__thread_main__
():
def
__thread_main__
():
try
:
try
:
if
not
self
.
_queue
.
wait_for_inited
():
while
not
self
.
_queue
.
wait_for_inited
(
1
):
return
if
self
.
_exited
:
return
for
tensors
in
self
.
_tensor_reader
():
for
tensors
in
self
.
_tensor_reader
():
array
=
core
.
LoDTensorArray
()
array
=
core
.
LoDTensorArray
()
...
@@ -829,10 +830,12 @@ class GeneratorLoader(DataLoaderBase):
...
@@ -829,10 +830,12 @@ class GeneratorLoader(DataLoaderBase):
def
_reset
(
self
):
def
_reset
(
self
):
self
.
_queue
.
close
()
self
.
_queue
.
close
()
self
.
_exited
=
True
thread
=
self
.
_thread
thread
=
self
.
_thread
if
thread
is
not
None
:
if
thread
is
not
None
:
thread
.
join
()
thread
.
join
()
self
.
_exited
=
False
self
.
_reader
.
reset
()
self
.
_reader
.
reset
()
def
set_sample_generator
(
self
,
def
set_sample_generator
(
self
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录