Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8528dd9f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8528dd9f
编写于
9月 14, 2021
作者:
A
Aurelius84
提交者:
GitHub
9月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor StreamAnalyzer and EventManager from InterpreterCore (#35711)
上级
85e4f45a
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
288 addition
and
171 deletion
+288
-171
paddle/fluid/framework/new_executor/CMakeLists.txt
paddle/fluid/framework/new_executor/CMakeLists.txt
+3
-1
paddle/fluid/framework/new_executor/event_manager.cc
paddle/fluid/framework/new_executor/event_manager.cc
+58
-0
paddle/fluid/framework/new_executor/event_manager.h
paddle/fluid/framework/new_executor/event_manager.h
+35
-0
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+6
-151
paddle/fluid/framework/new_executor/interpretercore.h
paddle/fluid/framework/new_executor/interpretercore.h
+4
-16
paddle/fluid/framework/new_executor/interpretercore_util.h
paddle/fluid/framework/new_executor/interpretercore_util.h
+0
-3
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+5
-0
paddle/fluid/framework/new_executor/stream_analyzer.cc
paddle/fluid/framework/new_executor/stream_analyzer.cc
+125
-0
paddle/fluid/framework/new_executor/stream_analyzer.h
paddle/fluid/framework/new_executor/stream_analyzer.h
+52
-0
未找到文件。
paddle/fluid/framework/new_executor/CMakeLists.txt
浏览文件 @
8528dd9f
...
...
@@ -5,7 +5,9 @@ graph_to_program_pass variable_helper timer monitor)
cc_library
(
workqueue SRCS workqueue.cc DEPS enforce
)
cc_library
(
interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue
${
DEVICE_EVENT_LIBS
}
)
cc_library
(
interpretercore_util SRCS interpretercore_util.cc DEPS
${
INTERPRETERCORE_DEPS
}
)
cc_library
(
interpretercore SRCS interpretercore.cc DEPS workqueue
${
DEVICE_EVENT_LIBS
}
interpretercore_util interpretercore_garbage_collector
)
cc_library
(
event_manager SRCS event_manager.cc DEPS
${
DEVICE_EVENT_LIBS
}
glog
)
cc_library
(
stream_analyzer SRCS stream_analyzer.cc DEPS
${
DEVICE_EVENT_LIBS
}
glog device_context
)
cc_library
(
interpretercore SRCS interpretercore.cc DEPS workqueue
${
DEVICE_EVENT_LIBS
}
interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager
)
cc_library
(
standalone_executor SRCS standalone_executor.cc DEPS interpretercore
)
cc_test
(
workqueue_test SRCS workqueue_test.cc DEPS workqueue
)
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
paddle/fluid/framework/new_executor/event_manager.cc
0 → 100644
浏览文件 @
8528dd9f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/event_manager.h"
namespace
paddle
{
namespace
framework
{
void
EventManager
::
WaitEvent
(
const
Instruction
&
instruction
,
const
platform
::
Place
&
place
)
{
// If InterpreterCore in on CPUPlace, do nothing.
if
(
platform
::
is_cpu_place
(
place
))
return
;
VLOG
(
3
)
<<
"Deal StreamWaitEventOrSync for "
<<
instruction
.
kernel_func_
.
operator_base_
->
Type
();
auto
*
dev_ctx
=
instruction
.
dev_ctx_
;
WaitOrSync
(
instruction
.
intput_events_
,
dev_ctx
);
}
void
EventManager
::
RecordEvent
(
const
Instruction
&
instruction
,
const
OpFuncNode
&
op_func_node
,
const
platform
::
Place
&
place
)
{
// If InterpreterCore in on CPUPlace, do nothing.
if
(
platform
::
is_cpu_place
(
place
))
return
;
for
(
auto
&
event
:
instruction
.
output_events_
)
{
VLOG
(
3
)
<<
"Record event in out_var_id: "
<<
event
.
var_id_
;
event
.
event_
->
Record
(
instruction
.
dev_ctx_
);
}
}
void
EventManager
::
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
)
{
for
(
auto
&
event_iter
:
events
)
{
if
(
event_iter
.
is_sync_
)
{
VLOG
(
3
)
<<
"host sync wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCPU
,
dev_ctx
);
}
else
{
VLOG
(
3
)
<<
"stream async wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCUDA
,
dev_ctx
);
}
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/event_manager.h
0 → 100644
浏览文件 @
8528dd9f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
namespace
paddle
{
namespace
framework
{
class
EventManager
{
public:
void
RecordEvent
(
const
Instruction
&
instruction
,
const
OpFuncNode
&
op_func_node
,
const
platform
::
Place
&
place
);
void
WaitEvent
(
const
Instruction
&
instruction
,
const
platform
::
Place
&
place
);
private:
void
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
);
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
8528dd9f
...
...
@@ -20,101 +20,6 @@
namespace
paddle
{
namespace
framework
{
namespace
{
/*
* Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the
* following conditions:
* 1. kQueueAsync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
* out_var should be associated with an event.
*/
std
::
vector
<
size_t
>
ParseEventVarIds
(
const
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
)
{
std
::
unordered_set
<
size_t
>
unique_var_ids
;
for
(
auto
&
item
:
cur_instr
.
output_index_
)
{
unique_var_ids
.
insert
(
item
.
second
.
begin
(),
item
.
second
.
end
());
}
std
::
vector
<
size_t
>
new_event_var_ids
;
for
(
auto
&
item
:
next_instr
.
input_index_
)
{
for
(
auto
var_id
:
item
.
second
)
{
if
(
unique_var_ids
.
count
(
var_id
)
>
0
)
{
new_event_var_ids
.
push_back
(
var_id
);
}
}
}
return
new_event_var_ids
;
}
void
AssociateInputWithEvents
(
const
platform
::
Place
&
place
,
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
std
::
map
<
size_t
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>>*
var_id2event
,
bool
is_sync
)
{
for
(
auto
var_id
:
new_event_var_id
)
{
if
(
var_id2event
->
count
(
var_id
)
==
0
)
{
auto
device_event
=
std
::
make_shared
<
platform
::
DeviceEvent
>
(
place
,
platform
::
GenerateDeviceEventFlag
());
var_id2event
->
emplace
(
var_id
,
std
::
move
(
device_event
));
}
// Add events for next_instr.inputs
next_instr
->
intput_events_
.
emplace_back
(
var_id
,
var_id2event
->
at
(
var_id
),
is_sync
);
}
}
void
ParseDirectAndEventRunOps
(
const
platform
::
Place
&
place
,
const
std
::
vector
<
OpFuncNode
>&
op_func_nodes
,
const
std
::
vector
<
size_t
>&
downstream_ops
,
size_t
op_index
,
std
::
map
<
size_t
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>>*
var_id2event
,
std
::
vector
<
Instruction
>*
instructions
)
{
auto
&
op_func_type
=
op_func_nodes
[
op_index
].
type_
;
auto
&
cur_instr
=
instructions
->
at
(
op_index
);
auto
&
next_instruction
=
cur_instr
.
next_instruction_
;
if
(
op_func_type
==
OpFuncType
::
kQueueSync
)
{
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
next_instruction
.
direct_run_
=
downstream_ops
;
}
else
{
// kQueueAsync
std
::
vector
<
size_t
>
event_var_ids
;
for
(
auto
next_op_id
:
downstream_ops
)
{
auto
&
next_instr
=
instructions
->
at
(
next_op_id
);
// case 1: GPU -> GPU(same stream)
if
(
cur_instr
.
dev_ctx_
==
next_instr
.
dev_ctx_
)
{
next_instruction
.
direct_run_
.
emplace_back
(
next_op_id
);
continue
;
}
// Always insert events between different stream
auto
new_event_var_ids
=
ParseEventVarIds
(
cur_instr
,
next_instr
);
event_var_ids
.
insert
(
event_var_ids
.
end
(),
new_event_var_ids
.
begin
(),
new_event_var_ids
.
end
());
bool
is_sync
=
(
op_func_nodes
[
next_op_id
].
type_
==
OpFuncType
::
kQueueSync
);
AssociateInputWithEvents
(
place
,
new_event_var_ids
,
&
next_instr
,
var_id2event
,
is_sync
);
if
(
is_sync
)
{
// GPU -> CPU
next_instruction
.
synchronize_run_
.
emplace_back
(
next_op_id
);
}
else
{
// GPU -> GPU(different stream)
next_instruction
.
event_wait_run_
.
emplace_back
(
next_op_id
);
}
}
// Create events for these cross-stream vars
VLOG
(
3
)
<<
cur_instr
.
kernel_func_
.
operator_base_
->
Type
()
<<
" event_var_ids.size: "
<<
event_var_ids
.
size
();
for
(
auto
var_id
:
event_var_ids
)
{
cur_instr
.
output_events_
.
emplace_back
(
var_id
,
var_id2event
->
at
(
var_id
),
false
/*not used*/
);
}
}
}
}
// namespace
InterpreterCore
::
InterpreterCore
(
const
platform
::
Place
&
place
,
const
ProgramDesc
&
main_prog
,
VariableScope
*
global_scope
,
...
...
@@ -123,8 +28,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
:
place_
(
place
),
main_program_
(
main_prog
),
global_scope_
(
global_scope
),
d2h_ctx_pool_
({
place
}),
h2d_ctx_pool_
({
place
})
{
stream_analyzer_
(
place
)
{
is_build_
=
false
;
feed_names_
=
feed_names
;
...
...
@@ -199,7 +103,7 @@ void InterpreterCore::Convert() {
Instruction
temp_inst
;
auto
*
op_base
=
op_list_
[
i
];
temp_inst
.
dev_ctx_
=
ParseDeviceContextForInstruction
(
vec_func_list_
[
i
],
*
op_base
);
stream_analyzer_
.
ParseDeviceContext
(
vec_func_list_
[
i
],
*
op_base
);
temp_inst
.
kernel_func_
.
compute_func_
=
vec_func_list_
[
i
].
kernel_func_
;
temp_inst
.
kernel_func_
.
operator_base_
=
op_base
;
temp_inst
.
input_index_
=
vec_func_list_
[
i
].
input_index
;
...
...
@@ -270,8 +174,8 @@ void InterpreterCore::Convert() {
}
}
ParseDirectAndEventRunOps
(
place_
,
vec_func_list_
,
filter_next
,
i
,
&
v
ar_id2event_
,
&
v
ec_instruction_
);
stream_analyzer_
.
Schedule
(
vec_func_list_
,
filter_next
,
i
,
&
vec_instruction_
);
for
(
auto
inst_id
:
filter_next
)
{
dependecy_count_
[
inst_id
]
++
;
...
...
@@ -361,7 +265,7 @@ void InterpreterCore::ExecuteInstructionList(
working_queue
.
pop
();
auto
&
instr_node
=
vec_instr
[
instr_id
];
// step1 : stream_wait (non-block host) or sync (block host)
StreamWaitEventOrSync
(
instr_node
);
event_manager_
.
WaitEvent
(
instr_node
,
place_
);
// step2: run instruction
RunInstruction
(
instr_node
);
++
run_op_number
;
...
...
@@ -371,7 +275,7 @@ void InterpreterCore::ExecuteInstructionList(
}
// step3: insert event for out_vars if needed
RecordEventInstruction
(
instr_node
,
vec_func_list_
[
instr_id
]
);
event_manager_
.
RecordEvent
(
instr_node
,
vec_func_list_
[
instr_id
],
place_
);
// step4: update working_queue
auto
&
next_instr
=
instr_node
.
next_instruction_
.
all_next_ops_
;
...
...
@@ -450,54 +354,5 @@ const CostInfo& InterpreterCore::DryRun(
return
dry_run_profiler_
.
GetCostInfo
();
}
platform
::
DeviceContext
*
InterpreterCore
::
ParseDeviceContextForInstruction
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
)
{
auto
&
op_type
=
op_base
.
Type
();
auto
*
dev_ctx
=
op_func_node
.
dev_ctx_
;
if
(
op_type
==
interpretercore
::
kMemcpyH2D
)
{
VLOG
(
3
)
<<
"Get dev_ctx from d2h_context_pool_"
;
dev_ctx
=
d2h_ctx_pool_
.
Get
(
place_
);
}
else
if
(
op_type
==
interpretercore
::
kMemcpyD2H
)
{
VLOG
(
3
)
<<
"Get dev_ctx from h2d_context_pool_"
;
dev_ctx
=
h2d_ctx_pool_
.
Get
(
place_
);
}
return
dev_ctx
;
}
void
InterpreterCore
::
RecordEventInstruction
(
const
Instruction
&
instruction
,
const
OpFuncNode
&
op_func_node
)
{
// If InterpreterCore in on CPUPlace, do nothing.
if
(
platform
::
is_cpu_place
(
place_
))
return
;
for
(
auto
&
event
:
instruction
.
output_events_
)
{
VLOG
(
3
)
<<
"Record event in out_var_id: "
<<
event
.
var_id_
;
event
.
event_
->
Record
(
instruction
.
dev_ctx_
);
}
}
void
InterpreterCore
::
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
)
{
for
(
auto
&
event_iter
:
events
)
{
if
(
event_iter
.
is_sync_
)
{
VLOG
(
3
)
<<
"host sync wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCPU
,
dev_ctx
);
}
else
{
VLOG
(
3
)
<<
"stream async wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCUDA
,
dev_ctx
);
}
}
}
void
InterpreterCore
::
StreamWaitEventOrSync
(
const
Instruction
&
instruction
)
{
// If InterpreterCore in on CPUPlace, do nothing.
if
(
platform
::
is_cpu_place
(
place_
))
return
;
VLOG
(
3
)
<<
"Deal StreamWaitEventOrSync for "
<<
instruction
.
kernel_func_
.
operator_base_
->
Type
();
auto
*
dev_ctx
=
instruction
.
dev_ctx_
;
WaitOrSync
(
instruction
.
intput_events_
,
dev_ctx
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/interpretercore.h
浏览文件 @
8528dd9f
...
...
@@ -19,10 +19,12 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/new_executor/event_manager.h"
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/profiler.h"
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
...
...
@@ -64,17 +66,6 @@ class InterpreterCore {
const
VariableScope
&
var_scope
,
const
platform
::
Place
&
place
,
std
::
vector
<
VariableMetaInfo
>&
working_var_ref
);
// NOLINT
platform
::
DeviceContext
*
ParseDeviceContextForInstruction
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
);
void
RecordEventInstruction
(
const
Instruction
&
instruction
,
const
OpFuncNode
&
op_func_node
);
void
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
);
void
StreamWaitEventOrSync
(
const
Instruction
&
instruction
);
void
AddFetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
);
bool
is_build_
;
...
...
@@ -83,9 +74,6 @@ class InterpreterCore {
ProgramDesc
main_program_
;
VariableScope
*
global_scope_
;
platform
::
DeviceContextPool
d2h_ctx_pool_
;
platform
::
DeviceContextPool
h2d_ctx_pool_
;
std
::
vector
<
Instruction
>
vec_instruction_
;
InstructionInfo
instruction_info_
;
std
::
vector
<
size_t
>
dependecy_count_
;
...
...
@@ -99,8 +87,8 @@ class InterpreterCore {
std
::
vector
<
std
::
string
>
feed_names_
;
InterpreterProfiler
dry_run_profiler_
;
std
::
map
<
size_t
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>>
var_id2event
_
;
StreamAnalyzer
stream_analyzer_
;
EventManager
event_manager
_
;
InterpreterCoreGarbageCollector
gc_
;
std
::
vector
<
paddle
::
platform
::
DeviceEvent
>
gc_event_
;
...
...
paddle/fluid/framework/new_executor/interpretercore_util.h
浏览文件 @
8528dd9f
...
...
@@ -476,9 +476,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
namespace
interpretercore
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
const
platform
::
Place
&
dst_place
);
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
8528dd9f
...
...
@@ -25,6 +25,11 @@
namespace
paddle
{
namespace
framework
{
namespace
interpretercore
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
}
// namespace interpretercore
using
OpKernelComputeFunc
=
std
::
function
<
void
(
const
ExecutionContext
&
)
>
;
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelType
,
OpKernelComputeFunc
,
OpKernelType
::
Hash
>
;
...
...
paddle/fluid/framework/new_executor/stream_analyzer.cc
0 → 100644
浏览文件 @
8528dd9f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include <unordered_set>
namespace
paddle
{
namespace
framework
{
/*
* Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the
* following conditions:
* 1. kQueueAsync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
* out_var should be associated with an event.
*/
std
::
vector
<
size_t
>
StreamAnalyzer
::
ParseEventVarIds
(
const
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
)
{
std
::
unordered_set
<
size_t
>
unique_var_ids
;
for
(
auto
&
item
:
cur_instr
.
output_index_
)
{
unique_var_ids
.
insert
(
item
.
second
.
begin
(),
item
.
second
.
end
());
}
std
::
vector
<
size_t
>
new_event_var_ids
;
for
(
auto
&
item
:
next_instr
.
input_index_
)
{
for
(
auto
var_id
:
item
.
second
)
{
if
(
unique_var_ids
.
count
(
var_id
)
>
0
)
{
new_event_var_ids
.
push_back
(
var_id
);
}
}
}
return
new_event_var_ids
;
}
void
StreamAnalyzer
::
AssociateInputWithEvents
(
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
bool
is_sync
)
{
for
(
auto
var_id
:
new_event_var_id
)
{
if
(
var_id2event_
.
count
(
var_id
)
==
0
)
{
auto
device_event
=
std
::
make_shared
<
platform
::
DeviceEvent
>
(
place_
,
platform
::
GenerateDeviceEventFlag
());
var_id2event_
.
emplace
(
var_id
,
std
::
move
(
device_event
));
}
// Add events for next_instr.inputs
next_instr
->
intput_events_
.
emplace_back
(
var_id
,
var_id2event_
.
at
(
var_id
),
is_sync
);
}
}
void
StreamAnalyzer
::
Schedule
(
const
std
::
vector
<
OpFuncNode
>&
op_func_nodes
,
const
std
::
vector
<
size_t
>&
downstream_ops
,
size_t
op_index
,
std
::
vector
<
Instruction
>*
instructions
)
{
auto
&
op_func_type
=
op_func_nodes
[
op_index
].
type_
;
auto
&
cur_instr
=
instructions
->
at
(
op_index
);
auto
&
next_instruction
=
cur_instr
.
next_instruction_
;
if
(
op_func_type
==
OpFuncType
::
kQueueSync
)
{
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
next_instruction
.
direct_run_
=
downstream_ops
;
}
else
{
// kQueueAsync
std
::
vector
<
size_t
>
event_var_ids
;
for
(
auto
next_op_id
:
downstream_ops
)
{
auto
&
next_instr
=
instructions
->
at
(
next_op_id
);
// case 1: GPU -> GPU(same stream)
if
(
cur_instr
.
dev_ctx_
==
next_instr
.
dev_ctx_
)
{
next_instruction
.
direct_run_
.
emplace_back
(
next_op_id
);
continue
;
}
// Always insert events between different stream
auto
new_event_var_ids
=
ParseEventVarIds
(
cur_instr
,
next_instr
);
event_var_ids
.
insert
(
event_var_ids
.
end
(),
new_event_var_ids
.
begin
(),
new_event_var_ids
.
end
());
bool
is_sync
=
(
op_func_nodes
[
next_op_id
].
type_
==
OpFuncType
::
kQueueSync
);
AssociateInputWithEvents
(
new_event_var_ids
,
&
next_instr
,
is_sync
);
if
(
is_sync
)
{
// GPU -> CPU
next_instruction
.
synchronize_run_
.
emplace_back
(
next_op_id
);
}
else
{
// GPU -> GPU(different stream)
next_instruction
.
event_wait_run_
.
emplace_back
(
next_op_id
);
}
}
// Create events for these cross-stream vars
VLOG
(
3
)
<<
cur_instr
.
kernel_func_
.
operator_base_
->
Type
()
<<
" event_var_ids.size: "
<<
event_var_ids
.
size
();
for
(
auto
var_id
:
event_var_ids
)
{
cur_instr
.
output_events_
.
emplace_back
(
var_id
,
var_id2event_
.
at
(
var_id
),
false
/*not used*/
);
}
}
}
platform
::
DeviceContext
*
StreamAnalyzer
::
ParseDeviceContext
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
)
{
auto
&
op_type
=
op_base
.
Type
();
auto
*
dev_ctx
=
op_func_node
.
dev_ctx_
;
if
(
op_type
==
interpretercore
::
kMemcpyH2D
)
{
VLOG
(
3
)
<<
"Get dev_ctx from d2h_context_pool_"
;
dev_ctx
=
d2h_ctx_pool_
.
Get
(
place_
);
}
else
if
(
op_type
==
interpretercore
::
kMemcpyD2H
)
{
VLOG
(
3
)
<<
"Get dev_ctx from h2d_context_pool_"
;
dev_ctx
=
h2d_ctx_pool_
.
Get
(
place_
);
}
return
dev_ctx
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/stream_analyzer.h
0 → 100644
浏览文件 @
8528dd9f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_event.h"
namespace
paddle
{
namespace
framework
{
class
StreamAnalyzer
{
public:
explicit
StreamAnalyzer
(
const
platform
::
Place
&
place
)
:
place_
(
place
),
d2h_ctx_pool_
({
place
}),
h2d_ctx_pool_
({
place
})
{}
~
StreamAnalyzer
()
{}
void
Schedule
(
const
std
::
vector
<
OpFuncNode
>&
op_func_nodes
,
const
std
::
vector
<
size_t
>&
downstream_ops
,
size_t
op_index
,
std
::
vector
<
Instruction
>*
instructions
);
platform
::
DeviceContext
*
ParseDeviceContext
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
);
private:
std
::
vector
<
size_t
>
ParseEventVarIds
(
const
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
);
void
AssociateInputWithEvents
(
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
bool
is_sync
);
platform
::
Place
place_
;
platform
::
DeviceContextPool
d2h_ctx_pool_
;
platform
::
DeviceContextPool
h2d_ctx_pool_
;
std
::
map
<
size_t
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>>
var_id2event_
;
};
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录