Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ce72c3ff
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ce72c3ff
编写于
5月 10, 2018
作者:
C
chengduo
提交者:
GitHub
5月 10, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10476 from chengduoZH/refine_parallel_exe
Clean Parallel exe
上级
61343fbf
a89cd467
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
191 addition
and
142 deletion
+191
-142
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+32
-28
paddle/fluid/framework/details/broadcast_op_handle.h
paddle/fluid/framework/details/broadcast_op_handle.h
+3
-1
paddle/fluid/framework/details/computation_op_handle.cc
paddle/fluid/framework/details/computation_op_handle.cc
+8
-8
paddle/fluid/framework/details/computation_op_handle.h
paddle/fluid/framework/details/computation_op_handle.h
+2
-0
paddle/fluid/framework/details/fetch_op_handle.cc
paddle/fluid/framework/details/fetch_op_handle.cc
+12
-9
paddle/fluid/framework/details/fetch_op_handle.h
paddle/fluid/framework/details/fetch_op_handle.h
+3
-1
paddle/fluid/framework/details/gather_op_handle.cc
paddle/fluid/framework/details/gather_op_handle.cc
+1
-12
paddle/fluid/framework/details/gather_op_handle.h
paddle/fluid/framework/details/gather_op_handle.h
+0
-1
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
+1
-6
paddle/fluid/framework/details/op_handle_base.cc
paddle/fluid/framework/details/op_handle_base.cc
+25
-3
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+13
-1
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+12
-21
paddle/fluid/framework/details/reduce_op_handle.h
paddle/fluid/framework/details/reduce_op_handle.h
+0
-2
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+1
-0
paddle/fluid/framework/details/send_op_handle.cc
paddle/fluid/framework/details/send_op_handle.cc
+2
-1
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+60
-48
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+16
-0
未找到文件。
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
ce72c3ff
...
...
@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
out_var_handles
.
size
(),
places_
.
size
(),
"The number of output should equal to the number of places."
);
// Wait input done, this Wait is asynchronous operation platform::Place
// &in_place;
WaitInputVarGenerated
(
*
in_var_handle
);
WaitInputVarGenerated
();
std
::
vector
<
const
Scope
*>
var_scopes
;
for
(
auto
*
s
:
local_scopes_
)
{
...
...
@@ -50,29 +48,9 @@ void BroadcastOpHandle::RunImpl() {
auto
*
in_var
=
var_scopes
.
at
(
in_var_handle
->
scope_idx_
)
->
FindVar
(
in_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
in_var
);
Tensor
&
in_tensor
=
VariableVisitor
::
GetMutableTensor
(
in_var
);
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
for
(
auto
*
out_var_handle
:
out_var_handles
)
{
if
(
out_var_handle
->
IsTheSameVar
(
*
in_var_handle
))
{
continue
;
}
auto
t_out_p
=
out_var_handle
->
place_
;
auto
*
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
if
(
platform
::
is_gpu_place
(
in_tensor
.
place
()))
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
t_out_p
),
"Places of input and output must be all on GPU."
);
}
else
{
t_out_p
=
platform
::
CPUPlace
();
}
VariableVisitor
::
ShareDimsAndLoD
(
*
in_var
,
out_var
);
VariableVisitor
::
GetMutableTensor
(
out_var
).
mutable_data
(
t_out_p
,
in_tensor
.
type
());
}
InitOutputValue
(
*
in_var_handle
,
out_var_handles
);
if
(
platform
::
is_cpu_place
(
in_tensor
.
place
()))
{
for
(
auto
*
out_var_handle
:
out_var_handles
)
{
...
...
@@ -147,11 +125,37 @@ void BroadcastOpHandle::RunImpl() {
}
}
void
BroadcastOpHandle
::
WaitInputVarGenerated
(
const
VarHandle
&
in_var
)
{
if
(
in_var
.
generated_op_
)
{
for
(
auto
&
pair
:
dev_ctxes_
)
{
in_var
.
generated_op_
->
Wait
(
pair
.
second
);
void
BroadcastOpHandle
::
InitOutputValue
(
const
VarHandle
&
in_var_handle
,
const
std
::
vector
<
VarHandle
*>
&
out_var_handles
)
const
{
std
::
vector
<
const
Scope
*>
var_scopes
;
for
(
auto
*
s
:
local_scopes_
)
{
var_scopes
.
emplace_back
(
s
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
());
}
auto
*
in_var
=
var_scopes
.
at
(
in_var_handle
.
scope_idx_
)
->
FindVar
(
in_var_handle
.
name_
);
Tensor
&
in_tensor
=
VariableVisitor
::
GetMutableTensor
(
in_var
);
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
for
(
auto
*
out_var_handle
:
out_var_handles
)
{
if
(
out_var_handle
->
IsTheSameVar
(
in_var_handle
))
{
continue
;
}
auto
t_out_p
=
out_var_handle
->
place_
;
auto
*
out_var
=
var_scopes
.
at
(
out_var_handle
->
scope_idx_
)
->
FindVar
(
out_var_handle
->
name_
);
PADDLE_ENFORCE_NOT_NULL
(
out_var
);
if
(
is_gpu_place
(
in_tensor
.
place
()))
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
t_out_p
),
"Places of input and output must be all on GPU."
);
}
else
{
t_out_p
=
platform
::
CPUPlace
();
}
VariableVisitor
::
ShareDimsAndLoD
(
*
in_var
,
out_var
);
VariableVisitor
::
GetMutableTensor
(
out_var
).
mutable_data
(
t_out_p
,
in_tensor
.
type
());
}
}
...
...
paddle/fluid/framework/details/broadcast_op_handle.h
浏览文件 @
ce72c3ff
...
...
@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
void
WaitInputVarGenerated
(
const
VarHandle
&
in_var
);
private:
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
...
...
@@ -65,6 +64,9 @@ struct BroadcastOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
void
InitOutputValue
(
const
VarHandle
&
in_var_handle
,
const
std
::
vector
<
VarHandle
*>
&
out_var_handles
)
const
;
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/computation_op_handle.cc
浏览文件 @
ce72c3ff
...
...
@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
place_
(
place
)
{}
void
ComputationOpHandle
::
RunImpl
()
{
auto
*
cur_ctx
=
dev_ctxes_
[
place_
];
for
(
auto
*
in
:
inputs_
)
{
bool
need_wait
=
in
->
generated_op_
&&
in
->
generated_op_
->
DeviceContext
(
place_
)
!=
cur_ctx
;
if
(
need_wait
)
{
in
->
generated_op_
->
Wait
(
cur_ctx
);
}
}
WaitInputVarGenerated
(
place_
);
this
->
RunAndRecordEvent
([
this
]
{
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
});
}
bool
ComputationOpHandle
::
NeedWait
(
VarHandleBase
*
in_var
)
{
bool
need_wait
=
in_var
&&
in_var
->
generated_op_
&&
in_var
->
generated_op_
->
DeviceContext
(
place_
)
!=
dev_ctxes_
[
place_
];
return
need_wait
;
}
std
::
string
ComputationOpHandle
::
Name
()
const
{
return
op_
->
Type
();
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/computation_op_handle.h
浏览文件 @
ce72c3ff
...
...
@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
virtual
bool
NeedWait
(
VarHandleBase
*
in_var
);
private:
std
::
unique_ptr
<
OperatorBase
>
op_
;
Scope
*
scope_
;
...
...
paddle/fluid/framework/details/fetch_op_handle.cc
浏览文件 @
ce72c3ff
...
...
@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() {
}
}
void
FetchOpHandle
::
Wait
(
platform
::
DeviceContext
*
waited_dev
)
{
void
FetchOpHandle
::
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
{
PADDLE_THROW
(
"Nobody should wait FetchOp. Unexpceted Error"
);
}
...
...
@@ -45,14 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
}
void
FetchOpHandle
::
RunImpl
()
{
auto
cpu_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
for
(
auto
*
input
:
inputs_
)
{
auto
*
var
=
static_cast
<
VarHandle
*>
(
input
);
if
(
var
->
generated_op_
)
{
var
->
generated_op_
->
Wait
(
cpu_ctx
);
}
}
WaitInputVarGenerated
(
platform
::
CPUPlace
());
tensors_
.
resize
(
inputs_
.
size
());
auto
*
var_handle
=
static_cast
<
VarHandle
*>
(
inputs_
[
0
]);
auto
&
var_name
=
var_handle
->
name_
;
...
...
@@ -79,6 +73,15 @@ void FetchOpHandle::RunImpl() {
this
->
WaitAndMergeCPUTensors
();
}
void
FetchOpHandle
::
WaitInputVarGenerated
(
const
platform
::
Place
&
place
)
{
auto
cpu_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
for
(
auto
*
input
:
inputs_
)
{
if
(
input
->
generated_op_
)
{
input
->
generated_op_
->
RecordWaitEventOnCtx
(
cpu_ctx
);
}
}
}
std
::
string
FetchOpHandle
::
Name
()
const
{
return
"Fetch"
;
}
}
// namespace details
...
...
paddle/fluid/framework/details/fetch_op_handle.h
浏览文件 @
ce72c3ff
...
...
@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase {
~
FetchOpHandle
();
void
Wait
(
platform
::
DeviceContext
*
waited_dev
)
override
;
void
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
override
;
void
WaitAndMergeCPUTensors
()
const
;
...
...
@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
virtual
void
WaitInputVarGenerated
(
const
platform
::
Place
&
place
);
private:
FeedFetchList
*
data_
;
size_t
offset_
;
...
...
paddle/fluid/framework/details/gather_op_handle.cc
浏览文件 @
ce72c3ff
...
...
@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() {
"Currently, gather_op only can gather SelectedRows."
);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated
(
in_var_handles
);
WaitInputVarGenerated
();
auto
&
pre_in_value
=
pre_in_var
->
Get
<
framework
::
SelectedRows
>
();
std
::
vector
<
int64_t
>
out_rows
;
...
...
@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() {
});
}
void
GatherOpHandle
::
WaitInputVarGenerated
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
)
{
for
(
auto
*
in
:
in_var_handles
)
{
if
(
in
->
generated_op_
)
{
for
(
auto
pair
:
dev_ctxes_
)
{
in
->
generated_op_
->
Wait
(
pair
.
second
);
}
}
}
}
std
::
string
GatherOpHandle
::
Name
()
const
{
return
"gather"
;
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/gather_op_handle.h
浏览文件 @
ce72c3ff
...
...
@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
void
WaitInputVarGenerated
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
);
private:
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
...
...
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
浏览文件 @
ce72c3ff
...
...
@@ -34,12 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
return
;
// No need to all reduce when GPU count = 1;
}
else
{
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
if
(
in
->
generated_op_
)
{
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
p
]);
}
}
WaitInputVarGenerated
();
auto
&
var_name
=
static_cast
<
VarHandle
*>
(
this
->
inputs_
[
0
])
->
name_
;
int
dtype
=
-
1
;
...
...
paddle/fluid/framework/details/op_handle_base.cc
浏览文件 @
ce72c3ff
...
...
@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
RunImpl
();
}
void
OpHandleBase
::
Wait
(
platform
::
DeviceContext
*
waited_dev
)
{
void
OpHandleBase
::
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
{
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_cpu_place
(
waited_
dev
->
GetPlace
())
||
events_
.
empty
())
{
if
(
platform
::
is_cpu_place
(
waited_
ctx
->
GetPlace
())
||
events_
.
empty
())
{
for
(
auto
&
dev_ctx
:
dev_ctxes_
)
{
dev_ctx
.
second
->
Wait
();
}
}
else
{
auto
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
waited_
dev
)
->
stream
();
static_cast
<
platform
::
CUDADeviceContext
*>
(
waited_
ctx
)
->
stream
();
for
(
auto
&
ev
:
events_
)
{
PADDLE_ENFORCE
(
cudaStreamWaitEvent
(
stream
,
ev
.
second
,
0
));
}
...
...
@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out
->
generated_op_
=
this
;
}
void
OpHandleBase
::
WaitInputVarGenerated
()
{
for
(
auto
in_var
:
inputs_
)
{
if
(
NeedWait
(
in_var
))
{
for
(
auto
&
pair
:
dev_ctxes_
)
{
in_var
->
generated_op_
->
RecordWaitEventOnCtx
(
pair
.
second
);
}
}
}
}
void
OpHandleBase
::
WaitInputVarGenerated
(
const
platform
::
Place
&
place
)
{
for
(
auto
*
in
:
inputs_
)
{
if
(
NeedWait
(
in
))
{
in
->
generated_op_
->
RecordWaitEventOnCtx
(
dev_ctxes_
[
place
]);
}
}
}
bool
OpHandleBase
::
NeedWait
(
VarHandleBase
*
in_var
)
{
return
in_var
&&
in_var
->
generated_op_
;
}
void
OpHandleBase
::
RunAndRecordEvent
(
const
std
::
function
<
void
()
>
&
callback
)
{
#ifdef PADDLE_WITH_CUDA
if
(
!
events_
.
empty
())
{
// Use event
...
...
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
ce72c3ff
...
...
@@ -38,12 +38,24 @@ class OpHandleBase {
void
Run
(
bool
use_event
);
virtual
void
Wait
(
platform
::
DeviceContext
*
waited_dev
);
virtual
void
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
);
void
AddInput
(
VarHandleBase
*
in
);
void
AddOutput
(
VarHandleBase
*
out
);
// This method adds the wait events of all the input on all the device
// context.
// NODE: This Wait is asynchronous operation.
virtual
void
WaitInputVarGenerated
();
// This method adds the wait events of all the input on the specified device
// context.
// NODE: This Wait is asynchronous operation.
virtual
void
WaitInputVarGenerated
(
const
platform
::
Place
&
place
);
virtual
bool
NeedWait
(
VarHandleBase
*
in_var
);
// If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual
bool
IsMultiDeviceTransfer
()
{
return
false
;
}
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
ce72c3ff
...
...
@@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() {
PADDLE_ENFORCE_NOT_NULL
(
pre_in_var
);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated
(
in_var_handles
);
WaitInputVarGenerated
();
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std
::
vector
<
platform
::
Place
>
in_places
;
// used to get dev_ctx
...
...
@@ -80,19 +80,21 @@ void ReduceOpHandle::RunImpl() {
}
if
(
pre_in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
vector
<
const
SelectedRows
*>
in_selected_rows
=
GetInputValues
<
SelectedRows
>
(
in_var_handles
,
var_scopes
);
GatherSelectedRows
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
this
->
RunAndRecordEvent
([
&
]
{
std
::
vector
<
const
SelectedRows
*>
in_selected_rows
=
GetInputValues
<
SelectedRows
>
(
in_var_handles
,
var_scopes
);
GatherSelectedRows
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
});
}
else
{
std
::
vector
<
const
LoDTensor
*>
lod_tensors
=
GetInputValues
<
LoDTensor
>
(
in_var_handles
,
var_scopes
);
if
(
paddle
::
platform
::
is_cpu_place
(
lod_tensors
[
0
]
->
place
()))
{
ReduceLoDTensor
func
(
lod_tensors
,
out_var
->
GetMutable
<
framework
::
LoDTensor
>
());
VisitDataType
(
ToDataType
(
lod_tensors
[
0
]
->
type
()),
func
);
this
->
RunAndRecordEvent
([
&
]
{
ReduceLoDTensor
func
(
lod_tensors
,
out_var
->
GetMutable
<
framework
::
LoDTensor
>
());
VisitDataType
(
ToDataType
(
lod_tensors
[
0
]
->
type
()),
func
);
});
}
else
if
(
paddle
::
platform
::
is_gpu_place
(
lod_tensors
[
0
]
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
auto
pre_in
=
pre_in_var
->
Get
<
framework
::
LoDTensor
>
();
...
...
@@ -157,17 +159,6 @@ std::vector<const T *> ReduceOpHandle::GetInputValues(
return
in_selected_rows
;
}
void
ReduceOpHandle
::
WaitInputVarGenerated
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
)
{
for
(
auto
*
in
:
in_var_handles
)
{
if
(
in
->
generated_op_
)
{
for
(
auto
pair
:
dev_ctxes_
)
{
in
->
generated_op_
->
Wait
(
pair
.
second
);
}
}
}
}
std
::
string
ReduceOpHandle
::
Name
()
const
{
return
"reduce"
;
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/reduce_op_handle.h
浏览文件 @
ce72c3ff
...
...
@@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
void
WaitInputVarGenerated
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
);
template
<
typename
T
>
std
::
vector
<
const
T
*>
GetInputValues
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
,
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
浏览文件 @
ce72c3ff
...
...
@@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle
::~
ScaleLossGradOpHandle
()
{}
void
ScaleLossGradOpHandle
::
RunImpl
()
{
// Doesn't wait any event
std
::
string
var_name
=
static_cast
<
VarHandle
*>
(
this
->
outputs_
[
0
])
->
name_
;
auto
&
local_scope
=
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
...
...
paddle/fluid/framework/details/send_op_handle.cc
浏览文件 @
ce72c3ff
...
...
@@ -26,6 +26,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
place_
(
place
)
{}
void
SendOpHandle
::
RunImpl
()
{
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
...
...
@@ -33,7 +34,7 @@ void SendOpHandle::RunImpl() {
continue
;
}
if
(
in
->
generated_op_
)
{
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
p
]);
in
->
generated_op_
->
RecordWaitEventOnCtx
(
dev_ctxes_
[
p
]);
}
}
auto
&
tmp_scope
=
local_scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
ce72c3ff
...
...
@@ -14,8 +14,6 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -45,73 +43,33 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Should revisit it if overlapping is available.
std
::
unordered_set
<
OpHandleBase
*>
delayed_ops
;
auto
InsertPendingVar
=
[
&
pending_vars
,
&
ready_vars
](
VarHandleBase
&
var
)
{
pending_vars
.
insert
(
&
var
);
if
(
var
.
generated_op_
==
nullptr
)
{
ready_vars
.
Push
(
&
var
);
}
};
auto
InsertPendingOp
=
[
&
pending_ops
](
OpHandleBase
&
op_instance
)
{
pending_ops
.
insert
({
&
op_instance
,
op_instance
.
Inputs
().
size
()});
};
// Transform SSAGraph to pending_ops & pending_vars
for
(
auto
&
var_map
:
graph_
->
vars_
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
*
version_pair
);
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
version_pair
.
get
()
);
}
}
}
for
(
auto
&
var
:
graph_
->
dep_vars_
)
{
InsertPendingVar
(
*
var
);
InsertPendingVar
(
&
pending_vars
,
&
ready_vars
,
var
.
get
()
);
}
for
(
auto
&
op
:
graph_
->
ops_
)
{
if
(
op
->
Inputs
().
empty
())
{
// Special case, Op has no input.
ready_ops
.
insert
(
op
.
get
());
}
else
{
InsertPendingOp
(
*
op
);
InsertPendingOp
(
&
pending_ops
,
op
.
get
()
);
}
}
// Step 2. Insert FetchOps
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>
fetch_ops
;
FeedFetchList
fetch_data
(
fetch_tensors
.
size
());
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
vars_
)
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
}
}
}
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
fetch_dependencies
;
for
(
size_t
i
=
0
;
i
<
fetch_tensors
.
size
();
++
i
)
{
auto
&
var_name
=
fetch_tensors
[
i
];
auto
&
vars
=
fetched_vars
.
at
(
var_name
);
auto
*
op
=
new
FetchOpHandle
(
&
fetch_data
,
i
,
&
local_scopes_
);
fetch_ops
.
emplace_back
(
op
);
for
(
auto
&
p
:
places_
)
{
op
->
SetDeviceContext
(
p
,
fetch_ctxs_
.
Get
(
p
));
}
for
(
auto
*
var
:
vars
)
{
op
->
AddInput
(
var
);
}
FeedFetchList
fetch_data
(
fetch_tensors
.
size
());
auto
*
fetch_dummy
=
new
DummyVarHandle
();
op
->
AddOutput
(
fetch_dummy
);
fetch_dependencies
.
emplace
(
fetch_dummy
);
InsertPendingVar
(
*
fetch_dummy
);
InsertPendingOp
(
*
op
);
}
InsertFetchOps
(
fetch_tensors
,
&
fetch_ops
,
&
fetch_dependencies
,
&
pending_ops
,
&
pending_vars
,
&
ready_vars
,
&
fetch_data
);
auto
run_all_ops
=
[
&
](
std
::
unordered_set
<
OpHandleBase
*>
&
set
)
{
for
(
auto
*
op
:
set
)
{
...
...
@@ -174,6 +132,60 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return
fetch_data
;
}
void
ThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>
*
fetch_ops
,
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
*
fetch_dependencies
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
)
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
for
(
auto
&
fetch_var_name
:
fetch_tensors
)
{
for
(
auto
&
var_map
:
graph_
->
vars_
)
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
());
}
}
}
for
(
size_t
i
=
0
;
i
<
fetch_tensors
.
size
();
++
i
)
{
auto
&
var_name
=
fetch_tensors
[
i
];
auto
&
vars
=
fetched_vars
.
at
(
var_name
);
auto
*
op
=
new
FetchOpHandle
(
fetch_data
,
i
,
&
local_scopes_
);
fetch_ops
->
emplace_back
(
op
);
for
(
auto
&
p
:
places_
)
{
op
->
SetDeviceContext
(
p
,
fetch_ctxs_
.
Get
(
p
));
}
for
(
auto
*
var
:
vars
)
{
op
->
AddInput
(
var
);
}
auto
*
fetch_dummy
=
new
DummyVarHandle
();
op
->
AddOutput
(
fetch_dummy
);
fetch_dependencies
->
emplace
(
fetch_dummy
);
this
->
InsertPendingVar
(
pending_vars
,
ready_vars
,
fetch_dummy
);
this
->
InsertPendingOp
(
pending_ops
,
op
);
}
}
void
ThreadedSSAGraphExecutor
::
InsertPendingOp
(
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
OpHandleBase
*
op_instance
)
const
{
pending_ops
->
insert
({
op_instance
,
op_instance
->
Inputs
().
size
()});
}
void
ThreadedSSAGraphExecutor
::
InsertPendingVar
(
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
VarHandleBase
*
var
)
const
{
pending_vars
->
insert
(
var
);
if
(
var
->
generated_op_
==
nullptr
)
{
ready_vars
->
Push
(
var
);
}
}
void
ThreadedSSAGraphExecutor
::
RunOp
(
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.h
浏览文件 @
ce72c3ff
...
...
@@ -23,6 +23,7 @@
#include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace
paddle
{
...
...
@@ -58,6 +59,21 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
exception_
;
std
::
atomic
<
int
>
running_ops_
;
bool
allow_op_delay_
;
void
InsertPendingOp
(
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
OpHandleBase
*
op_instance
)
const
;
void
InsertPendingVar
(
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
VarHandleBase
*
var
)
const
;
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
std
::
vector
<
std
::
unique_ptr
<
FetchOpHandle
>>
*
fetch_ops
,
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
*
fetch_dependencies
,
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
*
pending_ops
,
std
::
unordered_set
<
VarHandleBase
*>
*
pending_vars
,
BlockingQueue
<
VarHandleBase
*>
*
ready_vars
,
FeedFetchList
*
fetch_data
);
};
}
// namespace details
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录