Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fe35496b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fe35496b
编写于
9月 22, 2021
作者:
A
Aurelius84
提交者:
GitHub
9月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Modify H2D and D2H as kQueue::Sync and Polish Schedule logic (#35866)
* Modify H2D and D2H as kQueue::Sync * fix interface error
上级
ae65257d
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
79 addition
and
66 deletion
+79
-66
paddle/fluid/framework/new_executor/event_manager.cc
paddle/fluid/framework/new_executor/event_manager.cc
+5
-15
paddle/fluid/framework/new_executor/event_manager.h
paddle/fluid/framework/new_executor/event_manager.h
+0
-4
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+1
-2
paddle/fluid/framework/new_executor/interpretercore.h
paddle/fluid/framework/new_executor/interpretercore.h
+0
-1
paddle/fluid/framework/new_executor/interpretercore_util.cc
paddle/fluid/framework/new_executor/interpretercore_util.cc
+3
-1
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+16
-8
paddle/fluid/framework/new_executor/stream_analyzer.cc
paddle/fluid/framework/new_executor/stream_analyzer.cc
+44
-31
paddle/fluid/framework/new_executor/stream_analyzer.h
paddle/fluid/framework/new_executor/stream_analyzer.h
+10
-4
未找到文件。
paddle/fluid/framework/new_executor/event_manager.cc
浏览文件 @
fe35496b
...
@@ -24,9 +24,12 @@ void EventManager::WaitEvent(const Instruction& instruction,
...
@@ -24,9 +24,12 @@ void EventManager::WaitEvent(const Instruction& instruction,
VLOG
(
3
)
<<
"Deal StreamWaitEventOrSync for "
VLOG
(
3
)
<<
"Deal StreamWaitEventOrSync for "
<<
instruction
.
kernel_func_
.
operator_base_
->
Type
();
<<
instruction
.
kernel_func_
.
operator_base_
->
Type
();
auto
*
dev_ctx
=
instruction
.
dev_ctx_
;
WaitOrSync
(
instruction
.
intput_events_
,
dev_ctx
);
for
(
auto
&
event_iter
:
instruction
.
intput_events_
)
{
VLOG
(
3
)
<<
"wait var_id: "
<<
event_iter
.
var_id_
<<
" 's event with waiter_type: "
<<
event_iter
.
waiter_type_
;
event_iter
.
event_
->
Wait
(
event_iter
.
waiter_type_
,
instruction
.
dev_ctx_
);
}
}
}
void
EventManager
::
RecordEvent
(
const
Instruction
&
instruction
,
void
EventManager
::
RecordEvent
(
const
Instruction
&
instruction
,
...
@@ -40,18 +43,5 @@ void EventManager::RecordEvent(const Instruction& instruction,
...
@@ -40,18 +43,5 @@ void EventManager::RecordEvent(const Instruction& instruction,
}
}
}
}
void
EventManager
::
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
)
{
for
(
auto
&
event_iter
:
events
)
{
if
(
event_iter
.
is_sync_
)
{
VLOG
(
3
)
<<
"host sync wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCPU
,
dev_ctx
);
}
else
{
VLOG
(
3
)
<<
"stream async wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCUDA
,
dev_ctx
);
}
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/event_manager.h
浏览文件 @
fe35496b
...
@@ -24,10 +24,6 @@ class EventManager {
...
@@ -24,10 +24,6 @@ class EventManager {
const
platform
::
Place
&
place
);
const
platform
::
Place
&
place
);
void
WaitEvent
(
const
Instruction
&
instruction
,
const
platform
::
Place
&
place
);
void
WaitEvent
(
const
Instruction
&
instruction
,
const
platform
::
Place
&
place
);
private:
void
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
);
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
fe35496b
...
@@ -183,8 +183,7 @@ void InterpreterCore::Convert() {
...
@@ -183,8 +183,7 @@ void InterpreterCore::Convert() {
}
}
}
}
stream_analyzer_
.
Schedule
(
vec_func_list_
,
filter_next
,
i
,
stream_analyzer_
.
Schedule
(
filter_next
,
&
vec_instruction_
,
i
);
&
vec_instruction_
);
for
(
auto
inst_id
:
filter_next
)
{
for
(
auto
inst_id
:
filter_next
)
{
dependecy_count_
[
inst_id
]
++
;
dependecy_count_
[
inst_id
]
++
;
...
...
paddle/fluid/framework/new_executor/interpretercore.h
浏览文件 @
fe35496b
...
@@ -99,7 +99,6 @@ class InterpreterCore {
...
@@ -99,7 +99,6 @@ class InterpreterCore {
InterpreterCoreGarbageCollector
gc_
;
InterpreterCoreGarbageCollector
gc_
;
std
::
vector
<
paddle
::
platform
::
DeviceEvent
>
gc_event_
;
std
::
vector
<
paddle
::
platform
::
DeviceEvent
>
gc_event_
;
std
::
unique_ptr
<
WorkQueueGroup
>
group_thread_pool_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/interpretercore_util.cc
浏览文件 @
fe35496b
...
@@ -365,7 +365,9 @@ void build_op_func_list(const platform::Place& place,
...
@@ -365,7 +365,9 @@ void build_op_func_list(const platform::Place& place,
OpKernelComputeFunc
(
kernel_iter
->
second
);
OpKernelComputeFunc
(
kernel_iter
->
second
);
copy_op_func_node
.
kernel_func_
(
copy_exec_ctx
);
copy_op_func_node
.
kernel_func_
(
copy_exec_ctx
);
VLOG
(
3
)
<<
"Run "
<<
memcpy_op_type
<<
" done."
;
VLOG
(
3
)
<<
"Run "
<<
memcpy_op_type
<<
" done."
;
copy_op_func_node
.
type_
=
OpFuncType
::
kQueueAsync
;
// NOTE(Aurelius84): memcpy_op is expensive operation, so we tag them
// as kQueueSync and execute them in thread pool.
copy_op_func_node
.
type_
=
OpFuncType
::
kQueueSync
;
copy_op_func_node
.
dev_ctx_
=
dev_ctx
;
copy_op_func_node
.
dev_ctx_
=
dev_ctx
;
op_list
->
push_back
(
copy_op
);
op_list
->
push_back
(
copy_op
);
vec_func_list
->
push_back
(
copy_op_func_node
);
vec_func_list
->
push_back
(
copy_op_func_node
);
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
fe35496b
...
@@ -25,11 +25,6 @@
...
@@ -25,11 +25,6 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
interpretercore
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
}
// namespace interpretercore
using
OpKernelComputeFunc
=
std
::
function
<
void
(
const
ExecutionContext
&
)
>
;
using
OpKernelComputeFunc
=
std
::
function
<
void
(
const
ExecutionContext
&
)
>
;
using
OpKernelMap
=
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelType
,
OpKernelComputeFunc
,
OpKernelType
::
Hash
>
;
std
::
unordered_map
<
OpKernelType
,
OpKernelComputeFunc
,
OpKernelType
::
Hash
>
;
...
@@ -496,11 +491,11 @@ struct NextInstruction {
...
@@ -496,11 +491,11 @@ struct NextInstruction {
struct
EventInter
{
struct
EventInter
{
explicit
EventInter
(
size_t
var_id
,
explicit
EventInter
(
size_t
var_id
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>
event
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>
event
,
bool
is_sync
)
platform
::
DeviceType
waiter_type
)
:
var_id_
(
var_id
),
event_
(
event
),
is_sync_
(
is_sync
)
{}
:
var_id_
(
var_id
),
event_
(
event
),
waiter_type_
(
waiter_type
)
{}
size_t
var_id_
;
size_t
var_id_
;
std
::
shared_ptr
<
platform
::
DeviceEvent
>
event_
;
std
::
shared_ptr
<
platform
::
DeviceEvent
>
event_
;
bool
is_sync
_
;
platform
::
DeviceType
waiter_type
_
;
};
};
struct
InstructionInfo
{
struct
InstructionInfo
{
...
@@ -543,5 +538,18 @@ struct OpFuncNode {
...
@@ -543,5 +538,18 @@ struct OpFuncNode {
OpFuncType
type_
;
OpFuncType
type_
;
};
};
namespace
interpretercore
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
static
bool
IsMemcpyH2D
(
const
Instruction
&
instr
)
{
return
instr
.
kernel_func_
.
operator_base_
->
Type
()
==
kMemcpyH2D
;
}
static
bool
IsMemcpyD2H
(
const
Instruction
&
instr
)
{
return
instr
.
kernel_func_
.
operator_base_
->
Type
()
==
kMemcpyD2H
;
}
}
// namespace interpretercore
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/stream_analyzer.cc
浏览文件 @
fe35496b
...
@@ -22,7 +22,7 @@ namespace framework {
...
@@ -22,7 +22,7 @@ namespace framework {
* Parse the var_ids that need to be associated with an event.
* Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the
* The caller should guarantee front_op and back_op satisfy the
* following conditions:
* following conditions:
* 1. kQueue
As
ync -> kQueueAsync
* 1. kQueue
S
ync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync
* 2. kQueueAsync -> kQueueSync
*
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
* For example: matmul(gpu) -> out_var -> memcpy_d2h
...
@@ -48,7 +48,7 @@ std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
...
@@ -48,7 +48,7 @@ std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
void
StreamAnalyzer
::
AssociateInputWithEvents
(
void
StreamAnalyzer
::
AssociateInputWithEvents
(
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
bool
is_sync
)
{
platform
::
DeviceType
waiter_type
)
{
for
(
auto
var_id
:
new_event_var_id
)
{
for
(
auto
var_id
:
new_event_var_id
)
{
if
(
var_id2event_
.
count
(
var_id
)
==
0
)
{
if
(
var_id2event_
.
count
(
var_id
)
==
0
)
{
auto
device_event
=
std
::
make_shared
<
platform
::
DeviceEvent
>
(
auto
device_event
=
std
::
make_shared
<
platform
::
DeviceEvent
>
(
...
@@ -57,52 +57,43 @@ void StreamAnalyzer::AssociateInputWithEvents(
...
@@ -57,52 +57,43 @@ void StreamAnalyzer::AssociateInputWithEvents(
}
}
// Add events for next_instr.inputs
// Add events for next_instr.inputs
next_instr
->
intput_events_
.
emplace_back
(
var_id
,
var_id2event_
.
at
(
var_id
),
next_instr
->
intput_events_
.
emplace_back
(
var_id
,
var_id2event_
.
at
(
var_id
),
is_sync
);
waiter_type
);
}
}
}
}
void
StreamAnalyzer
::
Schedule
(
const
std
::
vector
<
OpFuncNode
>&
op_func_nodes
,
void
StreamAnalyzer
::
Schedule
(
const
std
::
vector
<
size_t
>&
downstream_ops
,
const
std
::
vector
<
size_t
>&
downstream_ops
,
std
::
vector
<
Instruction
>*
instructions
,
size_t
op_index
,
size_t
op_index
)
{
std
::
vector
<
Instruction
>*
instructions
)
{
auto
&
op_func_type
=
op_func_nodes
[
op_index
].
type_
;
auto
&
cur_instr
=
instructions
->
at
(
op_index
);
auto
&
cur_instr
=
instructions
->
at
(
op_index
);
auto
&
next_instruction
=
cur_instr
.
next_instruction_
;
auto
&
next_instruction
=
cur_instr
.
next_instruction_
;
if
(
op_func_type
==
OpFuncType
::
kQueueSync
)
{
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
next_instruction
.
direct_run_
=
downstream_ops
;
}
else
{
// kQueueAsync
std
::
vector
<
size_t
>
event_var_ids
;
std
::
vector
<
size_t
>
event_var_ids
;
for
(
auto
next_op_id
:
downstream_ops
)
{
for
(
auto
next_op_id
:
downstream_ops
)
{
auto
&
next_instr
=
instructions
->
at
(
next_op_id
);
auto
&
next_instr
=
instructions
->
at
(
next_op_id
);
// case 1: GPU -> GPU(same stream)
if
(
cur_instr
.
dev_ctx_
==
next_instr
.
dev_ctx_
)
{
if
(
IsDirectRun
(
cur_instr
,
next_instr
)
)
{
next_instruction
.
direct_run_
.
emplace_back
(
next_op_id
);
next_instruction
.
direct_run_
.
emplace_back
(
next_op_id
);
continue
;
}
else
{
}
// Always insert events between different stream
// Always insert events between different stream
auto
new_event_var_ids
=
ParseEventVarIds
(
cur_instr
,
next_instr
);
auto
new_event_var_ids
=
ParseEventVarIds
(
cur_instr
,
next_instr
);
event_var_ids
.
insert
(
event_var_ids
.
end
(),
new_event_var_ids
.
begin
(),
event_var_ids
.
insert
(
event_var_ids
.
end
(),
new_event_var_ids
.
begin
(),
new_event_var_ids
.
end
());
new_event_var_ids
.
end
());
bool
is_sync
=
auto
waiter_type
=
GetWaiterType
(
next_instr
);
(
op_func_nodes
[
next_op_id
].
type_
==
OpFuncType
::
kQueueSync
);
AssociateInputWithEvents
(
new_event_var_ids
,
&
next_instr
,
waiter_type
);
AssociateInputWithEvents
(
new_event_var_ids
,
&
next_instr
,
is_sync
);
if
(
is_sync
)
{
// GPU -> CPU
if
(
waiter_type
==
platform
::
kCPU
)
{
// GPU -> CPU
next_instruction
.
synchronize_run_
.
emplace_back
(
next_op_id
);
next_instruction
.
synchronize_run_
.
emplace_back
(
next_op_id
);
}
else
{
// GPU -> GPU(different stream)
}
else
{
// GPU -> GPU(different stream)
next_instruction
.
event_wait_run_
.
emplace_back
(
next_op_id
);
next_instruction
.
event_wait_run_
.
emplace_back
(
next_op_id
);
}
}
}
}
}
// Create events for these cross-stream vars
// Create events for these cross-stream vars
VLOG
(
3
)
<<
cur_instr
.
kernel_func_
.
operator_base_
->
Type
()
VLOG
(
3
)
<<
cur_instr
.
kernel_func_
.
operator_base_
->
Type
()
<<
" event_var_ids.size: "
<<
event_var_ids
.
size
();
<<
" event_var_ids.size: "
<<
event_var_ids
.
size
();
for
(
auto
var_id
:
event_var_ids
)
{
for
(
auto
var_id
:
event_var_ids
)
{
cur_instr
.
output_events_
.
emplace_back
(
var_id
,
var_id2event_
.
at
(
var_id
),
cur_instr
.
output_events_
.
emplace_back
(
var_id
,
var_id2event_
.
at
(
var_id
),
false
/*not used*/
);
platform
::
kCUDA
/*not used*/
);
}
}
}
}
}
...
@@ -121,5 +112,27 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
...
@@ -121,5 +112,27 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
return
dev_ctx
;
return
dev_ctx
;
}
}
/*
* NOTE(dev): The following cases are considered as directly run:
*
* 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU
* 2. D2H -> CPU
* 3. CPU -> H2D
*/
bool
StreamAnalyzer
::
IsDirectRun
(
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
)
{
return
(
cur_instr
.
dev_ctx_
==
next_instr
.
dev_ctx_
||
interpretercore
::
IsMemcpyD2H
(
cur_instr
)
||
interpretercore
::
IsMemcpyH2D
(
next_instr
));
}
platform
::
DeviceType
StreamAnalyzer
::
GetWaiterType
(
const
Instruction
&
instr
)
{
if
(
instr
.
type_
==
OpFuncType
::
kQueueSync
)
{
return
platform
::
kCPU
;
}
else
{
return
platform
::
kCUDA
;
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/stream_analyzer.h
浏览文件 @
fe35496b
...
@@ -29,9 +29,8 @@ class StreamAnalyzer {
...
@@ -29,9 +29,8 @@ class StreamAnalyzer {
~
StreamAnalyzer
()
{}
~
StreamAnalyzer
()
{}
void
Schedule
(
const
std
::
vector
<
OpFuncNode
>&
op_func_nodes
,
void
Schedule
(
const
std
::
vector
<
size_t
>&
downstream_ops
,
const
std
::
vector
<
size_t
>&
downstream_ops
,
size_t
op_index
,
std
::
vector
<
Instruction
>*
instructions
,
size_t
op_index
);
std
::
vector
<
Instruction
>*
instructions
);
platform
::
DeviceContext
*
ParseDeviceContext
(
const
OpFuncNode
&
op_func_node
,
platform
::
DeviceContext
*
ParseDeviceContext
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
);
const
OperatorBase
&
op_base
);
...
@@ -41,7 +40,14 @@ class StreamAnalyzer {
...
@@ -41,7 +40,14 @@ class StreamAnalyzer {
const
Instruction
&
next_instr
);
const
Instruction
&
next_instr
);
void
AssociateInputWithEvents
(
const
std
::
vector
<
size_t
>&
new_event_var_id
,
void
AssociateInputWithEvents
(
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
bool
is_sync
);
Instruction
*
next_instr
,
platform
::
DeviceType
waiter_type
);
bool
IsDirectRun
(
Instruction
&
cur_instr
,
// NOLINT
const
Instruction
&
next_instr
);
platform
::
DeviceType
GetWaiterType
(
const
Instruction
&
instr
);
platform
::
Place
place_
;
platform
::
Place
place_
;
platform
::
DeviceContextPool
d2h_ctx_pool_
;
platform
::
DeviceContextPool
d2h_ctx_pool_
;
platform
::
DeviceContextPool
h2d_ctx_pool_
;
platform
::
DeviceContextPool
h2d_ctx_pool_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录