From 8fe09faf14ee43dac6e7fc2a13620210319a4c59 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Mon, 22 Feb 2021 20:04:05 +0800 Subject: [PATCH] [ROCM] update fluid framework for rocm (part1), test=develop (#31009) --- .../framework/details/all_reduce_op_handle.cc | 15 ++++++++++----- .../framework/details/all_reduce_op_handle.h | 9 +++++---- .../fluid/framework/details/build_strategy.cc | 14 ++++++++------ .../fluid/framework/details/build_strategy.h | 4 ++-- .../details/gather_op_handle_test.cc | 4 ++-- .../grad_merge_all_reduce_op_handle.cc | 6 +++--- .../details/grad_merge_all_reduce_op_handle.h | 6 +++--- paddle/fluid/framework/details/var_handle.h | 10 +++++----- .../cudf/concurrent_unordered_map.cuh.h | 6 +++--- paddle/fluid/framework/ir/CMakeLists.txt | 8 ++++---- paddle/fluid/framework/ir/fuse_bn_act_pass.cc | 7 +++++-- .../framework/ir/fuse_bn_add_act_pass.cc | 7 +++++-- .../framework/ir/fusion_group/CMakeLists.txt | 2 +- .../ir/fusion_group/code_generator_tester.cc | 6 +++++- ...est_reference_count_pass_last_lived_ops.cc | 2 +- .../all_reduce_deps_pass.cc | 2 +- .../fuse_all_reduce_op_pass.cc | 16 ++++++++-------- .../multi_devices_graph_pass.cc | 19 ++++++++++--------- .../multi_devices_graph_pass.h | 4 ++-- 19 files changed, 83 insertions(+), 64 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 42797975f80..3429677a240 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/platform/profiler.h" -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) DECLARE_bool(sync_nccl_allreduce); #endif @@ -25,7 +25,7 @@ namespace paddle { namespace framework { namespace details { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, @@ -182,7 +182,7 @@ void AllReduceOpHandle::AllReduceFunc( const std::vector &places, const std::vector &out_var_names) { if (is_gpu_place(places[0])) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_, platform::errors::InvalidArgument( "The nccl context should not be NULL.")); @@ -198,7 +198,7 @@ void AllReduceOpHandle::AllReduceFunc( NCCLAllReduceFunc(all_reduce_calls); #else PADDLE_THROW( - platform::errors::PreconditionNotMet("Not compiled with CUDA.")); + platform::errors::PreconditionNotMet("Not compiled with GPU.")); #endif } else if (is_xpu_place(places[0])) { #if defined(PADDLE_WITH_XPU_BKCL) @@ -265,7 +265,7 @@ void AllReduceOpHandle::BKCLAllReduceFunc( } #endif -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void AllReduceOpHandle::NCCLAllReduceFunc( const std::vector> &all_reduce_calls) { this->RunAndRecordEvent([&] { @@ -291,8 +291,13 @@ void AllReduceOpHandle::SyncNCCLAllReduce() { nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_); auto &nccl_ctx = nccl_ctxs->at(dev_id); auto stream = nccl_ctx.stream(); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError()); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError()); +#endif } } } diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index fa260dea09e..39b923be9df 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -31,7 +31,7 @@ namespace platform { class NCCLCommunicator; } // namespace platform } // namespace paddle -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/platform/nccl_helper.h" #elif defined(PADDLE_WITH_XPU_BKCL) @@ -43,7 +43,7 @@ namespace paddle { namespace framework { namespace details { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) class AllReduceOpHandle : public NCCLOpHandleBase { public: AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, @@ -74,13 +74,14 @@ class AllReduceOpHandle : public OpHandleBase { std::vector local_scopes_; -#if !(PADDLE_WITH_NCCL || PADDLE_WITH_XPU_BKCL) +#if !defined(PADDLE_WITH_NCCL) && !defined(PADDLE_WITH_RCCL) && \ + !defined(PADDLE_WITH_XPU_BKCL) // NCCLOpHandleBase and BKCLOpHandleBase already have these attributes. // Will polish it by class inheritance framework. std::vector places_; #endif -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void NCCLAllReduceFunc( const std::vector> &all_reduce_calls); diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 4ee11f55a67..34c87b83889 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -158,7 +158,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { "fuse_relu_depthwise_conv_pass"); AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass"); AppendPassWithCheck(strategy_.fuse_bn_add_act_ops_, "fuse_bn_add_act_pass"); -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \ + !defined(_WIN32) && !defined(__APPLE__) AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); #else LOG(WARNING) << "fusion_group is not enabled for Windows/MacOS now, and " @@ -305,7 +306,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, const std::string &loss_var_name, const std::vector &local_scopes, const size_t &nranks, -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) DeviceType use_device, platform::NCCLCommunicator *nccl_ctxs) const { #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) @@ -331,7 +332,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, pass->Erase(kNRanks); pass->Set(kNRanks, new size_t(nranks)); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) platform::NCCLCommunicator *nctx = (use_device == p::kCUDA) ? nccl_ctxs : nullptr; pass->Erase(kNCCLCtxs); @@ -351,7 +352,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, pass->Erase(kLocalScopes); pass->SetNotOwned>(kLocalScopes, &local_scopes); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) platform::NCCLCommunicator *nctx = (use_device == p::kCUDA) ? nccl_ctxs : nullptr; pass->Erase(kNCCLCtxs); @@ -378,7 +379,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, LOG(INFO) << "set enable_sequential_execution:" << enable_sequential_execution_; } else if (pass->Type() == "all_reduce_deps_pass") { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) platform::NCCLCommunicator *nctx = (use_device == p::kCUDA) ? nccl_ctxs : nullptr; pass->Erase(kNCCLCtxs); @@ -474,6 +475,7 @@ USE_PASS(add_reader_dependency_pass); #ifdef PADDLE_WITH_MKLDNN USE_PASS(mkldnn_placement_pass); #endif -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \ + !defined(_WIN32) && !defined(__APPLE__) USE_PASS(fusion_group_pass); #endif diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 13ee0a1b4f5..81d2d5e6dae 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -39,7 +39,7 @@ class NCCLCommunicator; } // namespace platform } // namespace paddle -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) #include "paddle/fluid/platform/bkcl_helper.h" @@ -185,7 +185,7 @@ struct BuildStrategy { const std::string &loss_var_name, const std::vector &local_scopes, const size_t &nranks, -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) DeviceType use_device, platform::NCCLCommunicator *nccl_ctxs) const; #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index ae4779194f3..98c37ca3c40 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -47,7 +47,7 @@ struct TestGatherOpHandle { void InitCtxOnGpu(bool use_gpu) { if (use_gpu) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) int count = p::GetCUDADeviceCount(); if (count <= 1) { LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA " @@ -214,7 +214,7 @@ TEST(GatherTester, TestCPUGatherTestSelectedRows) { test_op.TestGatherSelectedRows(input_scope_idx); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) TEST(GatherTester, TestGPUGatherTestSelectedRows) { TestGatherOpHandle test_op; diff --git a/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.cc b/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.cc index c424efee057..a6232667193 100644 --- a/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h" -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) DECLARE_bool(sync_nccl_allreduce); #endif @@ -21,7 +21,7 @@ namespace paddle { namespace framework { namespace details { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) GradMergeAllReduceOpHandle::GradMergeAllReduceOpHandle( ir::Node *node, const std::vector &local_scopes, const std::vector &places, @@ -68,7 +68,7 @@ std::string GradMergeAllReduceOpHandle::Name() const { return "grad_merge_all_reduce"; } -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) FusedGradMergeAllReduceOpHandle::FusedGradMergeAllReduceOpHandle( ir::Node *node, const std::vector &local_scopes, const std::vector &places, const size_t num_of_all_reduce, diff --git a/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h b/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h index 5c18f8fef11..c59f6134730 100644 --- a/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h @@ -33,7 +33,7 @@ namespace platform { class NCCLCommunicator; } // namespace platform } // namespace paddle -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -44,7 +44,7 @@ namespace details { class GradMergeAllReduceOpHandle : public AllReduceOpHandle { public: -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) GradMergeAllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, @@ -75,7 +75,7 @@ class GradMergeAllReduceOpHandle : public AllReduceOpHandle { class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle { public: -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) FusedGradMergeAllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index a35ac0bd732..6f7e6a90f76 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -126,10 +126,10 @@ struct VarHandle : public VarHandleBase { name_(std::move(name)), place_(std::move(place)) {} -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) bool HasEvent() { return has_event_; } - const cudaEvent_t& GetEvent() { + const gpuEvent_t& GetEvent() { PADDLE_ENFORCE_EQ( HasEvent(), true, platform::errors::PreconditionNotMet( @@ -137,7 +137,7 @@ struct VarHandle : public VarHandleBase { return event_; } - void SetGenerateEvent(const cudaEvent_t& event) { + void SetGenerateEvent(const gpuEvent_t& event) { has_event_ = true; event_ = event; } @@ -150,9 +150,9 @@ struct VarHandle : public VarHandleBase { size_t scope_idx_; std::string name_; platform::Place place_; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // Only when this event is triggered, var is generated. - cudaEvent_t event_; + gpuEvent_t event_; bool has_event_{false}; #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h b/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h index c5647f2cdcf..d14abd218c2 100644 --- a/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h +++ b/paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h @@ -737,7 +737,7 @@ x.second ); } int assign_async(const concurrent_unordered_map& other, - cudaStream_t stream = 0) { + gpuStream_t stream = 0) { m_collisions = other.m_collisions; if (other.m_hashtbl_size <= m_hashtbl_capacity) { m_hashtbl_size = other.m_hashtbl_size; @@ -754,7 +754,7 @@ x.second ); return 0; } - void clear_async(cudaStream_t stream = 0) { + void clear_async(gpuStream_t stream = 0) { constexpr int block_size = 128; init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0, stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key, @@ -771,7 +771,7 @@ x.second ); } } - int prefetch(const int dev_id, cudaStream_t stream = 0) { + int prefetch(const int dev_id, gpuStream_t stream = 0) { cudaPointerAttributes hashtbl_values_ptr_attributes; cudaError_t status = cudaPointerGetAttributes( &hashtbl_values_ptr_attributes, m_hashtbl_values); diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 089737bb7c4..0ca78c679ae 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -9,7 +9,7 @@ copy_if_different(${pass_file} ${pass_file_final}) add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(memory_optimize_pass) add_subdirectory(multi_devices_graph_pass) -if(NOT APPLE AND NOT WIN32 AND WITH_GPU) +if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM)) add_subdirectory(fusion_group) endif() @@ -93,7 +93,7 @@ pass_library(multihead_matmul_fuse_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(layer_norm_fuse_pass inference) -if(WITH_GPU) +if(WITH_GPU OR WITH_ROCM) pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(embedding_eltwise_layernorm_fuse_pass inference) endif() @@ -153,7 +153,7 @@ cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_ cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass) cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass) cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor) -if(WITH_GPU) +if(WITH_GPU OR WITH_ROCM) cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) endif() @@ -169,7 +169,7 @@ if (WITH_MKLDNN) cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util) cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util) set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context) -if (WITH_GPU) +if (WITH_GPU OR WITH_ROCM) set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv) endif() cc_test(test_conv_batch_norm_mkldnn_fuse_pass SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc DEPS ${TEST_CONV_BN_PASS_DEPS}) diff --git a/paddle/fluid/framework/ir/fuse_bn_act_pass.cc b/paddle/fluid/framework/ir/fuse_bn_act_pass.cc index d8b5e3712d9..ae662c64af3 100644 --- a/paddle/fluid/framework/ir/fuse_bn_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_bn_act_pass.cc @@ -27,14 +27,17 @@ class Node; #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif namespace paddle { namespace framework { namespace ir { void FuseBatchNormActPass::ApplyImpl(ir::Graph *graph) const { -#ifdef PADDLE_WITH_CUDA -#if CUDNN_VERSION_MIN(7, 4, 1) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 4, 1) // forward std::unordered_set act_types = {"relu"}; graph = FuseBatchNormAct(graph, act_types); diff --git a/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc index 12b92837468..ec014d331fa 100644 --- a/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_bn_add_act_pass.cc @@ -19,14 +19,17 @@ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/platform/miopen_helper.h" +#endif namespace paddle { namespace framework { namespace ir { void FuseBatchNormAddActPass::ApplyImpl(ir::Graph *graph) const { -#ifdef PADDLE_WITH_CUDA -#if CUDNN_VERSION_MIN(7, 4, 1) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 4, 1) // forward std::unordered_set act_types = {"relu"}; graph = FuseBatchNormAddAct(graph, act_types); diff --git a/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt b/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt index 8586069cdf7..78b15398cc7 100644 --- a/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt +++ b/paddle/fluid/framework/ir/fusion_group/CMakeLists.txt @@ -1,7 +1,7 @@ cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph subgraph_detector) -if(WITH_GPU) +if(WITH_GPU OR WITH_ROCM) cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass) endif() diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc index 03d88c00707..0d490d4e669 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc @@ -28,7 +28,7 @@ class LoDTensor; } // namespace framework } // namespace paddle -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) namespace paddle { namespace framework { @@ -180,7 +180,11 @@ void TestMainImpl(std::string func_name, std::string code_str, paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); paddle::platform::CUDADeviceCode device_code(place, func_name, code_str); +#ifdef PADDLE_WITH_HIP + device_code.Compile(true); +#else device_code.Compile(is_float16); +#endif std::vector gpu_tensors(cpu_tensors.size()); std::vector tmp_cpu_tensors(cpu_tensors.size()); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc index a29b07fbe90..f410171f998 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc @@ -180,7 +180,7 @@ TEST(test_reference_count_pass, test_no_need_buffer_var_shrink) { {{"Out", {x7}}}, {}); std::vector use_cuda_list{false}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) use_cuda_list.push_back(true); #endif for (auto use_cuda : use_cuda_list) { diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc index 80480d4123e..cfbb6303ef1 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc @@ -30,7 +30,7 @@ class AllReduceDepsPass : public ir::Pass { std::vector all_reduce_op_handles = GetSortedAllReduceOps(*graph); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto use_hierarchical_allreduce = Get(details::kUseHierarchicalAllReduce); for (size_t i = 0; i < all_reduce_op_handles.size(); ++i) { diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc index 6d927d61707..484d09fd444 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc @@ -36,7 +36,7 @@ class FuseAllReduceOpPass : public ir::Pass { auto &places = Get>(details::kPlaces); auto &local_scopes = Get>(details::kLocalScopes); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto *multi_nccl_ctxs = &Get(details::kNCCLCtxs); #elif defined(PADDLE_WITH_XPU_BKCL) @@ -90,7 +90,7 @@ class FuseAllReduceOpPass : public ir::Pass { for (auto &p_g : group_p_g) { group_all_reduce_ops.emplace_back(all_reduce_ops.at(p_g.second)); } -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) InsertFusedAllReduce(places, local_scopes, group_size, group_all_reduce_ops, multi_nccl_ctxs, &result); #elif defined(PADDLE_WITH_XPU_BKCL) @@ -156,7 +156,7 @@ class FuseAllReduceOpPass : public ir::Pass { const std::vector &local_scopes, const size_t num_of_all_reduce, const std::vector &all_reduce_ops, -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) const platform::NCCLCommunicator *multi_nccl_ctxs, #elif defined(PADDLE_WITH_XPU_BKCL) const platform::BKCLCommunicator *multi_bkcl_ctxs, @@ -217,7 +217,7 @@ class FuseAllReduceOpPass : public ir::Pass { result->RemoveNode(op_handle.Node()); } -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places, local_scopes, is_grad_merge, grad_merge_cond_name, multi_nccl_ctxs, result); @@ -240,7 +240,7 @@ class FuseAllReduceOpPass : public ir::Pass { const std::vector &places, const std::vector &local_scopes, bool is_grad_merge, const std::string &grad_merge_cond_name, -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) const platform::NCCLCommunicator *multi_nccl_ctxs, #elif defined(PADDLE_WITH_XPU_BKCL) const platform::BKCLCommunicator *multi_bkcl_ctxs, @@ -248,7 +248,7 @@ class FuseAllReduceOpPass : public ir::Pass { ir::Graph *result) const { details::FusedAllReduceOpHandle *op_handle = NULL; if (is_grad_merge) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) op_handle = new details::FusedGradMergeAllReduceOpHandle( result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation), @@ -267,7 +267,7 @@ class FuseAllReduceOpPass : public ir::Pass { local_scopes, places, num_of_all_reduce, grad_merge_cond_name); #endif } else { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) op_handle = new details::FusedAllReduceOpHandle( result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation), @@ -293,7 +293,7 @@ class FuseAllReduceOpPass : public ir::Pass { op_handle->AddOutput(out); } -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) if (!multi_nccl_ctxs) { SetCommunicationContext(places, op_handle); } diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc index 0c03531aa88..c50e00f9995 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc @@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const { places_ = Get>(details::kPlaces); local_scopes_ = Get>(details::kLocalScopes); strategy_ = Get(kStrategy); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) multi_nccl_ctxs_ = &Get(details::kNCCLCtxs); nccl_ctxs_ = nullptr; if (multi_nccl_ctxs_) { @@ -323,7 +323,7 @@ std::vector MultiDevSSAGraphBuilderBase::SortOperations( bool MultiDevSSAGraphBuilderBase::UseGPU() const { bool use_gpu = false; -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) use_gpu = nccl_ctxs_ != nullptr; #endif return use_gpu; @@ -373,7 +373,7 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result, void MultiDevSSAGraphBuilderBase::SetCommunicationContext( details::OpHandleBase *op_handle, const platform::Place &p) const { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) if (nccl_ctxs_ == nullptr) { op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); @@ -392,7 +392,7 @@ void MultiDevSSAGraphBuilderBase::SetCommunicationContext( void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto *op_handle = new details::BroadcastOpHandle( result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_); @@ -429,7 +429,7 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp( ir::Graph *result, const std::vector> &bcast_varnames) const { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto *op_handle = new details::FusedBroadcastOpHandle( result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_); @@ -499,7 +499,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, const std::vector &scopes, const std::vector &places) -> details::OpHandleBase * { if (is_encoded) { -#if defined(PADDLE_WITH_DGC) && defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_DGC) && \ + (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)) result->Get(kGraphOps).emplace_back( new details::SparseAllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), @@ -515,7 +516,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, grad_merge_cond_name = BOOST_GET_CONST( std::string, node->Op()->GetAttr(GRAD_MERGE_COND_NAME)); VLOG(10) << "og=" << og << " use grad_merge_allreduce"; -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) result->Get(kGraphOps).emplace_back( new details::GradMergeAllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), @@ -532,7 +533,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, scopes, places, grad_merge_cond_name)); #endif } else { -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) result->Get(kGraphOps).emplace_back( new details::AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), @@ -648,7 +649,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps( details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp( ir::Graph *result, const std::string &og, size_t dst_dev_id) const { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) result->Get(kGraphOps).emplace_back(new details::ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h index 95c93479a50..27eda22828e 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h @@ -39,7 +39,7 @@ class Graph; namespace paddle { namespace platform { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) class NCCLCommunicator; class NCCLContextMap; #elif defined(PADDLE_WITH_XPU_BKCL) @@ -117,7 +117,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { void CreateIsolatedVarNode(ir::Graph *result, ir::Node *var_node) const; -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) mutable platform::NCCLContextMap *nccl_ctxs_{nullptr}; mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr}; #elif defined(PADDLE_WITH_XPU_BKCL) -- GitLab