Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
847e4f4e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
847e4f4e
编写于
3月 01, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pure async mode train
上级
f768fbf7
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
148 addition
and
63 deletion
+148
-63
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+79
-35
paddle/fluid/framework/details/async_ssa_graph_executor.h
paddle/fluid/framework/details/async_ssa_graph_executor.h
+12
-0
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+5
-3
paddle/fluid/framework/reader.cc
paddle/fluid/framework/reader.cc
+4
-1
paddle/fluid/framework/reader.h
paddle/fluid/framework/reader.h
+9
-1
paddle/fluid/operators/reader/blocking_queue.h
paddle/fluid/operators/reader/blocking_queue.h
+2
-1
paddle/fluid/operators/reader/buffered_reader.cc
paddle/fluid/operators/reader/buffered_reader.cc
+3
-0
paddle/fluid/operators/reader/create_py_reader_op.cc
paddle/fluid/operators/reader/create_py_reader_op.cc
+5
-2
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
+4
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-0
python/paddle/fluid/tests/unittests/test_async_ssa_graph_executor_mnist.py
...id/tests/unittests/test_async_ssa_graph_executor_mnist.py
+22
-19
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
847e4f4e
...
@@ -14,10 +14,31 @@
...
@@ -14,10 +14,31 @@
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
inline
void
NewTempScopeAndInitVars
(
const
std
::
vector
<
VarInfo
>
&
var_infos
,
Scope
*
scope
)
{
Scope
&
local_scope
=
scope
->
NewScope
();
*
scope
->
Var
(
details
::
kLocalExecScopeName
)
->
GetMutable
<
Scope
*>
()
=
&
local_scope
;
for
(
auto
&
info
:
var_infos
)
{
if
(
scope
->
FindVar
(
info
.
name_
)
!=
nullptr
)
{
continue
;
}
if
(
info
.
persistable_
)
{
// Persistable
InitializeVariable
(
scope
->
Var
(
info
.
name_
),
info
.
type_
);
}
else
{
InitializeVariable
(
local_scope
.
Var
(
info
.
name_
),
info
.
type_
);
}
}
}
AsyncSSAGraphExecutor
::
AsyncSSAGraphExecutor
(
AsyncSSAGraphExecutor
::
AsyncSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
vector
<
ir
::
Graph
*>
graphs
)
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
vector
<
ir
::
Graph
*>
graphs
)
...
@@ -39,58 +60,81 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
...
@@ -39,58 +60,81 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
executors_
.
emplace_back
(
new
details
::
ThreadedSSAGraphExecutor
(
executors_
.
emplace_back
(
new
details
::
ThreadedSSAGraphExecutor
(
strategy_
,
{
local_scopes_
[
i
]},
{
places_
[
i
]},
graphs_
[
i
]));
strategy_
,
{
local_scopes_
[
i
]},
{
places_
[
i
]},
graphs_
[
i
]));
}
}
}
FeedFetchList
AsyncSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
std
::
vector
<
std
::
future
<
FeedFetchList
>>
run_futures
;
std
::
vector
<
FeedFetchList
>
fetch_data
;
for
(
auto
&
node
:
graphs_
[
0
]
->
Nodes
())
{
FeedFetchList
ret
;
if
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
&&
node
->
Var
())
{
var_infos_
.
emplace_back
();
fetch_data
.
reserve
(
places_
.
size
());
var_infos_
.
back
().
name_
=
node
->
Var
()
->
Name
();
ret
.
reserve
(
fetch_tensors
.
size
());
var_infos_
.
back
().
type_
=
node
->
Var
()
->
GetType
();
exception_holder_
.
Clear
();
var_infos_
.
back
().
persistable_
=
node
->
Var
()
->
Persistable
();
}
}
for
(
auto
*
scope
:
local_scopes_
)
{
NewTempScopeAndInitVars
(
var_infos_
,
scope
);
}
}
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
void
AsyncSSAGraphExecutor
::
StartOffPythonTrainLoop
()
{
auto
call
=
[
this
,
i
,
&
fetch_tensors
]()
->
FeedFetchList
{
VLOG
(
3
)
<<
"StartOffPythonTrainLoop size = "
<<
places_
.
size
();
for
(
size_t
i
=
1
;
i
<
places_
.
size
();
++
i
)
{
auto
call
=
[
this
,
i
]()
->
void
{
VLOG
(
3
)
<<
"start off python thread "
<<
i
;
try
{
try
{
return
executors_
[
i
]
->
Run
(
fetch_tensors
);
while
(
true
)
{
executors_
[
i
]
->
Run
({});
}
}
catch
(...)
{
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
exception_holder_
.
Catch
(
std
::
current_exception
());
VLOG
(
3
)
<<
"get exception type = "
<<
exception_holder_
.
Type
();
}
}
return
FeedFetchList
()
;
VLOG
(
3
)
<<
"thread "
<<
i
<<
" exited!"
;
};
};
run_futures_
.
emplace_back
(
pool_
->
enqueue
(
std
::
move
(
call
)));
if
(
pool_
)
{
run_futures
.
emplace_back
(
pool_
->
enqueue
(
std
::
move
(
call
)));
}
else
{
fetch_data
.
emplace_back
(
std
::
move
(
call
()));
}
}
}
}
if
(
pool_
)
{
void
AsyncSSAGraphExecutor
::
HandleException
()
{
for
(
auto
&
f
:
run_futures
)
{
if
(
exception_holder_
.
IsCaught
())
{
if
(
exception_holder_
.
IsCaught
())
{
for
(
auto
&
f
:
run_futures_
)
{
VLOG
(
3
)
<<
"wait future"
;
f
.
wait
();
f
.
wait
();
}
else
{
fetch_data
.
emplace_back
(
std
::
move
(
f
.
get
()));
}
}
}
}
if
(
exception_holder_
.
IsCaught
())
{
VLOG
(
3
)
<<
"caught exception "
<<
exception_holder_
.
Type
()
VLOG
(
3
)
<<
"caught exception "
<<
exception_holder_
.
Type
()
<<
", rethrow it"
;
<<
", rethrow it"
;
run_futures_
.
clear
();
exception_holder_
.
ReThrow
();
exception_holder_
.
ReThrow
();
}
}
}
FeedFetchList
AsyncSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
// init once
if
(
run_futures_
.
size
()
==
0
&&
places_
.
size
()
>
1
)
{
exception_holder_
.
Clear
();
StartOffPythonTrainLoop
();
}
if
(
places_
.
size
()
==
1
)
{
exception_holder_
.
Clear
();
}
else
{
HandleException
();
}
FeedFetchList
fetch_data
;
fetch_data
.
reserve
(
fetch_tensors
.
size
());
try
{
fetch_data
=
executors_
[
0
]
->
Run
(
fetch_tensors
);
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
}
HandleException
();
FeedFetchList
ret
;
for
(
size_t
fetch_idx
=
0
;
fetch_idx
<
fetch_tensors
.
size
();
++
fetch_idx
)
{
for
(
size_t
fetch_idx
=
0
;
fetch_idx
<
fetch_tensors
.
size
();
++
fetch_idx
)
{
std
::
vector
<
const
LoDTensor
*>
lodtensor_ptrs
;
std
::
vector
<
const
LoDTensor
*>
lodtensor_ptrs
;
lodtensor_ptrs
.
reserve
(
local_scopes_
.
size
());
lodtensor_ptrs
.
push_back
(
&
fetch_data
.
at
(
fetch_idx
));
for
(
size_t
scope_idx
=
0
;
scope_idx
<
local_scopes_
.
size
();
++
scope_idx
)
{
lodtensor_ptrs
.
push_back
(
&
fetch_data
.
at
(
scope_idx
).
at
(
fetch_idx
));
}
ret
.
emplace_back
();
ret
.
emplace_back
();
ret
.
back
().
MergeLoDTensor
(
lodtensor_ptrs
,
platform
::
CPUPlace
());
ret
.
back
().
MergeLoDTensor
(
lodtensor_ptrs
,
platform
::
CPUPlace
());
}
}
...
...
paddle/fluid/framework/details/async_ssa_graph_executor.h
浏览文件 @
847e4f4e
...
@@ -24,6 +24,12 @@ namespace paddle {
...
@@ -24,6 +24,12 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
struct
VarInfo
{
std
::
string
name_
;
proto
::
VarType
::
Type
type_
;
bool
persistable_
;
};
class
AsyncSSAGraphExecutor
:
public
SSAGraphExecutor
{
class
AsyncSSAGraphExecutor
:
public
SSAGraphExecutor
{
public:
public:
AsyncSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
AsyncSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
...
@@ -35,6 +41,10 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -35,6 +41,10 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
private:
void
StartOffPythonTrainLoop
();
void
HandleException
();
private:
private:
ExecutionStrategy
strategy_
;
ExecutionStrategy
strategy_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_scopes_
;
...
@@ -44,6 +54,8 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
...
@@ -44,6 +54,8 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
std
::
vector
<
std
::
unique_ptr
<
details
::
ThreadedSSAGraphExecutor
>>
executors_
;
std
::
vector
<
std
::
unique_ptr
<
details
::
ThreadedSSAGraphExecutor
>>
executors_
;
ExceptionHolder
exception_holder_
;
ExceptionHolder
exception_holder_
;
std
::
vector
<
std
::
future
<
void
>>
run_futures_
;
std
::
vector
<
VarInfo
>
var_infos_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
847e4f4e
...
@@ -119,6 +119,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
...
@@ -119,6 +119,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
if
(
timeout
)
{
if
(
timeout
)
{
if
(
exception_holder_
.
IsCaught
())
{
if
(
exception_holder_
.
IsCaught
())
{
VLOG
(
3
)
<<
"caught exception "
<<
exception_holder_
.
Type
()
<<
", rethrow it"
;
for
(
auto
&
run_op_future
:
run_op_futures_
)
{
for
(
auto
&
run_op_future
:
run_op_futures_
)
{
run_op_future
.
wait
();
run_op_future
.
wait
();
}
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
847e4f4e
...
@@ -379,9 +379,11 @@ ParallelExecutor::ParallelExecutor(
...
@@ -379,9 +379,11 @@ ParallelExecutor::ParallelExecutor(
}
}
VLOG
(
3
)
<<
"use ScopeBufferedSSAGraphExecutor"
;
VLOG
(
3
)
<<
"use ScopeBufferedSSAGraphExecutor"
;
if
(
!
build_strategy
.
async_mode_
)
{
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
member_
->
places_
,
std
::
move
(
member_
->
executor_
)));
member_
->
places_
,
std
::
move
(
member_
->
executor_
)));
}
}
}
void
ParallelExecutor
::
BCastParamsToDevices
(
void
ParallelExecutor
::
BCastParamsToDevices
(
...
...
paddle/fluid/framework/reader.cc
浏览文件 @
847e4f4e
...
@@ -69,6 +69,9 @@ void ReaderBase::Start() {
...
@@ -69,6 +69,9 @@ void ReaderBase::Start() {
ReaderBase
::~
ReaderBase
()
{}
ReaderBase
::~
ReaderBase
()
{}
DecoratedReader
::~
DecoratedReader
()
{
reader_
->
Shutdown
();
}
DecoratedReader
::~
DecoratedReader
()
{
VLOG
(
1
)
<<
"~DecoratedReader"
;
reader_
->
Shutdown
();
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/reader.h
浏览文件 @
847e4f4e
...
@@ -77,7 +77,10 @@ class DecoratedReader : public ReaderBase,
...
@@ -77,7 +77,10 @@ class DecoratedReader : public ReaderBase,
~
DecoratedReader
();
~
DecoratedReader
();
protected:
protected:
void
ShutdownImpl
()
override
{
reader_
->
Shutdown
();
}
void
ShutdownImpl
()
override
{
VLOG
(
1
)
<<
"ShutdownImpl"
;
reader_
->
Shutdown
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
}
void
StartImpl
()
override
{
reader_
->
Start
();
}
...
@@ -98,6 +101,8 @@ class ReaderHolder {
...
@@ -98,6 +101,8 @@ class ReaderHolder {
reader_
=
reader_base
;
reader_
=
reader_base
;
}
}
~
ReaderHolder
()
{
VLOG
(
1
)
<<
"~ReaderHolder"
;
}
const
std
::
shared_ptr
<
ReaderBase
>&
Get
()
const
{
return
reader_
;
}
const
std
::
shared_ptr
<
ReaderBase
>&
Get
()
const
{
return
reader_
;
}
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
void
ReadNext
(
std
::
vector
<
LoDTensor
>*
out
)
{
...
@@ -106,6 +111,7 @@ class ReaderHolder {
...
@@ -106,6 +111,7 @@ class ReaderHolder {
}
}
void
ResetAll
()
{
void
ResetAll
()
{
VLOG
(
1
)
<<
"ResetAll"
;
auto
end_readers
=
reader_
->
GetEndPoints
();
auto
end_readers
=
reader_
->
GetEndPoints
();
for
(
auto
*
reader
:
end_readers
)
{
for
(
auto
*
reader
:
end_readers
)
{
reader
->
Shutdown
();
reader
->
Shutdown
();
...
@@ -116,11 +122,13 @@ class ReaderHolder {
...
@@ -116,11 +122,13 @@ class ReaderHolder {
}
}
void
Shutdown
()
{
void
Shutdown
()
{
VLOG
(
1
)
<<
"Shutdown"
;
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
Shutdown
();
reader_
->
Shutdown
();
}
}
void
Start
()
{
void
Start
()
{
VLOG
(
1
)
<<
"start"
;
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
PADDLE_ENFORCE_NOT_NULL
(
reader_
);
reader_
->
Start
();
reader_
->
Start
();
}
}
...
...
paddle/fluid/operators/reader/blocking_queue.h
浏览文件 @
847e4f4e
...
@@ -86,6 +86,7 @@ class BlockingQueue {
...
@@ -86,6 +86,7 @@ class BlockingQueue {
void
ReOpen
()
{
void
ReOpen
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
VLOG
(
1
)
<<
"reopen queue"
;
closed_
=
false
;
closed_
=
false
;
std
::
deque
<
T
>
new_deque
;
std
::
deque
<
T
>
new_deque
;
queue_
.
swap
(
new_deque
);
queue_
.
swap
(
new_deque
);
...
@@ -95,7 +96,7 @@ class BlockingQueue {
...
@@ -95,7 +96,7 @@ class BlockingQueue {
void
Close
()
{
void
Close
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
VLOG
(
3
)
<<
"close queue"
;
VLOG
(
1
)
<<
"close queue"
;
closed_
=
true
;
closed_
=
true
;
send_cv_
.
notify_all
();
send_cv_
.
notify_all
();
receive_cv_
.
notify_all
();
receive_cv_
.
notify_all
();
...
...
paddle/fluid/operators/reader/buffered_reader.cc
浏览文件 @
847e4f4e
...
@@ -20,6 +20,7 @@ namespace paddle {
...
@@ -20,6 +20,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
reader
{
namespace
reader
{
BufferedReader
::~
BufferedReader
()
{
BufferedReader
::~
BufferedReader
()
{
VLOG
(
1
)
<<
"~BufferedReader"
;
reader_
->
Shutdown
();
reader_
->
Shutdown
();
while
(
!
position_
.
empty
())
{
while
(
!
position_
.
empty
())
{
position_
.
front
().
wait
();
position_
.
front
().
wait
();
...
@@ -41,6 +42,7 @@ BufferedReader::BufferedReader(
...
@@ -41,6 +42,7 @@ BufferedReader::BufferedReader(
thread_pool_
(
1
),
thread_pool_
(
1
),
place_
(
place
),
place_
(
place
),
buffer_size_
(
buffer_size
)
{
buffer_size_
(
buffer_size
)
{
VLOG
(
1
)
<<
"BufferedReader"
;
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
platform
::
SetDeviceId
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
).
device
);
platform
::
SetDeviceId
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
).
device
);
...
@@ -121,6 +123,7 @@ void BufferedReader::ReadAsync(size_t i) {
...
@@ -121,6 +123,7 @@ void BufferedReader::ReadAsync(size_t i) {
}
}
void
BufferedReader
::
ShutdownImpl
()
{
void
BufferedReader
::
ShutdownImpl
()
{
VLOG
(
1
)
<<
"ShutdownImpl"
;
reader_
->
Shutdown
();
reader_
->
Shutdown
();
while
(
!
position_
.
empty
())
{
while
(
!
position_
.
empty
())
{
position_
.
pop
();
position_
.
pop
();
...
...
paddle/fluid/operators/reader/create_py_reader_op.cc
浏览文件 @
847e4f4e
...
@@ -33,10 +33,13 @@ class PyReader : public framework::FileReader {
...
@@ -33,10 +33,13 @@ class PyReader : public framework::FileReader {
if
(
!
success
)
out
->
clear
();
if
(
!
success
)
out
->
clear
();
}
}
~
PyReader
()
{
queue_
->
Close
();
}
~
PyReader
()
{
VLOG
(
1
)
<<
"~PyReader"
;
queue_
->
Close
();
}
void
Shutdown
()
override
{
void
Shutdown
()
override
{
VLOG
(
3
)
<<
"PyReader shutdown!"
;
VLOG
(
1
)
<<
"PyReader shutdown!"
;
queue_
->
Close
();
queue_
->
Close
();
}
}
...
...
paddle/fluid/operators/reader/lod_tensor_blocking_queue.h
浏览文件 @
847e4f4e
...
@@ -57,7 +57,10 @@ class LoDTensorBlockingQueue {
...
@@ -57,7 +57,10 @@ class LoDTensorBlockingQueue {
inline
void
ReOpen
()
{
queue_
.
ReOpen
();
}
inline
void
ReOpen
()
{
queue_
.
ReOpen
();
}
inline
void
Close
()
{
queue_
.
Close
();
}
inline
void
Close
()
{
VLOG
(
1
)
<<
"LoDTensorBlockingQueue close"
;
queue_
.
Close
();
}
inline
bool
IsClosed
()
const
{
return
queue_
.
IsClosed
();
}
inline
bool
IsClosed
()
const
{
return
queue_
.
IsClosed
();
}
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
847e4f4e
...
@@ -557,6 +557,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -557,6 +557,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"init_lod_tensor_blocking_queue"
,
m
.
def
(
"init_lod_tensor_blocking_queue"
,
[](
Variable
&
var
,
[](
Variable
&
var
,
size_t
capacity
)
->
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
{
size_t
capacity
)
->
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
{
VLOG
(
1
)
<<
"init_lod_tensor_blocking_queue"
;
auto
*
holder
=
var
.
GetMutable
<
LoDTensorBlockingQueueHolder
>
();
auto
*
holder
=
var
.
GetMutable
<
LoDTensorBlockingQueueHolder
>
();
holder
->
InitOnce
(
capacity
,
FLAGS_reader_queue_speed_test_mode
);
holder
->
InitOnce
(
capacity
,
FLAGS_reader_queue_speed_test_mode
);
return
holder
->
GetQueue
();
return
holder
->
GetQueue
();
...
...
python/paddle/fluid/tests/unittests/test_async_ssa_graph_executor_mnist.py
浏览文件 @
847e4f4e
...
@@ -36,7 +36,7 @@ def convolutional_neural_network(use_py_reader):
...
@@ -36,7 +36,7 @@ def convolutional_neural_network(use_py_reader):
capacity
=
64
,
capacity
=
64
,
feed_list
=
[
img
,
label
],
feed_list
=
[
img
,
label
],
name
=
'py_reader'
,
name
=
'py_reader'
,
use_double_buffer
=
Tru
e
)
use_double_buffer
=
Fals
e
)
img
,
label
=
fluid
.
layers
.
read_file
(
py_reader
)
img
,
label
=
fluid
.
layers
.
read_file
(
py_reader
)
conv_pool_1
=
fluid
.
nets
.
simple_img_conv_pool
(
conv_pool_1
=
fluid
.
nets
.
simple_img_conv_pool
(
...
@@ -139,19 +139,20 @@ def train(use_cuda, thread_num, cpu_num):
...
@@ -139,19 +139,20 @@ def train(use_cuda, thread_num, cpu_num):
exec_strategy
=
exec_strategy
)
exec_strategy
=
exec_strategy
)
py_reader
.
decorate_paddle_reader
(
train_reader
)
py_reader
.
decorate_paddle_reader
(
train_reader
)
py_reader
.
start
()
for
pass_id
in
range
(
2
):
step
=
0
step
=
0
py_reader
.
start
()
try
:
try
:
while
True
:
while
True
:
loss_val
=
pe
.
run
(
fetch_list
=
[
avg_loss
.
name
])
loss_val
=
pe
.
run
(
fetch_list
=
[
avg_loss
.
name
])
loss_val
=
numpy
.
mean
(
loss_val
)
loss_val
=
numpy
.
mean
(
loss_val
)
if
step
%
10
0
==
0
:
if
step
%
1
0
==
0
:
print
(
"
Batch %d, Cost %f, queue size %d"
%
print
(
"Pass %d,
Batch %d, Cost %f, queue size %d"
%
(
step
,
loss_val
,
py_reader
.
queue
.
size
()))
(
pass_id
,
step
,
loss_val
,
py_reader
.
queue
.
size
()))
step
+=
1
step
+=
1
except
fluid
.
core
.
EOFException
:
except
fluid
.
core
.
EOFException
:
print
(
"train end"
)
print
(
"train end pass = "
+
str
(
pass_id
)
)
py_reader
.
reset
()
py_reader
.
reset
()
return
step
return
step
...
@@ -161,10 +162,11 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase):
...
@@ -161,10 +162,11 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase):
def
test_check_async_ssa_exe_train
(
self
):
def
test_check_async_ssa_exe_train
(
self
):
step_list
=
[]
step_list
=
[]
for
cpu_num
in
[
1
,
2
,
4
]:
for
cpu_num
in
[
1
,
2
,
4
]:
scope
=
fluid
.
core
.
Scope
(
)
print
(
"run cpu_num -> "
+
str
(
cpu_num
)
)
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
scope_guard
(
fluid
.
core
.
Scope
()
):
with
fluid
.
program_guard
(
with
fluid
.
program_guard
(
fluid
.
Program
(),
startup_program
=
fluid
.
Program
()):
main_program
=
fluid
.
Program
(),
startup_program
=
fluid
.
Program
()):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
step
=
train
(
step
=
train
(
use_cuda
=
False
,
thread_num
=
cpu_num
,
cpu_num
=
cpu_num
)
use_cuda
=
False
,
thread_num
=
cpu_num
,
cpu_num
=
cpu_num
)
...
@@ -173,7 +175,8 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase):
...
@@ -173,7 +175,8 @@ class TestAsyncSSAGraphExecutor(unittest.TestCase):
print
(
"cpu_num -> "
+
str
(
cpu_num
)
+
" step -> "
+
str
(
step
)
+
print
(
"cpu_num -> "
+
str
(
cpu_num
)
+
" step -> "
+
str
(
step
)
+
" time -> "
+
str
(
end_time
-
start_time
))
" time -> "
+
str
(
end_time
-
start_time
))
with
fluid
.
program_guard
(
with
fluid
.
program_guard
(
fluid
.
Program
(),
startup_program
=
fluid
.
Program
()):
main_program
=
fluid
.
Program
(),
startup_program
=
fluid
.
Program
()):
test
()
test
()
assert
int
(
step_list
[
0
]
/
2
)
==
int
(
step_list
[
1
])
assert
int
(
step_list
[
0
]
/
2
)
==
int
(
step_list
[
1
])
assert
int
(
step_list
[
1
]
/
2
)
==
int
(
step_list
[
2
])
assert
int
(
step_list
[
1
]
/
2
)
==
int
(
step_list
[
2
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录