Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
53c73c77
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
53c73c77
编写于
3月 15, 2023
作者:
P
pangyoki
提交者:
GitHub
3月 15, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix cuda graph (#51648)
上级
4283e19e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
71 addition
and
56 deletion
+71
-56
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+10
-2
paddle/fluid/platform/cuda_graph_with_memory_pool.cc
paddle/fluid/platform/cuda_graph_with_memory_pool.cc
+61
-44
paddle/phi/backends/gpu/cuda/cuda_graph.h
paddle/phi/backends/gpu/cuda/cuda_graph.h
+0
-10
未找到文件。
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
53c73c77
...
@@ -545,7 +545,7 @@ void InterpreterCore::PrepareForCUDAGraphCapture() {
...
@@ -545,7 +545,7 @@ void InterpreterCore::PrepareForCUDAGraphCapture() {
platform
::
IsCUDAGraphCapturing
(),
platform
::
IsCUDAGraphCapturing
(),
false
,
false
,
platform
::
errors
::
PermissionDenied
(
"CUDA Graph is not allowed to capture "
platform
::
errors
::
PermissionDenied
(
"CUDA Graph is not allowed to capture "
"
when running the first batch
."
));
"
before prepare
."
));
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
place_
),
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
place_
),
true
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -684,8 +684,16 @@ void InterpreterCore::Convert(
...
@@ -684,8 +684,16 @@ void InterpreterCore::Convert(
if
(
op_type
==
interpreter
::
kMemcpyD2H
||
if
(
op_type
==
interpreter
::
kMemcpyD2H
||
op_type
==
interpreter
::
kMemcpyH2D
)
{
op_type
==
interpreter
::
kMemcpyH2D
)
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"
op_type can't be memcpy d2h or h2
d while using cuda graph."
));
"
Cuda memory copy d2h/h2d is not allowe
d while using cuda graph."
));
}
}
PADDLE_ENFORCE_EQ
(
typeid
(
*
dev_ctx_
)
==
typeid
(
phi
::
GPUContext
),
true
,
platform
::
errors
::
InvalidArgument
(
"Device context of op %s must be [%s] while using "
"cuda graph, but got [%s]."
,
op_type
,
typeid
(
phi
::
GPUContext
).
name
(),
typeid
(
*
dev_ctx_
).
name
()));
// cuda graph needs to record all stream
// cuda graph needs to record all stream
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
.
RecordCapturingDeviceContext
(
dev_ctx_
);
.
RecordCapturingDeviceContext
(
dev_ctx_
);
...
...
paddle/fluid/platform/cuda_graph_with_memory_pool.cc
浏览文件 @
53c73c77
...
@@ -40,32 +40,58 @@ void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) {
...
@@ -40,32 +40,58 @@ void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) {
dev_ctx
->
cusolver_dn_handle
();
dev_ctx
->
cusolver_dn_handle
();
}
}
phi
::
DeviceContext
*
SelectCUDAGraphDeviceContext
(
phi
::
GPUPlace
place
,
int64_t
*
pool_id
)
{
phi
::
DeviceContext
*
mutable_dev_ctx
;
auto
all_capturing_dev_ctxs
=
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
.
GetAllCapturingDeviceContexts
();
auto
num_stream
=
all_capturing_dev_ctxs
.
size
();
if
(
num_stream
>
0
)
{
// Capturing device contexts will only be recorded in new
// executor in temporary, that is,
// FLAGS_new_executor_use_cuda_graph needs to be set to True.
// This restriction can be removed if device context is
// recorded in other modes.
// Record method: RecordCapturingDeviceContext.
PADDLE_ENFORCE_EQ
(
FLAGS_new_executor_use_cuda_graph
,
true
,
platform
::
errors
::
InvalidArgument
(
"FLAGS_new_executor_use_cuda_graph must be True when "
"capturing stream is recorded."
));
if
(
num_stream
>
1
)
{
VLOG
(
4
)
<<
"Use a new stream to capture cuda graph. Used in multi-stream "
"scenarios with new executor."
;
if
(
*
pool_id
<=
CUDAGraph
::
kInvalidPoolID
)
{
*
pool_id
=
CUDAGraph
::
UniqueMemoryPoolID
();
}
mutable_dev_ctx
=
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
().
Get
(
*
pool_id
,
place
,
0
);
}
else
if
(
num_stream
==
1
)
{
VLOG
(
4
)
<<
"Use recorded stream to capture cuda graph. Used in "
"single-stream scenarios with new executor."
;
mutable_dev_ctx
=
*
(
all_capturing_dev_ctxs
.
begin
());
}
}
else
{
VLOG
(
4
)
<<
"Use default stream to capture cuda graph."
;
mutable_dev_ctx
=
phi
::
DeviceContextPool
::
Instance
().
Get
(
place
);
}
return
mutable_dev_ctx
;
}
void
BeginCUDAGraphCapture
(
phi
::
GPUPlace
place
,
void
BeginCUDAGraphCapture
(
phi
::
GPUPlace
place
,
cudaStreamCaptureMode
mode
,
cudaStreamCaptureMode
mode
,
int64_t
pool_id
)
{
int64_t
pool_id
)
{
auto
*
mutable_dev_ctx
=
phi
::
DeviceContextPool
::
Instance
().
Get
(
place
);
auto
*
mutable_dev_ctx
=
SelectCUDAGraphDeviceContext
(
place
,
&
pool_id
);
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
mutable_dev_ctx
);
InitCUDNNRelatedHandle
(
dev_ctx
);
auto
all_capturing_dev_ctxs
=
auto
all_capturing_dev_ctxs
=
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
.
GetAllCapturingDeviceContexts
();
.
GetAllCapturingDeviceContexts
();
// create_cuda_graph_stream: Whether to create a new stream to
auto
num_stream
=
all_capturing_dev_ctxs
.
size
();
// capture cuda graph, usually used in multi-stream scenarios.
if
(
num_stream
>
1
)
{
// Can only be used for new executor in static mode, that is,
// FLAGS_new_executor_use_cuda_graph needs to be set to True.
bool
create_cuda_graph_stream
=
false
;
if
(
FLAGS_new_executor_use_cuda_graph
&&
(
all_capturing_dev_ctxs
.
size
()
>
1
||
(
all_capturing_dev_ctxs
.
size
()
==
1
&&
(
*
(
all_capturing_dev_ctxs
.
begin
())
!=
mutable_dev_ctx
))))
{
create_cuda_graph_stream
=
true
;
}
if
(
create_cuda_graph_stream
)
{
VLOG
(
4
)
<<
"create a new stream to capture cuda graph."
;
if
(
pool_id
<=
CUDAGraph
::
kInvalidPoolID
)
{
pool_id
=
CUDAGraph
::
UniqueMemoryPoolID
();
}
mutable_dev_ctx
=
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
().
Get
(
pool_id
,
place
,
0
);
for
(
auto
iter
=
all_capturing_dev_ctxs
.
begin
();
for
(
auto
iter
=
all_capturing_dev_ctxs
.
begin
();
iter
!=
all_capturing_dev_ctxs
.
end
();
iter
!=
all_capturing_dev_ctxs
.
end
();
++
iter
)
{
++
iter
)
{
...
@@ -73,12 +99,9 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
...
@@ -73,12 +99,9 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
InitCUDNNRelatedHandle
(
capturing_dev_ctx
);
InitCUDNNRelatedHandle
(
capturing_dev_ctx
);
}
}
}
}
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
mutable_dev_ctx
);
InitCUDNNRelatedHandle
(
dev_ctx
);
auto
stream
=
dev_ctx
->
stream
();
auto
stream
=
dev_ctx
->
stream
();
CUDAGraph
::
BeginCapture
(
place
,
stream
,
mode
);
CUDAGraph
::
BeginCapture
(
place
,
stream
,
mode
);
CUDAGraph
::
SetIsCUDAGraphStreamCreated
(
create_cuda_graph_stream
);
// When using cuda graph in new executor, fast GC must be used.
// When using cuda graph in new executor, fast GC must be used.
// FLAGS_use_stream_safe_cuda_allocator should be true.
// FLAGS_use_stream_safe_cuda_allocator should be true.
...
@@ -96,7 +119,7 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
...
@@ -96,7 +119,7 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
if
(
old_value
)
{
if
(
old_value
)
{
FLAGS_use_stream_safe_cuda_allocator
=
true
;
FLAGS_use_stream_safe_cuda_allocator
=
true
;
}
}
if
(
create_cuda_graph_stream
)
{
if
(
num_stream
>
1
)
{
// Set cuda graph allocator for all streams.
// Set cuda graph allocator for all streams.
// Establish dependencies between cuda graph stream and all other streams
// Establish dependencies between cuda graph stream and all other streams
// using eventWait, so that all streams will be captured.
// using eventWait, so that all streams will be captured.
...
@@ -129,20 +152,17 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
...
@@ -129,20 +152,17 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
}
}
std
::
unique_ptr
<
CUDAGraph
>
EndCUDAGraphCapture
()
{
std
::
unique_ptr
<
CUDAGraph
>
EndCUDAGraphCapture
()
{
phi
::
DeviceContext
*
mutable_dev_ctx
;
auto
place
=
CUDAGraph
::
CapturingPlace
();
auto
place
=
CUDAGraph
::
CapturingPlace
();
bool
create_cuda_graph_stream
=
CUDAGraph
::
IsCUDAGraphStreamCreated
();
auto
pool_id
=
CUDAGraph
::
CapturingPoolID
();
if
(
create_cuda_graph_stream
)
{
auto
*
mutable_dev_ctx
=
SelectCUDAGraphDeviceContext
(
place
,
&
pool_id
);
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
mutable_dev_ctx
);
auto
all_capturing_dev_ctxs
=
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
.
GetAllCapturingDeviceContexts
();
auto
num_stream
=
all_capturing_dev_ctxs
.
size
();
if
(
num_stream
>
1
)
{
// join all other streams back to origin cuda graph stream.
// join all other streams back to origin cuda graph stream.
int64_t
pool_id
=
CUDAGraph
::
CapturingPoolID
();
mutable_dev_ctx
=
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
().
Get
(
pool_id
,
place
,
0
);
auto
*
cuda_graph_dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
mutable_dev_ctx
);
auto
all_capturing_dev_ctxs
=
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
.
GetAllCapturingDeviceContexts
();
for
(
auto
iter
=
all_capturing_dev_ctxs
.
begin
();
for
(
auto
iter
=
all_capturing_dev_ctxs
.
begin
();
iter
!=
all_capturing_dev_ctxs
.
end
();
iter
!=
all_capturing_dev_ctxs
.
end
();
++
iter
)
{
++
iter
)
{
...
@@ -152,19 +172,16 @@ std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
...
@@ -152,19 +172,16 @@ std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
capturing_dev_ctx
->
GetPlace
(),
capturing_dev_ctx
->
GetPlace
(),
platform
::
GenerateDeviceEventFlag
());
platform
::
GenerateDeviceEventFlag
());
capturing_event
->
Record
(
capturing_dev_ctx
);
capturing_event
->
Record
(
capturing_dev_ctx
);
capturing_event
->
Wait
(
platform
::
kCUDA
,
cuda_graph_dev_ctx
);
capturing_event
->
Wait
(
platform
::
kCUDA
,
dev_ctx
);
VLOG
(
4
)
<<
"CUDA Graph stream eventWait. cuda graph dev_ctx: "
VLOG
(
4
)
<<
"CUDA Graph stream eventWait. cuda graph dev_ctx: "
<<
dev_ctx
<<
cuda_graph_dev_ctx
<<
" wait for capturing dev_ctx: "
<<
capturing_dev_ctx
;
<<
" wait for capturing dev_ctx: "
<<
capturing_dev_ctx
;
capturing_dev_ctx
->
cudnn_workspace_handle
().
ResetWorkspace
();
capturing_dev_ctx
->
cudnn_workspace_handle
().
ResetWorkspace
();
capturing_dev_ctx
->
SetCUDAGraphAllocator
(
nullptr
);
capturing_dev_ctx
->
SetCUDAGraphAllocator
(
nullptr
);
}
}
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
.
ClearDeviceContextsRecords
();
}
else
{
mutable_dev_ctx
=
phi
::
DeviceContextPool
::
Instance
().
Get
(
place
);
}
}
auto
*
dev_ctx
=
reinterpret_cast
<
phi
::
GPUContext
*>
(
mutable_dev_ctx
);
phi
::
backends
::
gpu
::
CUDAGraphContextManager
::
Instance
()
.
ClearDeviceContextsRecords
();
dev_ctx
->
cudnn_workspace_handle
().
ResetWorkspace
();
dev_ctx
->
cudnn_workspace_handle
().
ResetWorkspace
();
dev_ctx
->
SetCUDAGraphAllocator
(
nullptr
);
dev_ctx
->
SetCUDAGraphAllocator
(
nullptr
);
return
CUDAGraph
::
EndCapture
();
return
CUDAGraph
::
EndCapture
();
...
...
paddle/phi/backends/gpu/cuda/cuda_graph.h
浏览文件 @
53c73c77
...
@@ -196,14 +196,6 @@ class CUDAGraph {
...
@@ -196,14 +196,6 @@ class CUDAGraph {
// supported during capturing CUDA Graph.
// supported during capturing CUDA Graph.
static
bool
IsValidCapturing
();
static
bool
IsValidCapturing
();
static
void
SetIsCUDAGraphStreamCreated
(
bool
create_cuda_graph_stream
)
{
capturing_graph_
->
is_cuda_graph_stream_created_
=
create_cuda_graph_stream
;
}
static
bool
IsCUDAGraphStreamCreated
()
{
return
capturing_graph_
->
is_cuda_graph_stream_created_
;
}
static
bool
IsThreadLocalCapturing
()
{
static
bool
IsThreadLocalCapturing
()
{
#if CUDA_VERSION >= 10010
#if CUDA_VERSION >= 10010
return
IsCapturing
()
&&
return
IsCapturing
()
&&
...
@@ -254,8 +246,6 @@ class CUDAGraph {
...
@@ -254,8 +246,6 @@ class CUDAGraph {
bool
is_first_run_
{
true
};
bool
is_first_run_
{
true
};
bool
is_cuda_graph_stream_created_
{
false
};
static
paddle
::
optional
<
std
::
thread
::
id
>
capturing_thread_id_
;
static
paddle
::
optional
<
std
::
thread
::
id
>
capturing_thread_id_
;
static
std
::
unique_ptr
<
CUDAGraph
>
capturing_graph_
;
static
std
::
unique_ptr
<
CUDAGraph
>
capturing_graph_
;
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录