Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8528dd9f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8528dd9f
编写于
9月 14, 2021
作者:
A
Aurelius84
提交者:
GitHub
9月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor StreamAnalyzer and EventManager from InterpreterCore (#35711)
上级
85e4f45a
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
288 addition
and
171 deletion
+288
-171
paddle/fluid/framework/new_executor/CMakeLists.txt
paddle/fluid/framework/new_executor/CMakeLists.txt
+3
-1
paddle/fluid/framework/new_executor/event_manager.cc
paddle/fluid/framework/new_executor/event_manager.cc
+58
-0
paddle/fluid/framework/new_executor/event_manager.h
paddle/fluid/framework/new_executor/event_manager.h
+35
-0
paddle/fluid/framework/new_executor/interpretercore.cc
paddle/fluid/framework/new_executor/interpretercore.cc
+6
-151
paddle/fluid/framework/new_executor/interpretercore.h
paddle/fluid/framework/new_executor/interpretercore.h
+4
-16
paddle/fluid/framework/new_executor/interpretercore_util.h
paddle/fluid/framework/new_executor/interpretercore_util.h
+0
-3
paddle/fluid/framework/new_executor/new_executor_defs.h
paddle/fluid/framework/new_executor/new_executor_defs.h
+5
-0
paddle/fluid/framework/new_executor/stream_analyzer.cc
paddle/fluid/framework/new_executor/stream_analyzer.cc
+125
-0
paddle/fluid/framework/new_executor/stream_analyzer.h
paddle/fluid/framework/new_executor/stream_analyzer.h
+52
-0
未找到文件。
paddle/fluid/framework/new_executor/CMakeLists.txt
浏览文件 @
8528dd9f
...
...
@@ -5,7 +5,9 @@ graph_to_program_pass variable_helper timer monitor)
cc_library
(
workqueue SRCS workqueue.cc DEPS enforce
)
cc_library
(
interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue
${
DEVICE_EVENT_LIBS
}
)
cc_library
(
interpretercore_util SRCS interpretercore_util.cc DEPS
${
INTERPRETERCORE_DEPS
}
)
cc_library
(
interpretercore SRCS interpretercore.cc DEPS workqueue
${
DEVICE_EVENT_LIBS
}
interpretercore_util interpretercore_garbage_collector
)
cc_library
(
event_manager SRCS event_manager.cc DEPS
${
DEVICE_EVENT_LIBS
}
glog
)
cc_library
(
stream_analyzer SRCS stream_analyzer.cc DEPS
${
DEVICE_EVENT_LIBS
}
glog device_context
)
cc_library
(
interpretercore SRCS interpretercore.cc DEPS workqueue
${
DEVICE_EVENT_LIBS
}
interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager
)
cc_library
(
standalone_executor SRCS standalone_executor.cc DEPS interpretercore
)
cc_test
(
workqueue_test SRCS workqueue_test.cc DEPS workqueue
)
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
paddle/fluid/framework/new_executor/event_manager.cc
0 → 100644
浏览文件 @
8528dd9f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/event_manager.h"
namespace
paddle
{
namespace
framework
{
void
EventManager
::
WaitEvent
(
const
Instruction
&
instruction
,
const
platform
::
Place
&
place
)
{
// If InterpreterCore in on CPUPlace, do nothing.
if
(
platform
::
is_cpu_place
(
place
))
return
;
VLOG
(
3
)
<<
"Deal StreamWaitEventOrSync for "
<<
instruction
.
kernel_func_
.
operator_base_
->
Type
();
auto
*
dev_ctx
=
instruction
.
dev_ctx_
;
WaitOrSync
(
instruction
.
intput_events_
,
dev_ctx
);
}
void
EventManager
::
RecordEvent
(
const
Instruction
&
instruction
,
const
OpFuncNode
&
op_func_node
,
const
platform
::
Place
&
place
)
{
// If InterpreterCore in on CPUPlace, do nothing.
if
(
platform
::
is_cpu_place
(
place
))
return
;
for
(
auto
&
event
:
instruction
.
output_events_
)
{
VLOG
(
3
)
<<
"Record event in out_var_id: "
<<
event
.
var_id_
;
event
.
event_
->
Record
(
instruction
.
dev_ctx_
);
}
}
void
EventManager
::
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
)
{
for
(
auto
&
event_iter
:
events
)
{
if
(
event_iter
.
is_sync_
)
{
VLOG
(
3
)
<<
"host sync wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCPU
,
dev_ctx
);
}
else
{
VLOG
(
3
)
<<
"stream async wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCUDA
,
dev_ctx
);
}
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/event_manager.h
0 → 100644
浏览文件 @
8528dd9f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
namespace
paddle
{
namespace
framework
{
class
EventManager
{
public:
void
RecordEvent
(
const
Instruction
&
instruction
,
const
OpFuncNode
&
op_func_node
,
const
platform
::
Place
&
place
);
void
WaitEvent
(
const
Instruction
&
instruction
,
const
platform
::
Place
&
place
);
private:
void
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
);
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/interpretercore.cc
浏览文件 @
8528dd9f
...
...
@@ -20,101 +20,6 @@
namespace
paddle
{
namespace
framework
{
namespace
{
/*
* Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the
* following conditions:
* 1. kQueueAsync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
* out_var should be associated with an event.
*/
std
::
vector
<
size_t
>
ParseEventVarIds
(
const
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
)
{
std
::
unordered_set
<
size_t
>
unique_var_ids
;
for
(
auto
&
item
:
cur_instr
.
output_index_
)
{
unique_var_ids
.
insert
(
item
.
second
.
begin
(),
item
.
second
.
end
());
}
std
::
vector
<
size_t
>
new_event_var_ids
;
for
(
auto
&
item
:
next_instr
.
input_index_
)
{
for
(
auto
var_id
:
item
.
second
)
{
if
(
unique_var_ids
.
count
(
var_id
)
>
0
)
{
new_event_var_ids
.
push_back
(
var_id
);
}
}
}
return
new_event_var_ids
;
}
void
AssociateInputWithEvents
(
const
platform
::
Place
&
place
,
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
std
::
map
<
size_t
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>>*
var_id2event
,
bool
is_sync
)
{
for
(
auto
var_id
:
new_event_var_id
)
{
if
(
var_id2event
->
count
(
var_id
)
==
0
)
{
auto
device_event
=
std
::
make_shared
<
platform
::
DeviceEvent
>
(
place
,
platform
::
GenerateDeviceEventFlag
());
var_id2event
->
emplace
(
var_id
,
std
::
move
(
device_event
));
}
// Add events for next_instr.inputs
next_instr
->
intput_events_
.
emplace_back
(
var_id
,
var_id2event
->
at
(
var_id
),
is_sync
);
}
}
void
ParseDirectAndEventRunOps
(
const
platform
::
Place
&
place
,
const
std
::
vector
<
OpFuncNode
>&
op_func_nodes
,
const
std
::
vector
<
size_t
>&
downstream_ops
,
size_t
op_index
,
std
::
map
<
size_t
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>>*
var_id2event
,
std
::
vector
<
Instruction
>*
instructions
)
{
auto
&
op_func_type
=
op_func_nodes
[
op_index
].
type_
;
auto
&
cur_instr
=
instructions
->
at
(
op_index
);
auto
&
next_instruction
=
cur_instr
.
next_instruction_
;
if
(
op_func_type
==
OpFuncType
::
kQueueSync
)
{
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
next_instruction
.
direct_run_
=
downstream_ops
;
}
else
{
// kQueueAsync
std
::
vector
<
size_t
>
event_var_ids
;
for
(
auto
next_op_id
:
downstream_ops
)
{
auto
&
next_instr
=
instructions
->
at
(
next_op_id
);
// case 1: GPU -> GPU(same stream)
if
(
cur_instr
.
dev_ctx_
==
next_instr
.
dev_ctx_
)
{
next_instruction
.
direct_run_
.
emplace_back
(
next_op_id
);
continue
;
}
// Always insert events between different stream
auto
new_event_var_ids
=
ParseEventVarIds
(
cur_instr
,
next_instr
);
event_var_ids
.
insert
(
event_var_ids
.
end
(),
new_event_var_ids
.
begin
(),
new_event_var_ids
.
end
());
bool
is_sync
=
(
op_func_nodes
[
next_op_id
].
type_
==
OpFuncType
::
kQueueSync
);
AssociateInputWithEvents
(
place
,
new_event_var_ids
,
&
next_instr
,
var_id2event
,
is_sync
);
if
(
is_sync
)
{
// GPU -> CPU
next_instruction
.
synchronize_run_
.
emplace_back
(
next_op_id
);
}
else
{
// GPU -> GPU(different stream)
next_instruction
.
event_wait_run_
.
emplace_back
(
next_op_id
);
}
}
// Create events for these cross-stream vars
VLOG
(
3
)
<<
cur_instr
.
kernel_func_
.
operator_base_
->
Type
()
<<
" event_var_ids.size: "
<<
event_var_ids
.
size
();
for
(
auto
var_id
:
event_var_ids
)
{
cur_instr
.
output_events_
.
emplace_back
(
var_id
,
var_id2event
->
at
(
var_id
),
false
/*not used*/
);
}
}
}
}
// namespace
InterpreterCore
::
InterpreterCore
(
const
platform
::
Place
&
place
,
const
ProgramDesc
&
main_prog
,
VariableScope
*
global_scope
,
...
...
@@ -123,8 +28,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
:
place_
(
place
),
main_program_
(
main_prog
),
global_scope_
(
global_scope
),
d2h_ctx_pool_
({
place
}),
h2d_ctx_pool_
({
place
})
{
stream_analyzer_
(
place
)
{
is_build_
=
false
;
feed_names_
=
feed_names
;
...
...
@@ -199,7 +103,7 @@ void InterpreterCore::Convert() {
Instruction
temp_inst
;
auto
*
op_base
=
op_list_
[
i
];
temp_inst
.
dev_ctx_
=
ParseDeviceContextForInstruction
(
vec_func_list_
[
i
],
*
op_base
);
stream_analyzer_
.
ParseDeviceContext
(
vec_func_list_
[
i
],
*
op_base
);
temp_inst
.
kernel_func_
.
compute_func_
=
vec_func_list_
[
i
].
kernel_func_
;
temp_inst
.
kernel_func_
.
operator_base_
=
op_base
;
temp_inst
.
input_index_
=
vec_func_list_
[
i
].
input_index
;
...
...
@@ -270,8 +174,8 @@ void InterpreterCore::Convert() {
}
}
ParseDirectAndEventRunOps
(
place_
,
vec_func_list_
,
filter_next
,
i
,
&
v
ar_id2event_
,
&
v
ec_instruction_
);
stream_analyzer_
.
Schedule
(
vec_func_list_
,
filter_next
,
i
,
&
vec_instruction_
);
for
(
auto
inst_id
:
filter_next
)
{
dependecy_count_
[
inst_id
]
++
;
...
...
@@ -361,7 +265,7 @@ void InterpreterCore::ExecuteInstructionList(
working_queue
.
pop
();
auto
&
instr_node
=
vec_instr
[
instr_id
];
// step1 : stream_wait (non-block host) or sync (block host)
StreamWaitEventOrSync
(
instr_node
);
event_manager_
.
WaitEvent
(
instr_node
,
place_
);
// step2: run instruction
RunInstruction
(
instr_node
);
++
run_op_number
;
...
...
@@ -371,7 +275,7 @@ void InterpreterCore::ExecuteInstructionList(
}
// step3: insert event for out_vars if needed
RecordEventInstruction
(
instr_node
,
vec_func_list_
[
instr_id
]
);
event_manager_
.
RecordEvent
(
instr_node
,
vec_func_list_
[
instr_id
],
place_
);
// step4: update working_queue
auto
&
next_instr
=
instr_node
.
next_instruction_
.
all_next_ops_
;
...
...
@@ -450,54 +354,5 @@ const CostInfo& InterpreterCore::DryRun(
return
dry_run_profiler_
.
GetCostInfo
();
}
platform
::
DeviceContext
*
InterpreterCore
::
ParseDeviceContextForInstruction
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
)
{
auto
&
op_type
=
op_base
.
Type
();
auto
*
dev_ctx
=
op_func_node
.
dev_ctx_
;
if
(
op_type
==
interpretercore
::
kMemcpyH2D
)
{
VLOG
(
3
)
<<
"Get dev_ctx from d2h_context_pool_"
;
dev_ctx
=
d2h_ctx_pool_
.
Get
(
place_
);
}
else
if
(
op_type
==
interpretercore
::
kMemcpyD2H
)
{
VLOG
(
3
)
<<
"Get dev_ctx from h2d_context_pool_"
;
dev_ctx
=
h2d_ctx_pool_
.
Get
(
place_
);
}
return
dev_ctx
;
}
void
InterpreterCore
::
RecordEventInstruction
(
const
Instruction
&
instruction
,
const
OpFuncNode
&
op_func_node
)
{
// If InterpreterCore in on CPUPlace, do nothing.
if
(
platform
::
is_cpu_place
(
place_
))
return
;
for
(
auto
&
event
:
instruction
.
output_events_
)
{
VLOG
(
3
)
<<
"Record event in out_var_id: "
<<
event
.
var_id_
;
event
.
event_
->
Record
(
instruction
.
dev_ctx_
);
}
}
void
InterpreterCore
::
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
)
{
for
(
auto
&
event_iter
:
events
)
{
if
(
event_iter
.
is_sync_
)
{
VLOG
(
3
)
<<
"host sync wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCPU
,
dev_ctx
);
}
else
{
VLOG
(
3
)
<<
"stream async wait in_var_id "
<<
event_iter
.
var_id_
;
event_iter
.
event_
->
Wait
(
platform
::
kCUDA
,
dev_ctx
);
}
}
}
void
InterpreterCore
::
StreamWaitEventOrSync
(
const
Instruction
&
instruction
)
{
// If InterpreterCore in on CPUPlace, do nothing.
if
(
platform
::
is_cpu_place
(
place_
))
return
;
VLOG
(
3
)
<<
"Deal StreamWaitEventOrSync for "
<<
instruction
.
kernel_func_
.
operator_base_
->
Type
();
auto
*
dev_ctx
=
instruction
.
dev_ctx_
;
WaitOrSync
(
instruction
.
intput_events_
,
dev_ctx
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/interpretercore.h
浏览文件 @
8528dd9f
...
...
@@ -19,10 +19,12 @@
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/new_executor/event_manager.h"
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/profiler.h"
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include "paddle/fluid/framework/new_executor/workqueue.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
...
...
@@ -64,17 +66,6 @@ class InterpreterCore {
const
VariableScope
&
var_scope
,
const
platform
::
Place
&
place
,
std
::
vector
<
VariableMetaInfo
>&
working_var_ref
);
// NOLINT
platform
::
DeviceContext
*
ParseDeviceContextForInstruction
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
);
void
RecordEventInstruction
(
const
Instruction
&
instruction
,
const
OpFuncNode
&
op_func_node
);
void
WaitOrSync
(
const
std
::
vector
<
EventInter
>&
events
,
const
platform
::
DeviceContext
*
dev_ctx
);
void
StreamWaitEventOrSync
(
const
Instruction
&
instruction
);
void
AddFetch
(
const
std
::
vector
<
std
::
string
>&
fetch_names
);
bool
is_build_
;
...
...
@@ -83,9 +74,6 @@ class InterpreterCore {
ProgramDesc
main_program_
;
VariableScope
*
global_scope_
;
platform
::
DeviceContextPool
d2h_ctx_pool_
;
platform
::
DeviceContextPool
h2d_ctx_pool_
;
std
::
vector
<
Instruction
>
vec_instruction_
;
InstructionInfo
instruction_info_
;
std
::
vector
<
size_t
>
dependecy_count_
;
...
...
@@ -99,8 +87,8 @@ class InterpreterCore {
std
::
vector
<
std
::
string
>
feed_names_
;
InterpreterProfiler
dry_run_profiler_
;
std
::
map
<
size_t
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>>
var_id2event
_
;
StreamAnalyzer
stream_analyzer_
;
EventManager
event_manager
_
;
InterpreterCoreGarbageCollector
gc_
;
std
::
vector
<
paddle
::
platform
::
DeviceEvent
>
gc_event_
;
...
...
paddle/fluid/framework/new_executor/interpretercore_util.h
浏览文件 @
8528dd9f
...
...
@@ -476,9 +476,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
namespace
interpretercore
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
std
::
string
get_memcpy_type
(
const
platform
::
Place
&
src_place
,
const
platform
::
Place
&
dst_place
);
...
...
paddle/fluid/framework/new_executor/new_executor_defs.h
浏览文件 @
8528dd9f
...
...
@@ -25,6 +25,11 @@
namespace
paddle
{
namespace
framework
{
namespace
interpretercore
{
static
constexpr
char
kMemcpyH2D
[]
=
"memcpy_h2d"
;
static
constexpr
char
kMemcpyD2H
[]
=
"memcpy_d2h"
;
}
// namespace interpretercore
using
OpKernelComputeFunc
=
std
::
function
<
void
(
const
ExecutionContext
&
)
>
;
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelType
,
OpKernelComputeFunc
,
OpKernelType
::
Hash
>
;
...
...
paddle/fluid/framework/new_executor/stream_analyzer.cc
0 → 100644
浏览文件 @
8528dd9f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
#include <unordered_set>
namespace
paddle
{
namespace
framework
{
/*
* Parse the var_ids that need to be associated with an event.
* The caller should guarantee front_op and back_op satisfy the
* following conditions:
* 1. kQueueAsync -> kQueueAsync
* 2. kQueueAsync -> kQueueSync
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
* out_var should be associated with an event.
*/
std
::
vector
<
size_t
>
StreamAnalyzer
::
ParseEventVarIds
(
const
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
)
{
std
::
unordered_set
<
size_t
>
unique_var_ids
;
for
(
auto
&
item
:
cur_instr
.
output_index_
)
{
unique_var_ids
.
insert
(
item
.
second
.
begin
(),
item
.
second
.
end
());
}
std
::
vector
<
size_t
>
new_event_var_ids
;
for
(
auto
&
item
:
next_instr
.
input_index_
)
{
for
(
auto
var_id
:
item
.
second
)
{
if
(
unique_var_ids
.
count
(
var_id
)
>
0
)
{
new_event_var_ids
.
push_back
(
var_id
);
}
}
}
return
new_event_var_ids
;
}
void
StreamAnalyzer
::
AssociateInputWithEvents
(
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
bool
is_sync
)
{
for
(
auto
var_id
:
new_event_var_id
)
{
if
(
var_id2event_
.
count
(
var_id
)
==
0
)
{
auto
device_event
=
std
::
make_shared
<
platform
::
DeviceEvent
>
(
place_
,
platform
::
GenerateDeviceEventFlag
());
var_id2event_
.
emplace
(
var_id
,
std
::
move
(
device_event
));
}
// Add events for next_instr.inputs
next_instr
->
intput_events_
.
emplace_back
(
var_id
,
var_id2event_
.
at
(
var_id
),
is_sync
);
}
}
void
StreamAnalyzer
::
Schedule
(
const
std
::
vector
<
OpFuncNode
>&
op_func_nodes
,
const
std
::
vector
<
size_t
>&
downstream_ops
,
size_t
op_index
,
std
::
vector
<
Instruction
>*
instructions
)
{
auto
&
op_func_type
=
op_func_nodes
[
op_index
].
type_
;
auto
&
cur_instr
=
instructions
->
at
(
op_index
);
auto
&
next_instruction
=
cur_instr
.
next_instruction_
;
if
(
op_func_type
==
OpFuncType
::
kQueueSync
)
{
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
next_instruction
.
direct_run_
=
downstream_ops
;
}
else
{
// kQueueAsync
std
::
vector
<
size_t
>
event_var_ids
;
for
(
auto
next_op_id
:
downstream_ops
)
{
auto
&
next_instr
=
instructions
->
at
(
next_op_id
);
// case 1: GPU -> GPU(same stream)
if
(
cur_instr
.
dev_ctx_
==
next_instr
.
dev_ctx_
)
{
next_instruction
.
direct_run_
.
emplace_back
(
next_op_id
);
continue
;
}
// Always insert events between different stream
auto
new_event_var_ids
=
ParseEventVarIds
(
cur_instr
,
next_instr
);
event_var_ids
.
insert
(
event_var_ids
.
end
(),
new_event_var_ids
.
begin
(),
new_event_var_ids
.
end
());
bool
is_sync
=
(
op_func_nodes
[
next_op_id
].
type_
==
OpFuncType
::
kQueueSync
);
AssociateInputWithEvents
(
new_event_var_ids
,
&
next_instr
,
is_sync
);
if
(
is_sync
)
{
// GPU -> CPU
next_instruction
.
synchronize_run_
.
emplace_back
(
next_op_id
);
}
else
{
// GPU -> GPU(different stream)
next_instruction
.
event_wait_run_
.
emplace_back
(
next_op_id
);
}
}
// Create events for these cross-stream vars
VLOG
(
3
)
<<
cur_instr
.
kernel_func_
.
operator_base_
->
Type
()
<<
" event_var_ids.size: "
<<
event_var_ids
.
size
();
for
(
auto
var_id
:
event_var_ids
)
{
cur_instr
.
output_events_
.
emplace_back
(
var_id
,
var_id2event_
.
at
(
var_id
),
false
/*not used*/
);
}
}
}
platform
::
DeviceContext
*
StreamAnalyzer
::
ParseDeviceContext
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
)
{
auto
&
op_type
=
op_base
.
Type
();
auto
*
dev_ctx
=
op_func_node
.
dev_ctx_
;
if
(
op_type
==
interpretercore
::
kMemcpyH2D
)
{
VLOG
(
3
)
<<
"Get dev_ctx from d2h_context_pool_"
;
dev_ctx
=
d2h_ctx_pool_
.
Get
(
place_
);
}
else
if
(
op_type
==
interpretercore
::
kMemcpyD2H
)
{
VLOG
(
3
)
<<
"Get dev_ctx from h2d_context_pool_"
;
dev_ctx
=
h2d_ctx_pool_
.
Get
(
place_
);
}
return
dev_ctx
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/stream_analyzer.h
0 → 100644
浏览文件 @
8528dd9f
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_event.h"
namespace
paddle
{
namespace
framework
{
class
StreamAnalyzer
{
public:
explicit
StreamAnalyzer
(
const
platform
::
Place
&
place
)
:
place_
(
place
),
d2h_ctx_pool_
({
place
}),
h2d_ctx_pool_
({
place
})
{}
~
StreamAnalyzer
()
{}
void
Schedule
(
const
std
::
vector
<
OpFuncNode
>&
op_func_nodes
,
const
std
::
vector
<
size_t
>&
downstream_ops
,
size_t
op_index
,
std
::
vector
<
Instruction
>*
instructions
);
platform
::
DeviceContext
*
ParseDeviceContext
(
const
OpFuncNode
&
op_func_node
,
const
OperatorBase
&
op_base
);
private:
std
::
vector
<
size_t
>
ParseEventVarIds
(
const
Instruction
&
cur_instr
,
const
Instruction
&
next_instr
);
void
AssociateInputWithEvents
(
const
std
::
vector
<
size_t
>&
new_event_var_id
,
Instruction
*
next_instr
,
bool
is_sync
);
platform
::
Place
place_
;
platform
::
DeviceContextPool
d2h_ctx_pool_
;
platform
::
DeviceContextPool
h2d_ctx_pool_
;
std
::
map
<
size_t
,
std
::
shared_ptr
<
platform
::
DeviceEvent
>>
var_id2event_
;
};
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录