未验证 提交 8fe09faf 编写于 作者: Q Qi Li 提交者: GitHub

[ROCM] update fluid framework for rocm (part1), test=develop (#31009)

上级 33429630
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_NCCL #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DECLARE_bool(sync_nccl_allreduce); DECLARE_bool(sync_nccl_allreduce);
#endif #endif
...@@ -25,7 +25,7 @@ namespace paddle { ...@@ -25,7 +25,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
...@@ -182,7 +182,7 @@ void AllReduceOpHandle::AllReduceFunc( ...@@ -182,7 +182,7 @@ void AllReduceOpHandle::AllReduceFunc(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::vector<std::string> &out_var_names) { const std::vector<std::string> &out_var_names) {
if (is_gpu_place(places[0])) { if (is_gpu_place(places[0])) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_, PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The nccl context should not be NULL.")); "The nccl context should not be NULL."));
...@@ -198,7 +198,7 @@ void AllReduceOpHandle::AllReduceFunc( ...@@ -198,7 +198,7 @@ void AllReduceOpHandle::AllReduceFunc(
NCCLAllReduceFunc(all_reduce_calls); NCCLAllReduceFunc(all_reduce_calls);
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA.")); platform::errors::PreconditionNotMet("Not compiled with GPU."));
#endif #endif
} else if (is_xpu_place(places[0])) { } else if (is_xpu_place(places[0])) {
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
...@@ -265,7 +265,7 @@ void AllReduceOpHandle::BKCLAllReduceFunc( ...@@ -265,7 +265,7 @@ void AllReduceOpHandle::BKCLAllReduceFunc(
} }
#endif #endif
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void AllReduceOpHandle::NCCLAllReduceFunc( void AllReduceOpHandle::NCCLAllReduceFunc(
const std::vector<std::function<void()>> &all_reduce_calls) { const std::vector<std::function<void()>> &all_reduce_calls) {
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
...@@ -291,8 +291,13 @@ void AllReduceOpHandle::SyncNCCLAllReduce() { ...@@ -291,8 +291,13 @@ void AllReduceOpHandle::SyncNCCLAllReduce() {
nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_); nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, use_hierarchical_allreduce_);
auto &nccl_ctx = nccl_ctxs->at(dev_id); auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto stream = nccl_ctx.stream(); auto stream = nccl_ctx.stream();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream));
PADDLE_ENFORCE_CUDA_SUCCESS(hipGetLastError());
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError()); PADDLE_ENFORCE_CUDA_SUCCESS(cudaGetLastError());
#endif
} }
} }
} }
......
...@@ -31,7 +31,7 @@ namespace platform { ...@@ -31,7 +31,7 @@ namespace platform {
class NCCLCommunicator; class NCCLCommunicator;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
...@@ -43,7 +43,7 @@ namespace paddle { ...@@ -43,7 +43,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
class AllReduceOpHandle : public NCCLOpHandleBase { class AllReduceOpHandle : public NCCLOpHandleBase {
public: public:
AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes, AllReduceOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
...@@ -74,13 +74,14 @@ class AllReduceOpHandle : public OpHandleBase { ...@@ -74,13 +74,14 @@ class AllReduceOpHandle : public OpHandleBase {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
#if !(PADDLE_WITH_NCCL || PADDLE_WITH_XPU_BKCL) #if !defined(PADDLE_WITH_NCCL) && !defined(PADDLE_WITH_RCCL) && \
!defined(PADDLE_WITH_XPU_BKCL)
// NCCLOpHandleBase and BKCLOpHandleBase already have these attributes. // NCCLOpHandleBase and BKCLOpHandleBase already have these attributes.
// Will polish it by class inheritance framework. // Will polish it by class inheritance framework.
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#endif #endif
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void NCCLAllReduceFunc( void NCCLAllReduceFunc(
const std::vector<std::function<void()>> &all_reduce_calls); const std::vector<std::function<void()>> &all_reduce_calls);
......
...@@ -158,7 +158,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -158,7 +158,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"fuse_relu_depthwise_conv_pass"); "fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass"); AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_add_act_ops_, "fuse_bn_add_act_pass"); AppendPassWithCheck(strategy_.fuse_bn_add_act_ops_, "fuse_bn_add_act_pass");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
!defined(_WIN32) && !defined(__APPLE__)
AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass");
#else #else
LOG(WARNING) << "fusion_group is not enabled for Windows/MacOS now, and " LOG(WARNING) << "fusion_group is not enabled for Windows/MacOS now, and "
...@@ -305,7 +306,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -305,7 +306,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DeviceType use_device, DeviceType use_device,
platform::NCCLCommunicator *nccl_ctxs) const { platform::NCCLCommunicator *nccl_ctxs) const {
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
...@@ -331,7 +332,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -331,7 +332,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Erase(kNRanks); pass->Erase(kNRanks);
pass->Set<size_t>(kNRanks, new size_t(nranks)); pass->Set<size_t>(kNRanks, new size_t(nranks));
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
platform::NCCLCommunicator *nctx = platform::NCCLCommunicator *nctx =
(use_device == p::kCUDA) ? nccl_ctxs : nullptr; (use_device == p::kCUDA) ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
...@@ -351,7 +352,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -351,7 +352,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
pass->Erase(kLocalScopes); pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes, pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes); &local_scopes);
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
platform::NCCLCommunicator *nctx = platform::NCCLCommunicator *nctx =
(use_device == p::kCUDA) ? nccl_ctxs : nullptr; (use_device == p::kCUDA) ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
...@@ -378,7 +379,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -378,7 +379,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
LOG(INFO) << "set enable_sequential_execution:" LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_; << enable_sequential_execution_;
} else if (pass->Type() == "all_reduce_deps_pass") { } else if (pass->Type() == "all_reduce_deps_pass") {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
platform::NCCLCommunicator *nctx = platform::NCCLCommunicator *nctx =
(use_device == p::kCUDA) ? nccl_ctxs : nullptr; (use_device == p::kCUDA) ? nccl_ctxs : nullptr;
pass->Erase(kNCCLCtxs); pass->Erase(kNCCLCtxs);
...@@ -474,6 +475,7 @@ USE_PASS(add_reader_dependency_pass); ...@@ -474,6 +475,7 @@ USE_PASS(add_reader_dependency_pass);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass); USE_PASS(mkldnn_placement_pass);
#endif #endif
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
!defined(_WIN32) && !defined(__APPLE__)
USE_PASS(fusion_group_pass); USE_PASS(fusion_group_pass);
#endif #endif
...@@ -39,7 +39,7 @@ class NCCLCommunicator; ...@@ -39,7 +39,7 @@ class NCCLCommunicator;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/bkcl_helper.h" #include "paddle/fluid/platform/bkcl_helper.h"
...@@ -185,7 +185,7 @@ struct BuildStrategy { ...@@ -185,7 +185,7 @@ struct BuildStrategy {
const std::string &loss_var_name, const std::string &loss_var_name,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const size_t &nranks, const size_t &nranks,
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DeviceType use_device, DeviceType use_device,
platform::NCCLCommunicator *nccl_ctxs) const; platform::NCCLCommunicator *nccl_ctxs) const;
#elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU) && defined(PADDLE_WITH_XPU_BKCL)
......
...@@ -47,7 +47,7 @@ struct TestGatherOpHandle { ...@@ -47,7 +47,7 @@ struct TestGatherOpHandle {
void InitCtxOnGpu(bool use_gpu) { void InitCtxOnGpu(bool use_gpu) {
if (use_gpu) { if (use_gpu) {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int count = p::GetCUDADeviceCount(); int count = p::GetCUDADeviceCount();
if (count <= 1) { if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA " LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
...@@ -214,7 +214,7 @@ TEST(GatherTester, TestCPUGatherTestSelectedRows) { ...@@ -214,7 +214,7 @@ TEST(GatherTester, TestCPUGatherTestSelectedRows) {
test_op.TestGatherSelectedRows(input_scope_idx); test_op.TestGatherSelectedRows(input_scope_idx);
} }
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(GatherTester, TestGPUGatherTestSelectedRows) { TEST(GatherTester, TestGPUGatherTestSelectedRows) {
TestGatherOpHandle test_op; TestGatherOpHandle test_op;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h"
#ifdef PADDLE_WITH_NCCL #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
DECLARE_bool(sync_nccl_allreduce); DECLARE_bool(sync_nccl_allreduce);
#endif #endif
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
GradMergeAllReduceOpHandle::GradMergeAllReduceOpHandle( GradMergeAllReduceOpHandle::GradMergeAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
...@@ -68,7 +68,7 @@ std::string GradMergeAllReduceOpHandle::Name() const { ...@@ -68,7 +68,7 @@ std::string GradMergeAllReduceOpHandle::Name() const {
return "grad_merge_all_reduce"; return "grad_merge_all_reduce";
} }
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
FusedGradMergeAllReduceOpHandle::FusedGradMergeAllReduceOpHandle( FusedGradMergeAllReduceOpHandle::FusedGradMergeAllReduceOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes, ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const size_t num_of_all_reduce, const std::vector<platform::Place> &places, const size_t num_of_all_reduce,
......
...@@ -33,7 +33,7 @@ namespace platform { ...@@ -33,7 +33,7 @@ namespace platform {
class NCCLCommunicator; class NCCLCommunicator;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h" #include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
...@@ -44,7 +44,7 @@ namespace details { ...@@ -44,7 +44,7 @@ namespace details {
class GradMergeAllReduceOpHandle : public AllReduceOpHandle { class GradMergeAllReduceOpHandle : public AllReduceOpHandle {
public: public:
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
GradMergeAllReduceOpHandle(ir::Node *node, GradMergeAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
...@@ -75,7 +75,7 @@ class GradMergeAllReduceOpHandle : public AllReduceOpHandle { ...@@ -75,7 +75,7 @@ class GradMergeAllReduceOpHandle : public AllReduceOpHandle {
class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle { class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle {
public: public:
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
FusedGradMergeAllReduceOpHandle(ir::Node *node, FusedGradMergeAllReduceOpHandle(ir::Node *node,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
......
...@@ -126,10 +126,10 @@ struct VarHandle : public VarHandleBase { ...@@ -126,10 +126,10 @@ struct VarHandle : public VarHandleBase {
name_(std::move(name)), name_(std::move(name)),
place_(std::move(place)) {} place_(std::move(place)) {}
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
bool HasEvent() { return has_event_; } bool HasEvent() { return has_event_; }
const cudaEvent_t& GetEvent() { const gpuEvent_t& GetEvent() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
HasEvent(), true, HasEvent(), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -137,7 +137,7 @@ struct VarHandle : public VarHandleBase { ...@@ -137,7 +137,7 @@ struct VarHandle : public VarHandleBase {
return event_; return event_;
} }
void SetGenerateEvent(const cudaEvent_t& event) { void SetGenerateEvent(const gpuEvent_t& event) {
has_event_ = true; has_event_ = true;
event_ = event; event_ = event;
} }
...@@ -150,9 +150,9 @@ struct VarHandle : public VarHandleBase { ...@@ -150,9 +150,9 @@ struct VarHandle : public VarHandleBase {
size_t scope_idx_; size_t scope_idx_;
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Only when this event is triggered, var is generated. // Only when this event is triggered, var is generated.
cudaEvent_t event_; gpuEvent_t event_;
bool has_event_{false}; bool has_event_{false};
#endif #endif
......
...@@ -737,7 +737,7 @@ x.second ); ...@@ -737,7 +737,7 @@ x.second );
} }
int assign_async(const concurrent_unordered_map& other, int assign_async(const concurrent_unordered_map& other,
cudaStream_t stream = 0) { gpuStream_t stream = 0) {
m_collisions = other.m_collisions; m_collisions = other.m_collisions;
if (other.m_hashtbl_size <= m_hashtbl_capacity) { if (other.m_hashtbl_size <= m_hashtbl_capacity) {
m_hashtbl_size = other.m_hashtbl_size; m_hashtbl_size = other.m_hashtbl_size;
...@@ -754,7 +754,7 @@ x.second ); ...@@ -754,7 +754,7 @@ x.second );
return 0; return 0;
} }
void clear_async(cudaStream_t stream = 0) { void clear_async(gpuStream_t stream = 0) {
constexpr int block_size = 128; constexpr int block_size = 128;
init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0, init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0,
stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key, stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key,
...@@ -771,7 +771,7 @@ x.second ); ...@@ -771,7 +771,7 @@ x.second );
} }
} }
int prefetch(const int dev_id, cudaStream_t stream = 0) { int prefetch(const int dev_id, gpuStream_t stream = 0) {
cudaPointerAttributes hashtbl_values_ptr_attributes; cudaPointerAttributes hashtbl_values_ptr_attributes;
cudaError_t status = cudaPointerGetAttributes( cudaError_t status = cudaPointerGetAttributes(
&hashtbl_values_ptr_attributes, m_hashtbl_values); &hashtbl_values_ptr_attributes, m_hashtbl_values);
......
...@@ -9,7 +9,7 @@ copy_if_different(${pass_file} ${pass_file_final}) ...@@ -9,7 +9,7 @@ copy_if_different(${pass_file} ${pass_file_final})
add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass) add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass) add_subdirectory(multi_devices_graph_pass)
if(NOT APPLE AND NOT WIN32 AND WITH_GPU) if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM))
add_subdirectory(fusion_group) add_subdirectory(fusion_group)
endif() endif()
...@@ -93,7 +93,7 @@ pass_library(multihead_matmul_fuse_pass inference) ...@@ -93,7 +93,7 @@ pass_library(multihead_matmul_fuse_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference) pass_library(layer_norm_fuse_pass inference)
if(WITH_GPU) if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference) pass_library(embedding_eltwise_layernorm_fuse_pass inference)
endif() endif()
...@@ -153,7 +153,7 @@ cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_ ...@@ -153,7 +153,7 @@ cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_
cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass) cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass)
cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass) cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass)
cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor) cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor)
if(WITH_GPU) if(WITH_GPU OR WITH_ROCM)
cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass) cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
endif() endif()
...@@ -169,7 +169,7 @@ if (WITH_MKLDNN) ...@@ -169,7 +169,7 @@ if (WITH_MKLDNN)
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util) cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util) cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util)
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context) set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context)
if (WITH_GPU) if (WITH_GPU OR WITH_ROCM)
set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv) set(TEST_CONV_BN_PASS_DEPS ${TEST_CONV_BN_PASS_DEPS} depthwise_conv)
endif() endif()
cc_test(test_conv_batch_norm_mkldnn_fuse_pass SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc DEPS ${TEST_CONV_BN_PASS_DEPS}) cc_test(test_conv_batch_norm_mkldnn_fuse_pass SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc DEPS ${TEST_CONV_BN_PASS_DEPS})
......
...@@ -27,14 +27,17 @@ class Node; ...@@ -27,14 +27,17 @@ class Node;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void FuseBatchNormActPass::ApplyImpl(ir::Graph *graph) const { void FuseBatchNormActPass::ApplyImpl(ir::Graph *graph) const {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if CUDNN_VERSION_MIN(7, 4, 1) #if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 4, 1)
// forward // forward
std::unordered_set<std::string> act_types = {"relu"}; std::unordered_set<std::string> act_types = {"relu"};
graph = FuseBatchNormAct(graph, act_types); graph = FuseBatchNormAct(graph, act_types);
......
...@@ -19,14 +19,17 @@ ...@@ -19,14 +19,17 @@
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void FuseBatchNormAddActPass::ApplyImpl(ir::Graph *graph) const { void FuseBatchNormAddActPass::ApplyImpl(ir::Graph *graph) const {
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if CUDNN_VERSION_MIN(7, 4, 1) #if defined(PADDLE_WITH_HIP) || CUDNN_VERSION_MIN(7, 4, 1)
// forward // forward
std::unordered_set<std::string> act_types = {"relu"}; std::unordered_set<std::string> act_types = {"relu"};
graph = FuseBatchNormAddAct(graph, act_types); graph = FuseBatchNormAddAct(graph, act_types);
......
cc_library(code_generator cc_library(code_generator
SRCS operation.cc code_generator.cc code_generator_helper.cc SRCS operation.cc code_generator.cc code_generator_helper.cc
DEPS graph subgraph_detector) DEPS graph subgraph_detector)
if(WITH_GPU) if(WITH_GPU OR WITH_ROCM)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass) cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif() endif()
......
...@@ -28,7 +28,7 @@ class LoDTensor; ...@@ -28,7 +28,7 @@ class LoDTensor;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -180,7 +180,11 @@ void TestMainImpl(std::string func_name, std::string code_str, ...@@ -180,7 +180,11 @@ void TestMainImpl(std::string func_name, std::string code_str,
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode device_code(place, func_name, code_str); paddle::platform::CUDADeviceCode device_code(place, func_name, code_str);
#ifdef PADDLE_WITH_HIP
device_code.Compile(true);
#else
device_code.Compile(is_float16); device_code.Compile(is_float16);
#endif
std::vector<paddle::framework::LoDTensor> gpu_tensors(cpu_tensors.size()); std::vector<paddle::framework::LoDTensor> gpu_tensors(cpu_tensors.size());
std::vector<paddle::framework::LoDTensor> tmp_cpu_tensors(cpu_tensors.size()); std::vector<paddle::framework::LoDTensor> tmp_cpu_tensors(cpu_tensors.size());
......
...@@ -180,7 +180,7 @@ TEST(test_reference_count_pass, test_no_need_buffer_var_shrink) { ...@@ -180,7 +180,7 @@ TEST(test_reference_count_pass, test_no_need_buffer_var_shrink) {
{{"Out", {x7}}}, {}); {{"Out", {x7}}}, {});
std::vector<bool> use_cuda_list{false}; std::vector<bool> use_cuda_list{false};
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_cuda_list.push_back(true); use_cuda_list.push_back(true);
#endif #endif
for (auto use_cuda : use_cuda_list) { for (auto use_cuda : use_cuda_list) {
......
...@@ -30,7 +30,7 @@ class AllReduceDepsPass : public ir::Pass { ...@@ -30,7 +30,7 @@ class AllReduceDepsPass : public ir::Pass {
std::vector<details::OpHandleBase*> all_reduce_op_handles = std::vector<details::OpHandleBase*> all_reduce_op_handles =
GetSortedAllReduceOps(*graph); GetSortedAllReduceOps(*graph);
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto use_hierarchical_allreduce = auto use_hierarchical_allreduce =
Get<bool>(details::kUseHierarchicalAllReduce); Get<bool>(details::kUseHierarchicalAllReduce);
for (size_t i = 0; i < all_reduce_op_handles.size(); ++i) { for (size_t i = 0; i < all_reduce_op_handles.size(); ++i) {
......
...@@ -36,7 +36,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -36,7 +36,7 @@ class FuseAllReduceOpPass : public ir::Pass {
auto &places = Get<const std::vector<platform::Place>>(details::kPlaces); auto &places = Get<const std::vector<platform::Place>>(details::kPlaces);
auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes); auto &local_scopes = Get<const std::vector<Scope *>>(details::kLocalScopes);
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *multi_nccl_ctxs = auto *multi_nccl_ctxs =
&Get<platform::NCCLCommunicator>(details::kNCCLCtxs); &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
...@@ -90,7 +90,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -90,7 +90,7 @@ class FuseAllReduceOpPass : public ir::Pass {
for (auto &p_g : group_p_g) { for (auto &p_g : group_p_g) {
group_all_reduce_ops.emplace_back(all_reduce_ops.at(p_g.second)); group_all_reduce_ops.emplace_back(all_reduce_ops.at(p_g.second));
} }
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
InsertFusedAllReduce(places, local_scopes, group_size, InsertFusedAllReduce(places, local_scopes, group_size,
group_all_reduce_ops, multi_nccl_ctxs, &result); group_all_reduce_ops, multi_nccl_ctxs, &result);
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
...@@ -156,7 +156,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -156,7 +156,7 @@ class FuseAllReduceOpPass : public ir::Pass {
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const size_t num_of_all_reduce, const size_t num_of_all_reduce,
const std::vector<ir::Node *> &all_reduce_ops, const std::vector<ir::Node *> &all_reduce_ops,
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
const platform::NCCLCommunicator *multi_nccl_ctxs, const platform::NCCLCommunicator *multi_nccl_ctxs,
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
const platform::BKCLCommunicator *multi_bkcl_ctxs, const platform::BKCLCommunicator *multi_bkcl_ctxs,
...@@ -217,7 +217,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -217,7 +217,7 @@ class FuseAllReduceOpPass : public ir::Pass {
result->RemoveNode(op_handle.Node()); result->RemoveNode(op_handle.Node());
} }
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places, CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places,
local_scopes, is_grad_merge, grad_merge_cond_name, local_scopes, is_grad_merge, grad_merge_cond_name,
multi_nccl_ctxs, result); multi_nccl_ctxs, result);
...@@ -240,7 +240,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -240,7 +240,7 @@ class FuseAllReduceOpPass : public ir::Pass {
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::vector<Scope *> &local_scopes, bool is_grad_merge, const std::vector<Scope *> &local_scopes, bool is_grad_merge,
const std::string &grad_merge_cond_name, const std::string &grad_merge_cond_name,
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
const platform::NCCLCommunicator *multi_nccl_ctxs, const platform::NCCLCommunicator *multi_nccl_ctxs,
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
const platform::BKCLCommunicator *multi_bkcl_ctxs, const platform::BKCLCommunicator *multi_bkcl_ctxs,
...@@ -248,7 +248,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -248,7 +248,7 @@ class FuseAllReduceOpPass : public ir::Pass {
ir::Graph *result) const { ir::Graph *result) const {
details::FusedAllReduceOpHandle *op_handle = NULL; details::FusedAllReduceOpHandle *op_handle = NULL;
if (is_grad_merge) { if (is_grad_merge) {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
op_handle = new details::FusedGradMergeAllReduceOpHandle( op_handle = new details::FusedGradMergeAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", result->CreateEmptyNode("fused_all_reduce",
ir::Node::Type::kOperation), ir::Node::Type::kOperation),
...@@ -267,7 +267,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -267,7 +267,7 @@ class FuseAllReduceOpPass : public ir::Pass {
local_scopes, places, num_of_all_reduce, grad_merge_cond_name); local_scopes, places, num_of_all_reduce, grad_merge_cond_name);
#endif #endif
} else { } else {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
op_handle = new details::FusedAllReduceOpHandle( op_handle = new details::FusedAllReduceOpHandle(
result->CreateEmptyNode("fused_all_reduce", result->CreateEmptyNode("fused_all_reduce",
ir::Node::Type::kOperation), ir::Node::Type::kOperation),
...@@ -293,7 +293,7 @@ class FuseAllReduceOpPass : public ir::Pass { ...@@ -293,7 +293,7 @@ class FuseAllReduceOpPass : public ir::Pass {
op_handle->AddOutput(out); op_handle->AddOutput(out);
} }
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (!multi_nccl_ctxs) { if (!multi_nccl_ctxs) {
SetCommunicationContext(places, op_handle); SetCommunicationContext(places, op_handle);
} }
......
...@@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const { ...@@ -157,7 +157,7 @@ void MultiDevSSAGraphBuilderBase::Init() const {
places_ = Get<const std::vector<platform::Place>>(details::kPlaces); places_ = Get<const std::vector<platform::Place>>(details::kPlaces);
local_scopes_ = Get<const std::vector<Scope *>>(details::kLocalScopes); local_scopes_ = Get<const std::vector<Scope *>>(details::kLocalScopes);
strategy_ = Get<const details::BuildStrategy>(kStrategy); strategy_ = Get<const details::BuildStrategy>(kStrategy);
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
multi_nccl_ctxs_ = &Get<platform::NCCLCommunicator>(details::kNCCLCtxs); multi_nccl_ctxs_ = &Get<platform::NCCLCommunicator>(details::kNCCLCtxs);
nccl_ctxs_ = nullptr; nccl_ctxs_ = nullptr;
if (multi_nccl_ctxs_) { if (multi_nccl_ctxs_) {
...@@ -323,7 +323,7 @@ std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations( ...@@ -323,7 +323,7 @@ std::vector<ir::Node *> MultiDevSSAGraphBuilderBase::SortOperations(
bool MultiDevSSAGraphBuilderBase::UseGPU() const { bool MultiDevSSAGraphBuilderBase::UseGPU() const {
bool use_gpu = false; bool use_gpu = false;
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
use_gpu = nccl_ctxs_ != nullptr; use_gpu = nccl_ctxs_ != nullptr;
#endif #endif
return use_gpu; return use_gpu;
...@@ -373,7 +373,7 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result, ...@@ -373,7 +373,7 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
void MultiDevSSAGraphBuilderBase::SetCommunicationContext( void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
details::OpHandleBase *op_handle, const platform::Place &p) const { details::OpHandleBase *op_handle, const platform::Place &p) const {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nccl_ctxs_ == nullptr) { if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -392,7 +392,7 @@ void MultiDevSSAGraphBuilderBase::SetCommunicationContext( ...@@ -392,7 +392,7 @@ void MultiDevSSAGraphBuilderBase::SetCommunicationContext(
void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *op_handle = new details::BroadcastOpHandle( auto *op_handle = new details::BroadcastOpHandle(
result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_); local_scopes_, places_, nccl_ctxs_);
...@@ -429,7 +429,7 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result, ...@@ -429,7 +429,7 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp( void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
ir::Graph *result, ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const { const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto *op_handle = new details::FusedBroadcastOpHandle( auto *op_handle = new details::FusedBroadcastOpHandle(
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation), result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_); local_scopes_, places_, nccl_ctxs_);
...@@ -499,7 +499,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, ...@@ -499,7 +499,8 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
const std::vector<Scope *> &scopes, const std::vector<Scope *> &scopes,
const std::vector<platform::Place> &places) -> details::OpHandleBase * { const std::vector<platform::Place> &places) -> details::OpHandleBase * {
if (is_encoded) { if (is_encoded) {
#if defined(PADDLE_WITH_DGC) && defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_DGC) && \
(defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new details::SparseAllReduceOpHandle( new details::SparseAllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
...@@ -515,7 +516,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, ...@@ -515,7 +516,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
grad_merge_cond_name = BOOST_GET_CONST( grad_merge_cond_name = BOOST_GET_CONST(
std::string, node->Op()->GetAttr(GRAD_MERGE_COND_NAME)); std::string, node->Op()->GetAttr(GRAD_MERGE_COND_NAME));
VLOG(10) << "og=" << og << " use grad_merge_allreduce"; VLOG(10) << "og=" << og << " use grad_merge_allreduce";
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new details::GradMergeAllReduceOpHandle( new details::GradMergeAllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
...@@ -532,7 +533,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result, ...@@ -532,7 +533,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
scopes, places, grad_merge_cond_name)); scopes, places, grad_merge_cond_name));
#endif #endif
} else { } else {
#ifdef PADDLE_WITH_NCCL #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new details::AllReduceOpHandle( new details::AllReduceOpHandle(
result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation),
...@@ -648,7 +649,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps( ...@@ -648,7 +649,7 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp( details::VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
ir::Graph *result, const std::string &og, size_t dst_dev_id) const { ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
result->Get<GraphOps>(kGraphOps).emplace_back(new details::ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new details::ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_)); local_scopes_, places_, nccl_ctxs_));
......
...@@ -39,7 +39,7 @@ class Graph; ...@@ -39,7 +39,7 @@ class Graph;
namespace paddle { namespace paddle {
namespace platform { namespace platform {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
class NCCLCommunicator; class NCCLCommunicator;
class NCCLContextMap; class NCCLContextMap;
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
...@@ -117,7 +117,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -117,7 +117,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
void CreateIsolatedVarNode(ir::Graph *result, ir::Node *var_node) const; void CreateIsolatedVarNode(ir::Graph *result, ir::Node *var_node) const;
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
mutable platform::NCCLContextMap *nccl_ctxs_{nullptr}; mutable platform::NCCLContextMap *nccl_ctxs_{nullptr};
mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr}; mutable platform::NCCLCommunicator *multi_nccl_ctxs_{nullptr};
#elif defined(PADDLE_WITH_XPU_BKCL) #elif defined(PADDLE_WITH_XPU_BKCL)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册