Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8fe09faf
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8fe09faf
编写于
2月 22, 2021
作者:
Q
Qi Li
提交者:
GitHub
2月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid framework for rocm (part1), test=develop (#31009)
上级
33429630
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
83 addition
and
64 deletion
+83
-64
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+10
-5
paddle/fluid/framework/details/all_reduce_op_handle.h
paddle/fluid/framework/details/all_reduce_op_handle.h
+5
-4
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+8
-6
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+2
-2
paddle/fluid/framework/details/gather_op_handle_test.cc
paddle/fluid/framework/details/gather_op_handle_test.cc
+2
-2
paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.cc
...luid/framework/details/grad_merge_all_reduce_op_handle.cc
+3
-3
paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h
...fluid/framework/details/grad_merge_all_reduce_op_handle.h
+3
-3
paddle/fluid/framework/details/var_handle.h
paddle/fluid/framework/details/var_handle.h
+5
-5
paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h
...mework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h
+3
-3
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+4
-4
paddle/fluid/framework/ir/fuse_bn_act_pass.cc
paddle/fluid/framework/ir/fuse_bn_act_pass.cc
+5
-2
paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc
paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc
+5
-2
paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
+1
-1
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
.../fluid/framework/ir/fusion_group/code_generator_tester.cc
+5
-1
paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc
...optimize_pass/test_reference_count_pass_last_lived_ops.cc
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc
...ework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
...rk/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
+8
-8
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
...k/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
+10
-9
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
...rk/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
+2
-2
未找到文件。
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
8fe09faf
...
...
@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/platform/profiler.h"
#if
def PADDLE_WITH_NCCL
#if
defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DECLARE_bool
(
sync_nccl_allreduce
);
#endif
...
...
@@ -25,7 +25,7 @@ namespace paddle {
namespace
framework
{
namespace
details
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
AllReduceOpHandle
::
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
...
...
@@ -182,7 +182,7 @@ void AllReduceOpHandle::AllReduceFunc(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
std
::
string
>
&
out_var_names
)
{
if
(
is_gpu_place
(
places
[
0
]))
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
PADDLE_ENFORCE_NOT_NULL
(
nccl_ctxs_
,
platform
::
errors
::
InvalidArgument
(
"The nccl context should not be NULL."
));
...
...
@@ -198,7 +198,7 @@ void AllReduceOpHandle::AllReduceFunc(
NCCLAllReduceFunc
(
all_reduce_calls
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with
CUDA
."
));
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with
GPU
."
));
#endif
}
else
if
(
is_xpu_place
(
places
[
0
]))
{
#if defined(PADDLE_WITH_XPU_BKCL)
...
...
@@ -265,7 +265,7 @@ void AllReduceOpHandle::BKCLAllReduceFunc(
}
#endif
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
void
AllReduceOpHandle
::
NCCLAllReduceFunc
(
const
std
::
vector
<
std
::
function
<
void
()
>>
&
all_reduce_calls
)
{
this
->
RunAndRecordEvent
([
&
]
{
...
...
@@ -291,8 +291,13 @@ void AllReduceOpHandle::SyncNCCLAllReduce() {
nccl_ctxs_
->
GetRunEnvNCCLCtx
(
run_order_
,
use_hierarchical_allreduce_
);
auto
&
nccl_ctx
=
nccl_ctxs
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipGetLastError
());
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaGetLastError
());
#endif
}
}
}
...
...
paddle/fluid/framework/details/all_reduce_op_handle.h
浏览文件 @
8fe09faf
...
...
@@ -31,7 +31,7 @@ namespace platform {
class
NCCLCommunicator
;
}
// namespace platform
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU_BKCL)
...
...
@@ -43,7 +43,7 @@ namespace paddle {
namespace
framework
{
namespace
details
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
class
AllReduceOpHandle
:
public
NCCLOpHandleBase
{
public:
AllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
...
...
@@ -74,13 +74,14 @@ class AllReduceOpHandle : public OpHandleBase {
std
::
vector
<
Scope
*>
local_scopes_
;
#if !(PADDLE_WITH_NCCL || PADDLE_WITH_XPU_BKCL)
#if !defined(PADDLE_WITH_NCCL) && !defined(PADDLE_WITH_RCCL) && \
!defined(PADDLE_WITH_XPU_BKCL)
// NCCLOpHandleBase and BKCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework.
std
::
vector
<
platform
::
Place
>
places_
;
#endif
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
void
NCCLAllReduceFunc
(
const
std
::
vector
<
std
::
function
<
void
()
>>
&
all_reduce_calls
);
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
8fe09faf
...
...
@@ -158,7 +158,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"fuse_relu_depthwise_conv_pass"
);
AppendPassWithCheck
(
strategy_
.
fuse_bn_act_ops_
,
"fuse_bn_act_pass"
);
AppendPassWithCheck
(
strategy_
.
fuse_bn_add_act_ops_
,
"fuse_bn_add_act_pass"
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__)
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
!defined(_WIN32) && !defined(__APPLE__)
AppendPassWithCheck
(
strategy_
.
enable_auto_fusion_
,
"fusion_group_pass"
);
#else
LOG
(
WARNING
)
<<
"fusion_group is not enabled for Windows/MacOS now, and "
...
...
@@ -305,7 +306,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
&
nranks
,
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
DeviceType
use_device
,
platform
::
NCCLCommunicator
*
nccl_ctxs
)
const
{
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
...
...
@@ -331,7 +332,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass
->
Erase
(
kNRanks
);
pass
->
Set
<
size_t
>
(
kNRanks
,
new
size_t
(
nranks
));
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
platform
::
NCCLCommunicator
*
nctx
=
(
use_device
==
p
::
kCUDA
)
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
...
...
@@ -351,7 +352,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
platform
::
NCCLCommunicator
*
nctx
=
(
use_device
==
p
::
kCUDA
)
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
...
...
@@ -378,7 +379,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
LOG
(
INFO
)
<<
"set enable_sequential_execution:"
<<
enable_sequential_execution_
;
}
else
if
(
pass
->
Type
()
==
"all_reduce_deps_pass"
)
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
platform
::
NCCLCommunicator
*
nctx
=
(
use_device
==
p
::
kCUDA
)
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
kNCCLCtxs
);
...
...
@@ -474,6 +475,7 @@ USE_PASS(add_reader_dependency_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS
(
mkldnn_placement_pass
);
#endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__)
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
!defined(_WIN32) && !defined(__APPLE__)
USE_PASS
(
fusion_group_pass
);
#endif
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
8fe09faf
...
...
@@ -39,7 +39,7 @@ class NCCLCommunicator;
}
// namespace platform
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/bkcl_helper.h"
...
...
@@ -185,7 +185,7 @@ struct BuildStrategy {
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
&
nranks
,
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
DeviceType
use_device
,
platform
::
NCCLCommunicator
*
nccl_ctxs
)
const
;
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
...
...
paddle/fluid/framework/details/gather_op_handle_test.cc
浏览文件 @
8fe09faf
...
...
@@ -47,7 +47,7 @@ struct TestGatherOpHandle {
void
InitCtxOnGpu
(
bool
use_gpu
)
{
if
(
use_gpu
)
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int
count
=
p
::
GetCUDADeviceCount
();
if
(
count
<=
1
)
{
LOG
(
WARNING
)
<<
"Cannot test multi-gpu Broadcast, because the CUDA "
...
...
@@ -214,7 +214,7 @@ TEST(GatherTester, TestCPUGatherTestSelectedRows) {
test_op
.
TestGatherSelectedRows
(
input_scope_idx
);
}
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST
(
GatherTester
,
TestGPUGatherTestSelectedRows
)
{
TestGatherOpHandle
test_op
;
...
...
paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.cc
浏览文件 @
8fe09faf
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h"
#if
def PADDLE_WITH_NCCL
#if
defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DECLARE_bool
(
sync_nccl_allreduce
);
#endif
...
...
@@ -21,7 +21,7 @@ namespace paddle {
namespace
framework
{
namespace
details
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
GradMergeAllReduceOpHandle
::
GradMergeAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
...
...
@@ -68,7 +68,7 @@ std::string GradMergeAllReduceOpHandle::Name() const {
return
"grad_merge_all_reduce"
;
}
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
FusedGradMergeAllReduceOpHandle
::
FusedGradMergeAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
size_t
num_of_all_reduce
,
...
...
paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h
浏览文件 @
8fe09faf
...
...
@@ -33,7 +33,7 @@ namespace platform {
class
NCCLCommunicator
;
}
// namespace platform
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
...
...
@@ -44,7 +44,7 @@ namespace details {
class
GradMergeAllReduceOpHandle
:
public
AllReduceOpHandle
{
public:
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
GradMergeAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
...
...
@@ -75,7 +75,7 @@ class GradMergeAllReduceOpHandle : public AllReduceOpHandle {
class
FusedGradMergeAllReduceOpHandle
:
public
FusedAllReduceOpHandle
{
public:
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
FusedGradMergeAllReduceOpHandle
(
ir
::
Node
*
node
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
...
...
paddle/fluid/framework/details/var_handle.h
浏览文件 @
8fe09faf
...
...
@@ -126,10 +126,10 @@ struct VarHandle : public VarHandleBase {
name_
(
std
::
move
(
name
)),
place_
(
std
::
move
(
place
))
{}
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
bool
HasEvent
()
{
return
has_event_
;
}
const
cuda
Event_t
&
GetEvent
()
{
const
gpu
Event_t
&
GetEvent
()
{
PADDLE_ENFORCE_EQ
(
HasEvent
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
...
...
@@ -137,7 +137,7 @@ struct VarHandle : public VarHandleBase {
return
event_
;
}
void
SetGenerateEvent
(
const
cuda
Event_t
&
event
)
{
void
SetGenerateEvent
(
const
gpu
Event_t
&
event
)
{
has_event_
=
true
;
event_
=
event
;
}
...
...
@@ -150,9 +150,9 @@ struct VarHandle : public VarHandleBase {
size_t
scope_idx_
;
std
::
string
name_
;
platform
::
Place
place_
;
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Only when this event is triggered, var is generated.
cuda
Event_t
event_
;
gpu
Event_t
event_
;
bool
has_event_
{
false
};
#endif
...
...
paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h
浏览文件 @
8fe09faf
...
...
@@ -737,7 +737,7 @@ x.second );
}
int
assign_async
(
const
concurrent_unordered_map
&
other
,
cuda
Stream_t
stream
=
0
)
{
gpu
Stream_t
stream
=
0
)
{
m_collisions
=
other
.
m_collisions
;
if
(
other
.
m_hashtbl_size
<=
m_hashtbl_capacity
)
{
m_hashtbl_size
=
other
.
m_hashtbl_size
;
...
...
@@ -754,7 +754,7 @@ x.second );
return
0
;
}
void
clear_async
(
cuda
Stream_t
stream
=
0
)
{
void
clear_async
(
gpu
Stream_t
stream
=
0
)
{
constexpr
int
block_size
=
128
;
init_hashtbl
<<<
((
m_hashtbl_size
-
1
)
/
block_size
)
+
1
,
block_size
,
0
,
stream
>>>
(
m_hashtbl_values
,
m_hashtbl_size
,
unused_key
,
...
...
@@ -771,7 +771,7 @@ x.second );
}
}
int
prefetch
(
const
int
dev_id
,
cuda
Stream_t
stream
=
0
)
{
int
prefetch
(
const
int
dev_id
,
gpu
Stream_t
stream
=
0
)
{
cudaPointerAttributes
hashtbl_values_ptr_attributes
;
cudaError_t
status
=
cudaPointerGetAttributes
(
&
hashtbl_values_ptr_attributes
,
m_hashtbl_values
);
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
8fe09faf
...
...
@@ -9,7 +9,7 @@ copy_if_different(${pass_file} ${pass_file_final})
add_subdirectory
(
fuse_optimizer_ops_pass
)
add_subdirectory
(
memory_optimize_pass
)
add_subdirectory
(
multi_devices_graph_pass
)
if
(
NOT APPLE AND NOT WIN32 AND
WITH_GPU
)
if
(
NOT APPLE AND NOT WIN32
AND
(
WITH_GPU OR WITH_ROCM
)
)
add_subdirectory
(
fusion_group
)
endif
()
...
...
@@ -93,7 +93,7 @@ pass_library(multihead_matmul_fuse_pass inference)
pass_library
(
adaptive_pool2d_convert_global_pass inference
)
pass_library
(
unsqueeze2_eltwise_fuse_pass inference
)
pass_library
(
layer_norm_fuse_pass inference
)
if
(
WITH_GPU
)
if
(
WITH_GPU
OR WITH_ROCM
)
pass_library
(
cudnn_placement_pass base DEPS placement_pass_base
)
pass_library
(
embedding_eltwise_layernorm_fuse_pass inference
)
endif
()
...
...
@@ -153,7 +153,7 @@ cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_
cc_test
(
test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass
)
cc_test
(
test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass
)
cc_test
(
test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor
)
if
(
WITH_GPU
)
if
(
WITH_GPU
OR WITH_ROCM
)
cc_test
(
test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass
)
cc_test
(
test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass
)
endif
()
...
...
@@ -169,7 +169,7 @@ if (WITH_MKLDNN)
cc_test
(
test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util
)
cc_test
(
test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util
)
set
(
TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context
)
if
(
WITH_GPU
)
if
(
WITH_GPU
OR WITH_ROCM
)
set
(
TEST_CONV_BN_PASS_DEPS
${
TEST_CONV_BN_PASS_DEPS
}
depthwise_conv
)
endif
()
cc_test
(
test_conv_batch_norm_mkldnn_fuse_pass SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc DEPS
${
TEST_CONV_BN_PASS_DEPS
}
)
...
...
paddle/fluid/framework/ir/fuse_bn_act_pass.cc
浏览文件 @
8fe09faf
...
...
@@ -27,14 +27,17 @@ class Node;
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
FuseBatchNormActPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
#if
def PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(7, 4, 1)
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if
defined(PADDLE_WITH_HIP) ||
CUDNN_VERSION_MIN(7, 4, 1)
// forward
std
::
unordered_set
<
std
::
string
>
act_types
=
{
"relu"
};
graph
=
FuseBatchNormAct
(
graph
,
act_types
);
...
...
paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc
浏览文件 @
8fe09faf
...
...
@@ -19,14 +19,17 @@
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
FuseBatchNormAddActPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
#if
def PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(7, 4, 1)
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if
defined(PADDLE_WITH_HIP) ||
CUDNN_VERSION_MIN(7, 4, 1)
// forward
std
::
unordered_set
<
std
::
string
>
act_types
=
{
"relu"
};
graph
=
FuseBatchNormAddAct
(
graph
,
act_types
);
...
...
paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
浏览文件 @
8fe09faf
cc_library
(
code_generator
SRCS operation.cc code_generator.cc code_generator_helper.cc
DEPS graph subgraph_detector
)
if
(
WITH_GPU
)
if
(
WITH_GPU
OR WITH_ROCM
)
cc_test
(
test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass
)
endif
()
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc
浏览文件 @
8fe09faf
...
...
@@ -28,7 +28,7 @@ class LoDTensor;
}
// namespace framework
}
// namespace paddle
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
namespace
paddle
{
namespace
framework
{
...
...
@@ -180,7 +180,11 @@ void TestMainImpl(std::string func_name, std::string code_str,
paddle
::
platform
::
CUDAPlace
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
paddle
::
platform
::
CUDADeviceCode
device_code
(
place
,
func_name
,
code_str
);
#ifdef PADDLE_WITH_HIP
device_code
.
Compile
(
true
);
#else
device_code
.
Compile
(
is_float16
);
#endif
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
gpu_tensors
(
cpu_tensors
.
size
());
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
tmp_cpu_tensors
(
cpu_tensors
.
size
());
...
...
paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc
浏览文件 @
8fe09faf
...
...
@@ -180,7 +180,7 @@ TEST(test_reference_count_pass, test_no_need_buffer_var_shrink) {
{{
"Out"
,
{
x7
}}},
{});
std
::
vector
<
bool
>
use_cuda_list
{
false
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_cuda_list
.
push_back
(
true
);
#endif
for
(
auto
use_cuda
:
use_cuda_list
)
{
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc
浏览文件 @
8fe09faf
...
...
@@ -30,7 +30,7 @@ class AllReduceDepsPass : public ir::Pass {
std
::
vector
<
details
::
OpHandleBase
*>
all_reduce_op_handles
=
GetSortedAllReduceOps
(
*
graph
);
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
use_hierarchical_allreduce
=
Get
<
bool
>
(
details
::
kUseHierarchicalAllReduce
);
for
(
size_t
i
=
0
;
i
<
all_reduce_op_handles
.
size
();
++
i
)
{
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
浏览文件 @
8fe09faf
...
...
@@ -36,7 +36,7 @@ class FuseAllReduceOpPass : public ir::Pass {
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
*
multi_nccl_ctxs
=
&
Get
<
platform
::
NCCLCommunicator
>
(
details
::
kNCCLCtxs
);
#elif defined(PADDLE_WITH_XPU_BKCL)
...
...
@@ -90,7 +90,7 @@ class FuseAllReduceOpPass : public ir::Pass {
for
(
auto
&
p_g
:
group_p_g
)
{
group_all_reduce_ops
.
emplace_back
(
all_reduce_ops
.
at
(
p_g
.
second
));
}
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
InsertFusedAllReduce
(
places
,
local_scopes
,
group_size
,
group_all_reduce_ops
,
multi_nccl_ctxs
,
&
result
);
#elif defined(PADDLE_WITH_XPU_BKCL)
...
...
@@ -156,7 +156,7 @@ class FuseAllReduceOpPass : public ir::Pass {
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
num_of_all_reduce
,
const
std
::
vector
<
ir
::
Node
*>
&
all_reduce_ops
,
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
const
platform
::
NCCLCommunicator
*
multi_nccl_ctxs
,
#elif defined(PADDLE_WITH_XPU_BKCL)
const
platform
::
BKCLCommunicator
*
multi_bkcl_ctxs
,
...
...
@@ -217,7 +217,7 @@ class FuseAllReduceOpPass : public ir::Pass {
result
->
RemoveNode
(
op_handle
.
Node
());
}
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
CreateFusedAllReduceOp
(
inputs
,
outputs
,
num_of_all_reduce
,
places
,
local_scopes
,
is_grad_merge
,
grad_merge_cond_name
,
multi_nccl_ctxs
,
result
);
...
...
@@ -240,7 +240,7 @@ class FuseAllReduceOpPass : public ir::Pass {
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
bool
is_grad_merge
,
const
std
::
string
&
grad_merge_cond_name
,
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
const
platform
::
NCCLCommunicator
*
multi_nccl_ctxs
,
#elif defined(PADDLE_WITH_XPU_BKCL)
const
platform
::
BKCLCommunicator
*
multi_bkcl_ctxs
,
...
...
@@ -248,7 +248,7 @@ class FuseAllReduceOpPass : public ir::Pass {
ir
::
Graph
*
result
)
const
{
details
::
FusedAllReduceOpHandle
*
op_handle
=
NULL
;
if
(
is_grad_merge
)
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
op_handle
=
new
details
::
FusedGradMergeAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
...
@@ -267,7 +267,7 @@ class FuseAllReduceOpPass : public ir::Pass {
local_scopes
,
places
,
num_of_all_reduce
,
grad_merge_cond_name
);
#endif
}
else
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
...
@@ -293,7 +293,7 @@ class FuseAllReduceOpPass : public ir::Pass {
op_handle
->
AddOutput
(
out
);
}
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
if
(
!
multi_nccl_ctxs
)
{
SetCommunicationContext
(
places
,
op_handle
);
}
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
浏览文件 @
8fe09faf
...
...
@@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
strategy_
=
Get
<
const
details
::
BuildStrategy
>
(
kStrategy
);
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
multi_nccl_ctxs_
=
&
Get
<
platform
::
NCCLCommunicator
>
(
details
::
kNCCLCtxs
);
nccl_ctxs_
=
nullptr
;
if
(
multi_nccl_ctxs_
)
{
...
...
@@ -323,7 +323,7 @@ std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations(
bool
MultiDevSSAGraphBuilderBase
::
UseGPU
()
const
{
bool
use_gpu
=
false
;
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
use_gpu
=
nccl_ctxs_
!=
nullptr
;
#endif
return
use_gpu
;
...
...
@@ -373,7 +373,7 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
void
MultiDevSSAGraphBuilderBase
::
SetCommunicationContext
(
details
::
OpHandleBase
*
op_handle
,
const
platform
::
Place
&
p
)
const
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
if
(
nccl_ctxs_
==
nullptr
)
{
op_handle
->
SetDeviceContext
(
p
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
...
...
@@ -392,7 +392,7 @@ void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
void
MultiDevSSAGraphBuilderBase
::
CreateBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
*
op_handle
=
new
details
::
BroadcastOpHandle
(
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
);
...
...
@@ -429,7 +429,7 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
void
MultiDevSSAGraphBuilderBase
::
CreateFusedBroadcastOp
(
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
auto
*
op_handle
=
new
details
::
FusedBroadcastOpHandle
(
result
->
CreateEmptyNode
(
"fused_broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
);
...
...
@@ -499,7 +499,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
const
std
::
vector
<
Scope
*>
&
scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
)
->
details
::
OpHandleBase
*
{
if
(
is_encoded
)
{
#if defined(PADDLE_WITH_DGC) && defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_DGC) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
SparseAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
...
@@ -515,7 +516,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
grad_merge_cond_name
=
BOOST_GET_CONST
(
std
::
string
,
node
->
Op
()
->
GetAttr
(
GRAD_MERGE_COND_NAME
));
VLOG
(
10
)
<<
"og="
<<
og
<<
" use grad_merge_allreduce"
;
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
GradMergeAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
...
@@ -532,7 +533,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
scopes
,
places
,
grad_merge_cond_name
));
#endif
}
else
{
#if
def PADDLE_WITH_NCCL
#if
defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
AllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
...
...
@@ -648,7 +649,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
details
::
VarHandle
*
MultiDevSSAGraphBuilderBase
::
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
size_t
dst_dev_id
)
const
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
浏览文件 @
8fe09faf
...
...
@@ -39,7 +39,7 @@ class Graph;
namespace
paddle
{
namespace
platform
{
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
class
NCCLCommunicator
;
class
NCCLContextMap
;
#elif defined(PADDLE_WITH_XPU_BKCL)
...
...
@@ -117,7 +117,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
void
CreateIsolatedVarNode
(
ir
::
Graph
*
result
,
ir
::
Node
*
var_node
)
const
;
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
|| defined(PADDLE_WITH_RCCL)
mutable
platform
::
NCCLContextMap
*
nccl_ctxs_
{
nullptr
};
mutable
platform
::
NCCLCommunicator
*
multi_nccl_ctxs_
{
nullptr
};
#elif defined(PADDLE_WITH_XPU_BKCL)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录