Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
925c17ab
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
925c17ab
编写于
4月 13, 2018
作者:
Y
Yu Yang
提交者:
GitHub
4月 13, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9895 from reyoung/feature/fix_transformer_hang
Fix Transformer Hang Problem
上级
51c219c9
6b20b355
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
54 addition
and
25 deletion
+54
-25
paddle/fluid/framework/details/computation_op_handle.cc
paddle/fluid/framework/details/computation_op_handle.cc
+3
-1
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
+6
-4
paddle/fluid/framework/details/op_handle_base.cc
paddle/fluid/framework/details/op_handle_base.cc
+21
-11
paddle/fluid/framework/details/op_handle_base.h
paddle/fluid/framework/details/op_handle_base.h
+2
-0
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+9
-5
paddle/fluid/framework/details/send_op_handle.cc
paddle/fluid/framework/details/send_op_handle.cc
+1
-1
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+3
-1
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+1
-1
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+8
-1
未找到文件。
paddle/fluid/framework/details/computation_op_handle.cc
浏览文件 @
925c17ab
...
@@ -35,7 +35,9 @@ void ComputationOpHandle::RunImpl() {
...
@@ -35,7 +35,9 @@ void ComputationOpHandle::RunImpl() {
}
}
}
}
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
this
->
RunAndRecordEvent
([
this
]
{
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
});
}
}
std
::
string
ComputationOpHandle
::
Name
()
const
{
return
op_
->
Type
();
}
std
::
string
ComputationOpHandle
::
Name
()
const
{
return
op_
->
Type
();
}
...
...
paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
浏览文件 @
925c17ab
...
@@ -69,10 +69,12 @@ void NCCLAllReduceOpHandle::RunImpl() {
...
@@ -69,10 +69,12 @@ void NCCLAllReduceOpHandle::RunImpl() {
});
});
}
}
platform
::
NCCLGroupGuard
guard
;
this
->
RunAndRecordEvent
([
&
]
{
for
(
auto
&
call
:
all_reduce_calls
)
{
platform
::
NCCLGroupGuard
guard
;
call
();
for
(
auto
&
call
:
all_reduce_calls
)
{
}
call
();
}
});
}
}
}
}
...
...
paddle/fluid/framework/details/op_handle_base.cc
浏览文件 @
925c17ab
...
@@ -54,17 +54,6 @@ void OpHandleBase::Run(bool use_event) {
...
@@ -54,17 +54,6 @@ void OpHandleBase::Run(bool use_event) {
#endif
#endif
RunImpl
();
RunImpl
();
#ifdef PADDLE_WITH_CUDA
if
(
use_event
)
{
for
(
auto
&
p
:
dev_ctxes_
)
{
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
.
first
).
device
;
auto
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
p
.
second
)
->
stream
();
PADDLE_ENFORCE
(
cudaEventRecord
(
events_
.
at
(
dev_id
),
stream
));
}
}
#endif
}
}
void
OpHandleBase
::
Wait
(
platform
::
DeviceContext
*
waited_dev
)
{
void
OpHandleBase
::
Wait
(
platform
::
DeviceContext
*
waited_dev
)
{
...
@@ -97,6 +86,27 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
...
@@ -97,6 +86,27 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out
->
generated_op_
=
this
;
out
->
generated_op_
=
this
;
}
}
void
OpHandleBase
::
RunAndRecordEvent
(
const
std
::
function
<
void
()
>
&
callback
)
{
#ifdef PADDLE_WITH_CUDA
if
(
!
events_
.
empty
())
{
// Use event
std
::
function
<
void
()
>
method
=
callback
;
for
(
auto
&
p
:
dev_ctxes_
)
{
method
=
[
method
,
p
,
this
]()
{
static_cast
<
platform
::
CUDADeviceContext
*>
(
p
.
second
)
->
RecordEvent
(
events_
.
at
(
boost
::
get
<
platform
::
CUDAPlace
>
(
p
.
first
).
device
),
method
);
};
}
method
();
}
else
{
#endif
callback
();
#ifdef PADDLE_WITH_CUDA
}
#endif
}
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/op_handle_base.h
浏览文件 @
925c17ab
...
@@ -62,6 +62,8 @@ class OpHandleBase {
...
@@ -62,6 +62,8 @@ class OpHandleBase {
virtual
bool
IsMultiDeviceTransfer
()
{
return
false
;
}
virtual
bool
IsMultiDeviceTransfer
()
{
return
false
;
}
protected:
protected:
void
RunAndRecordEvent
(
const
std
::
function
<
void
()
>
&
callback
);
virtual
void
RunImpl
()
=
0
;
virtual
void
RunImpl
()
=
0
;
};
};
...
...
paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
浏览文件 @
925c17ab
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include <string>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
...
@@ -37,11 +39,13 @@ void ScaleLossGradOpHandle::RunImpl() {
...
@@ -37,11 +39,13 @@ void ScaleLossGradOpHandle::RunImpl() {
*
tmp
=
coeff_
;
*
tmp
=
coeff_
;
}
else
{
}
else
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
auto
stream
=
this
->
RunAndRecordEvent
([
&
]
{
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctxes_
[
place_
])
auto
stream
=
->
stream
();
static_cast
<
platform
::
CUDADeviceContext
*>
(
this
->
dev_ctxes_
[
place_
])
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
->
stream
();
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
stream
);
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place_
),
tmp
,
platform
::
CPUPlace
(),
&
coeff_
,
sizeof
(
float
),
stream
);
});
#endif
#endif
}
}
}
}
...
...
paddle/fluid/framework/details/send_op_handle.cc
浏览文件 @
925c17ab
...
@@ -34,7 +34,7 @@ void SendOpHandle::RunImpl() {
...
@@ -34,7 +34,7 @@ void SendOpHandle::RunImpl() {
}
}
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
p
]);
in
->
generated_op_
->
Wait
(
dev_ctxes_
[
p
]);
}
}
op_
->
Run
(
*
local_scope_
,
place_
);
this
->
RunAndRecordEvent
([
&
]
{
op_
->
Run
(
*
local_scope_
,
place_
);
}
);
}
}
std
::
string
SendOpHandle
::
Name
()
const
{
return
"send"
;
}
std
::
string
SendOpHandle
::
Name
()
const
{
return
"send"
;
}
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
925c17ab
...
@@ -196,10 +196,12 @@ void ThreadedSSAGraphExecutor::RunOp(
...
@@ -196,10 +196,12 @@ void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
BlockingQueue
<
VarHandleBase
*>
*
ready_var_q
,
details
::
OpHandleBase
*
op
)
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
auto
op_run
=
[
ready_var_q
,
op
,
this
]
{
try
{
try
{
VLOG
(
10
)
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" : "
<<
op
->
DebugString
();
op
->
Run
(
use_event_
);
op
->
Run
(
use_event_
);
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
" Done "
;
running_ops_
--
;
running_ops_
--
;
ready_var_q
->
Extend
(
op
->
outputs_
);
ready_var_q
->
Extend
(
op
->
outputs_
);
VLOG
(
10
)
<<
op
<<
" "
<<
op
->
Name
()
<<
"Signal posted"
;
}
catch
(
platform
::
EnforceNotMet
ex
)
{
}
catch
(
platform
::
EnforceNotMet
ex
)
{
exception_
.
reset
(
new
platform
::
EnforceNotMet
(
ex
));
exception_
.
reset
(
new
platform
::
EnforceNotMet
(
ex
));
}
catch
(...)
{
}
catch
(...)
{
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
925c17ab
...
@@ -175,7 +175,7 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -175,7 +175,7 @@ CUDADeviceContext::~CUDADeviceContext() {
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
void
CUDADeviceContext
::
Wait
()
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
std
::
lock_guard
<
std
::
recursive_
mutex
>
guard
(
mutex_
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaGetLastError
());
PADDLE_ENFORCE
(
cudaGetLastError
());
}
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
925c17ab
...
@@ -98,13 +98,20 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -98,13 +98,20 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cuda stream in the device context. */
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
()
const
;
cudaStream_t
stream
()
const
;
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
std
::
lock_guard
<
std
::
recursive_mutex
>
guard
(
mutex_
);
callback
();
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
private:
private:
CUDAPlace
place_
;
CUDAPlace
place_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
mutable
std
::
mutex
mutex_
;
mutable
std
::
recursive_
mutex
mutex_
;
cudaStream_t
stream_
;
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
cublasHandle_t
cublas_handle_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录