Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9eec2c75
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9eec2c75
编写于
5月 09, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine pe
上级
f4851f14
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
65 addition
and
70 deletion
+65
-70
paddle/fluid/framework/details/broadcast_op_handle.cc
paddle/fluid/framework/details/broadcast_op_handle.cc
+1
-11
paddle/fluid/framework/details/broadcast_op_handle.h
paddle/fluid/framework/details/broadcast_op_handle.h
+0
-1
paddle/fluid/framework/details/computation_op_handle.cc
paddle/fluid/framework/details/computation_op_handle.cc
+8
-8
paddle/fluid/framework/details/computation_op_handle.h
paddle/fluid/framework/details/computation_op_handle.h
+2
-0
paddle/fluid/framework/details/fetch_op_handle.cc
paddle/fluid/framework/details/fetch_op_handle.cc
+12
-7
paddle/fluid/framework/details/fetch_op_handle.h
paddle/fluid/framework/details/fetch_op_handle.h
+3
-1
paddle/fluid/framework/details/gather_op_handle.cc
paddle/fluid/framework/details/gather_op_handle.cc
+1
-12
paddle/fluid/framework/details/gather_op_handle.h
paddle/fluid/framework/details/gather_op_handle.h
+0
-1
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
+1
-4
paddle/fluid/framework/details/op_handle_base.cc
paddle/fluid/framework/details/op_handle_base.cc
+25
-3
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+9
-1
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+1
-12
paddle/fluid/framework/details/reduce_op_handle.h
paddle/fluid/framework/details/reduce_op_handle.h
+0
-2
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+1
-0
paddle/fluid/framework/details/send_op_handle.cc
paddle/fluid/framework/details/send_op_handle.cc
+1
-7
未找到文件。
paddle/fluid/framework/details/broadcast_op_handle.cc
浏览文件 @
9eec2c75
...
...
@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
out_var_handles
.
size
(),
places_
.
size
(),
"The number of output should equal to the number of places."
);
// Wait input done, this Wait is asynchronous operation platform::Place
// &in_place;
WaitInputVarGenerated
(
*
in_var_handle
);
WaitInputVarGenerated
();
std
::
vector
<
const
Scope
*>
var_scopes
;
for
(
auto
*
s
:
local_scopes_
)
{
...
...
@@ -147,14 +145,6 @@ void BroadcastOpHandle::RunImpl() {
}
}
void
BroadcastOpHandle
::
WaitInputVarGenerated
(
const
VarHandle
&
in_var
)
{
if
(
in_var
.
generated_op_
)
{
for
(
auto
&
pair
:
dev_ctxes_
)
{
in_var
.
generated_op_
->
Wait
(
pair
.
second
);
}
}
}
std
::
string
BroadcastOpHandle
::
Name
()
const
{
return
"broadcast"
;
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/broadcast_op_handle.h
浏览文件 @
9eec2c75
...
...
@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
void
WaitInputVarGenerated
(
const
VarHandle
&
in_var
);
private:
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
...
...
paddle/fluid/framework/details/computation_op_handle.cc
浏览文件 @
9eec2c75
...
...
@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
place_
(
place
)
{}
void
ComputationOpHandle
::
RunImpl
()
{
auto
*
cur_ctx
=
dev_ctxes_
[
place_
];
for
(
auto
*
in
:
inputs_
)
{
bool
need_wait
=
in
->
generated_op_
&&
in
->
generated_op_
->
DeviceContext
(
place_
)
!=
cur_ctx
;
if
(
need_wait
)
{
in
->
generated_op_
->
Wait
(
cur_ctx
);
}
}
WaitInputVarGenerated
(
place_
);
this
->
RunAndRecordEvent
([
this
]
{
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
});
}
bool
ComputationOpHandle
::
NeedWait
(
VarHandleBase
*
in_var
)
{
bool
need_wait
=
dynamic_cast
<
VarHandle
*>
(
in_var
)
&&
in_var
->
generated_op_
&&
in_var
->
generated_op_
->
DeviceContext
(
place_
)
!=
dev_ctxes_
[
place_
];
return
need_wait
;
}
std
::
string
ComputationOpHandle
::
Name
()
const
{
return
op_
->
Type
();
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/computation_op_handle.h
浏览文件 @
9eec2c75
...
...
@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
virtual
bool
NeedWait
(
VarHandleBase
*
in_var
);
private:
std
::
unique_ptr
<
OperatorBase
>
op_
;
Scope
*
scope_
;
...
...
paddle/fluid/framework/details/fetch_op_handle.cc
浏览文件 @
9eec2c75
...
...
@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() {
}
}
void
FetchOpHandle
::
Wait
(
platform
::
DeviceContext
*
waited_dev
)
{
void
FetchOpHandle
::
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
{
PADDLE_THROW
(
"Nobody should wait FetchOp. Unexpceted Error"
);
}
...
...
@@ -45,12 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
}
void
FetchOpHandle
::
RunImpl
()
{
auto
cpu_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
());
for
(
auto
*
input
:
inputs_
)
{
auto
*
var
=
static_cast
<
VarHandle
*>
(
input
);
var
->
generated_op_
->
Wait
(
cpu_ctx
);
}
WaitInputVarGenerated
(
platform
::
CPUPlace
());
tensors_
.
resize
(
inputs_
.
size
());
auto
*
var_handle
=
static_cast
<
VarHandle
*>
(
inputs_
[
0
]);
auto
&
var_name
=
var_handle
->
name_
;
...
...
@@ -77,6 +73,15 @@ void FetchOpHandle::RunImpl() {
this
->
WaitAndMergeCPUTensors
();
}
void
FetchOpHandle
::
WaitInputVarGenerated
(
const
platform
::
Place
&
place
)
{
auto
cpu_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
for
(
auto
*
input
:
inputs_
)
{
if
(
input
->
generated_op_
)
{
input
->
generated_op_
->
RecordWaitEventOnCtx
(
cpu_ctx
);
}
}
}
std
::
string
FetchOpHandle
::
Name
()
const
{
return
"Fetch"
;
}
}
// namespace details
...
...
paddle/fluid/framework/details/fetch_op_handle.h
浏览文件 @
9eec2c75
...
...
@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase {
~
FetchOpHandle
();
void
Wait
(
platform
::
DeviceContext
*
waited_dev
)
override
;
void
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
override
;
void
WaitAndMergeCPUTensors
()
const
;
...
...
@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
virtual
void
WaitInputVarGenerated
(
const
platform
::
Place
&
place
);
private:
FeedFetchList
*
data_
;
size_t
offset_
;
...
...
paddle/fluid/framework/details/gather_op_handle.cc
浏览文件 @
9eec2c75
...
...
@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() {
"Currently, gather_op only can gather SelectedRows."
);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated
(
in_var_handles
);
WaitInputVarGenerated
();
auto
&
pre_in_value
=
pre_in_var
->
Get
<
framework
::
SelectedRows
>
();
std
::
vector
<
int64_t
>
out_rows
;
...
...
@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() {
});
}
void
GatherOpHandle
::
WaitInputVarGenerated
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
)
{
for
(
auto
*
in
:
in_var_handles
)
{
if
(
in
->
generated_op_
)
{
for
(
auto
pair
:
dev_ctxes_
)
{
in
->
generated_op_
->
Wait
(
pair
.
second
);
}
}
}
}
std
::
string
GatherOpHandle
::
Name
()
const
{
return
"gather"
;
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/gather_op_handle.h
浏览文件 @
9eec2c75
...
...
@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
void
WaitInputVarGenerated
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
);
private:
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
...
...
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
浏览文件 @
9eec2c75
...
...
@@ -34,10 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
return
;
// No need to all reduce when GPU count = 1;
}
else
{
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
p
]);
}
WaitInputVarGenerated
();
auto
&
var_name
=
static_cast
<
VarHandle
*>
(
this
->
inputs_
[
0
])
->
name_
;
int
dtype
=
-
1
;
...
...
paddle/fluid/framework/details/op_handle_base.cc
浏览文件 @
9eec2c75
...
...
@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
RunImpl
();
}
void
OpHandleBase
::
Wait
(
platform
::
DeviceContext
*
waited_dev
)
{
void
OpHandleBase
::
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
{
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_cpu_place
(
waited_
dev
->
GetPlace
())
||
events_
.
empty
())
{
if
(
platform
::
is_cpu_place
(
waited_
ctx
->
GetPlace
())
||
events_
.
empty
())
{
for
(
auto
&
dev_ctx
:
dev_ctxes_
)
{
dev_ctx
.
second
->
Wait
();
}
}
else
{
auto
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
waited_
dev
)
->
stream
();
static_cast
<
platform
::
CUDADeviceContext
*>
(
waited_
ctx
)
->
stream
();
for
(
auto
&
ev
:
events_
)
{
PADDLE_ENFORCE
(
cudaStreamWaitEvent
(
stream
,
ev
.
second
,
0
));
}
...
...
@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out
->
generated_op_
=
this
;
}
void
OpHandleBase
::
WaitInputVarGenerated
()
{
for
(
auto
in_var
:
inputs_
)
{
if
(
NeedWait
(
in_var
))
{
for
(
auto
&
pair
:
dev_ctxes_
)
{
in_var
->
generated_op_
->
RecordWaitEventOnCtx
(
pair
.
second
);
}
}
}
}
void
OpHandleBase
::
WaitInputVarGenerated
(
const
platform
::
Place
&
place
)
{
for
(
auto
*
in
:
inputs_
)
{
if
(
NeedWait
(
in
))
{
in
->
generated_op_
->
RecordWaitEventOnCtx
(
dev_ctxes_
[
place
]);
}
}
}
bool
OpHandleBase
::
NeedWait
(
VarHandleBase
*
in_var
)
{
return
dynamic_cast
<
VarHandle
*>
(
in_var
)
&&
in_var
->
generated_op_
;
}
void
OpHandleBase
::
RunAndRecordEvent
(
const
std
::
function
<
void
()
>
&
callback
)
{
#ifdef PADDLE_WITH_CUDA
if
(
!
events_
.
empty
())
{
// Use event
...
...
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
9eec2c75
...
...
@@ -38,12 +38,20 @@ class OpHandleBase {
void
Run
(
bool
use_event
);
virtual
void
Wait
(
platform
::
DeviceContext
*
waited_dev
);
virtual
void
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
);
void
AddInput
(
VarHandleBase
*
in
);
void
AddOutput
(
VarHandleBase
*
out
);
// Wait inputs are generated, this Wait is asynchronous operation.
virtual
void
WaitInputVarGenerated
();
// Wait inputs are generated, this Wait is asynchronous operation.
virtual
void
WaitInputVarGenerated
(
const
platform
::
Place
&
place
);
virtual
bool
NeedWait
(
VarHandleBase
*
in_var
);
// If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual
bool
IsMultiDeviceTransfer
()
{
return
false
;
}
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
9eec2c75
...
...
@@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() {
PADDLE_ENFORCE_NOT_NULL
(
pre_in_var
);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated
(
in_var_handles
);
WaitInputVarGenerated
();
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std
::
vector
<
platform
::
Place
>
in_places
;
// used to get dev_ctx
...
...
@@ -157,17 +157,6 @@ std::vector<const T *> ReduceOpHandle::GetInputValues(
return
in_selected_rows
;
}
void
ReduceOpHandle
::
WaitInputVarGenerated
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
)
{
for
(
auto
*
in
:
in_var_handles
)
{
if
(
in
->
generated_op_
)
{
for
(
auto
pair
:
dev_ctxes_
)
{
in
->
generated_op_
->
Wait
(
pair
.
second
);
}
}
}
}
std
::
string
ReduceOpHandle
::
Name
()
const
{
return
"reduce"
;
}
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/reduce_op_handle.h
浏览文件 @
9eec2c75
...
...
@@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase {
protected:
void
RunImpl
()
override
;
void
WaitInputVarGenerated
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
);
template
<
typename
T
>
std
::
vector
<
const
T
*>
GetInputValues
(
const
std
::
vector
<
VarHandle
*>
&
in_var_handles
,
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
浏览文件 @
9eec2c75
...
...
@@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle
::~
ScaleLossGradOpHandle
()
{}
void
ScaleLossGradOpHandle
::
RunImpl
()
{
// Doesn't wait any event
std
::
string
var_name
=
static_cast
<
VarHandle
*>
(
this
->
outputs_
[
0
])
->
name_
;
auto
&
local_scope
=
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
...
...
paddle/fluid/framework/details/send_op_handle.cc
浏览文件 @
9eec2c75
...
...
@@ -27,13 +27,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
void
SendOpHandle
::
RunImpl
()
{
// Wait input done
for
(
auto
*
in
:
inputs_
)
{
auto
&
p
=
static_cast
<
VarHandle
*>
(
in
)
->
place_
;
if
(
in
->
DebugString
()
==
"dummy"
)
{
// HACK
continue
;
}
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
p
]);
}
WaitInputVarGenerated
();
auto
&
tmp_scope
=
local_scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录