diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 4d0cfb629763f72cc5059e37149fd1e676811d42..61f3c026f1facc6afb2b9b45316b1205cf676904 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -4,6 +4,10 @@ if(WITH_PSLIB) nv_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc DEPS heter_ps) add_subdirectory(heter_ps) + elseif(WITH_RCCL) + hip_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cu ps_gpu_wrapper.cc + DEPS heter_ps) + add_subdirectory(heter_ps) else() cc_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cc) endif(WITH_NCCL) @@ -12,11 +16,16 @@ else() cc_library(ps_gpu_wrapper SRCS ps_gpu_wrapper.cc) endif(WITH_PSLIB) -if(WITH_NCCL) +if(WITH_NCCL OR WITH_RCCL) cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope) endif() if(WITH_BOX_PS) - nv_library(box_wrapper SRCS box_wrapper.cc box_wrapper.cu DEPS framework_proto lod_tensor box_ps) + if(WITH_GPU) + nv_library(box_wrapper SRCS box_wrapper.cc box_wrapper.cu DEPS framework_proto lod_tensor box_ps) + endif() + if(WITH_ROCM) + hip_library(box_wrapper SRCS box_wrapper.cc box_wrapper.cu DEPS framework_proto lod_tensor box_ps) + endif() else() cc_library(box_wrapper SRCS box_wrapper.cc DEPS framework_proto lod_tensor) endif(WITH_BOX_PS) diff --git a/paddle/fluid/framework/fleet/box_wrapper.cc b/paddle/fluid/framework/fleet/box_wrapper.cc index 2d3e6943822f823d40a21e4e60ec87abf7bfbaef..37fbf47f854ade5f854d206c29eca6c11a89ee85 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cc +++ b/paddle/fluid/framework/fleet/box_wrapper.cc @@ -25,7 +25,7 @@ namespace paddle { namespace framework { std::shared_ptr BoxWrapper::s_instance_ = nullptr; -cudaStream_t BoxWrapper::stream_list_[8]; +gpuStream_t BoxWrapper::stream_list_[8]; std::shared_ptr BoxWrapper::boxps_ptr_ = nullptr; AfsManager* BoxWrapper::afs_manager = nullptr; int BoxWrapper::embedx_dim_ = 8; diff --git a/paddle/fluid/framework/fleet/box_wrapper.cu b/paddle/fluid/framework/fleet/box_wrapper.cu index 31809532a69760c7398e19572694c03b8a1ae67e..c9b5abf7a9befc01a2defe64e346e8f192cad70a 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cu +++ b/paddle/fluid/framework/fleet/box_wrapper.cu @@ -142,8 +142,13 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, ->stream(); auto buf_value = memory::AllocShared(place, values.size() * sizeof(float*)); float** gpu_values = reinterpret_cast(buf_value->ptr()); +#ifdef PADDLE_WITH_HIP + hipMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), + hipMemcpyHostToDevice); +#else cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*), cudaMemcpyHostToDevice); +#endif #define EMBEDX_CASE(i, ...) \ case i: { \ constexpr size_t EmbedxDim = i; \ @@ -155,6 +160,19 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, } \ } break +#ifdef PADDLE_WITH_HIP +#define EXPAND_EMBED_PUSH_CASE(i, ...) \ + case i: { \ + constexpr size_t ExpandDim = i; \ + hipLaunchKernelGGL( \ + PushCopy, dim3((total_length + 512 - 1) / 512), \ + dim3(512), 0, stream, gpu_values, \ + reinterpret_cast*>( \ + total_values_gpu), \ + gpu_len, hidden_size, expand_embed_dim, slot_num, total_length, \ + gpu_keys); \ + } break +#else #define EXPAND_EMBED_PULL_CASE(i, ...) \ case i: { \ constexpr size_t ExpandDim = i; \ @@ -166,6 +184,7 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place, gpu_len, hidden_size, expand_embed_dim, slot_num, total_length, \ gpu_keys); \ } break +#endif switch (hidden_size - 3) { EMBEDX_CASE(8, EXPAND_EMBED_PULL_CASE(0); EXPAND_EMBED_PULL_CASE(8); @@ -187,9 +206,16 @@ void BoxWrapper::CopyKeys(const paddle::platform::Place& place, platform::DeviceContextPool::Instance().Get( BOOST_GET_CONST(platform::CUDAPlace, place))) ->stream(); +#ifdef PADDLE_WITH_HIP + hipLaunchKernelGGL(CopyKeysKernel, dim3((total_len + 512 - 1) / 512), + dim3(512), 0, stream, origin_keys, total_keys, gpu_len, + slot_num, total_len); + hipStreamSynchronize(stream); +#else CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>( origin_keys, total_keys, gpu_len, slot_num, total_len); cudaStreamSynchronize(stream); +#endif } void BoxWrapper::CopyForPush(const paddle::platform::Place& place, @@ -217,12 +243,21 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place, int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); int* d_slot_vector = reinterpret_cast(buf_slot_vector->ptr()); +#ifdef PADDLE_WITH_HIP + hipMemcpy(gpu_values, grad_values.data(), grad_values.size() * sizeof(float*), + hipMemcpyHostToDevice); + hipMemcpy(gpu_len, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), hipMemcpyHostToDevice); + hipMemcpy(d_slot_vector, slot_vector_.data(), + slot_lengths_lod.size() * sizeof(int), hipMemcpyHostToDevice); +#else cudaMemcpy(gpu_values, grad_values.data(), grad_values.size() * sizeof(float*), cudaMemcpyHostToDevice); cudaMemcpy(gpu_len, slot_lengths_lod.data(), slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); cudaMemcpy(d_slot_vector, slot_vector_.data(), slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice); +#endif #define EMBEDX_CASE(i, ...) \ case i: { \ @@ -235,6 +270,18 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place, } \ } break +#ifdef PADDLE_WITH_HIP +#define EXPAND_EMBED_PUSH_CASE(i, ...) \ + case i: { \ + constexpr size_t ExpandDim = i; \ + hipLaunchKernelGGL(PushCopy, \ + dim3(total_length + 512 - 1) / 512), dim3(512), 0, stream, \ + reinterpret_cast*>( \ + total_grad_values_gpu), \ + gpu_values, gpu_len, hidden_size, expand_embed_dim, \ + slot_lengths.size(), total_length, batch_size, d_slot_vector); \ + } break +#else #define EXPAND_EMBED_PUSH_CASE(i, ...) \ case i: { \ constexpr size_t ExpandDim = i; \ @@ -245,6 +292,7 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place, gpu_values, gpu_len, hidden_size, expand_embed_dim, \ slot_lengths.size(), total_length, batch_size, d_slot_vector); \ } break +#endif switch (hidden_size - 3) { EMBEDX_CASE(8, EXPAND_EMBED_PUSH_CASE(0); EXPAND_EMBED_PUSH_CASE(8); diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 399ee744ea9ab4abe17199398719d04957790a76..645d725871a061165a1138d96bcff8261bca3056 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -396,7 +396,7 @@ class BoxWrapper { const std::string& model_path) { if (nullptr != s_instance_) { VLOG(3) << "Begin InitializeGPU"; - std::vector stream_list; + std::vector stream_list; for (int i = 0; i < platform::GetCUDADeviceCount(); ++i) { VLOG(3) << "before get context i[" << i << "]"; platform::CUDADeviceContext* context = @@ -542,8 +542,12 @@ class BoxWrapper { auto* gpu_data = gpu_tensor.data(); auto len = gpu_tensor.numel(); data->resize(len); +#ifdef PADDLE_WITH_HIP + hipMemcpy(data->data(), gpu_data, sizeof(T) * len, hipMemcpyDeviceToHost); +#else cudaMemcpy(data->data(), gpu_data, sizeof(T) * len, cudaMemcpyDeviceToHost); +#endif } static inline std::pair parse_cmatch_rank(uint64_t x) { // first 32 bit store cmatch and second 32 bit store rank @@ -819,7 +823,7 @@ class BoxWrapper { } private: - static cudaStream_t stream_list_[8]; + static gpuStream_t stream_list_[8]; static std::shared_ptr boxps_ptr_; boxps::PSAgentBase* p_agent_ = nullptr; // TODO(hutuxian): magic number, will add a config to specify diff --git a/paddle/fluid/framework/fleet/box_wrapper_impl.h b/paddle/fluid/framework/fleet/box_wrapper_impl.h index b4e414dc83ef1000f2e1e09525699b5bb47d2441..8832f0a20e376cb05985c874c863aa89ea3df14d 100644 --- a/paddle/fluid/framework/fleet/box_wrapper_impl.h +++ b/paddle/fluid/framework/fleet/box_wrapper_impl.h @@ -43,7 +43,7 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in PaddleBox now.")); } else if (platform::is_gpu_place(place)) { -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) VLOG(3) << "Begin copy keys, key_num[" << total_length << "]"; int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId(); LoDTensor& total_keys_tensor = keys_tensor[device_id]; @@ -60,11 +60,17 @@ void BoxWrapper::PullSparseCase(const paddle::platform::Place& place, memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t)); uint64_t** gpu_keys = reinterpret_cast(buf_key->ptr()); int64_t* gpu_len = reinterpret_cast(buf_length->ptr()); +#ifdef PADDLE_WITH_HIP + hipMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), + hipMemcpyHostToDevice); + hipMemcpy(gpu_len, slot_lengths_lod.data(), + slot_lengths.size() * sizeof(int64_t), hipMemcpyHostToDevice); +#else cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*), cudaMemcpyHostToDevice); cudaMemcpy(gpu_len, slot_lengths_lod.data(), slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice); - +#endif this->CopyKeys(place, gpu_keys, total_keys, gpu_len, static_cast(slot_lengths.size()), static_cast(total_length)); @@ -124,7 +130,7 @@ void BoxWrapper::PushSparseGradCase( PADDLE_THROW(platform::errors::Unimplemented( "Warning:: CPUPlace is not supported in PaddleBox now.")); } else if (platform::is_gpu_place(place)) { -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId(); LoDTensor& cached_total_keys_tensor = keys_tensor[device_id]; uint64_t* total_keys = diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index 425c8a9f2a72a9f3d103392c1281c284c12d2073..7ad20aa6e18c802eead4f39a336f64d0011545ea 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -698,13 +698,14 @@ void FleetWrapper::PushDenseVarsSync( Scope* scope, const uint64_t table_id, const std::vector& var_names) {} -#if (defined PADDLE_WITH_CUDA) && (defined PADDLE_WITH_PSLIB) +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \ + (defined PADDLE_WITH_PSLIB) void FleetWrapper::PushDenseVarsAsync( const Scope& scope, const uint64_t table_id, const std::vector& var_names, std::vector<::std::future>* push_sparse_status, float scale_datanorm, int batch_size, const paddle::platform::Place& place, - cudaStream_t stream, cudaEvent_t event) { + gpuStream_t stream, gpuEvent_t event) { std::vector regions; for (auto& t : var_names) { Variable* var = scope.FindVar(t); @@ -719,8 +720,13 @@ void FleetWrapper::PushDenseVarsAsync( memory::Copy(platform::CUDAPinnedPlace(), pin_g, BOOST_GET_CONST(platform::CUDAPlace, place), g_data, sizeof(float) * count, stream); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, stream)); + hipEventSynchronize(event); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, stream)); cudaEventSynchronize(event); +#endif float* g = pin_g; if (scale_datanorm >= 0) { diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index aa0da8286269fea1818eef0de770256a016d1e56..e584fb5e2b9ca77923161c8c89c2e7784c5d164b 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -152,14 +152,14 @@ class FleetWrapper { // Push dense variables to server in async mode // Param: scope, table_id, var_names, scale_datanorm, batch_size // Param: push_sparse_status -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void PushDenseVarsAsync( const Scope& scope, const uint64_t table_id, const std::vector& var_names, std::vector<::std::future>* push_sparse_status, float scale_datanorm, int batch_size, - const paddle::platform::Place& place, cudaStream_t stream, - cudaEvent_t event); + const paddle::platform::Place& place, gpuStream_t stream, + gpuEvent_t event); #endif #ifdef PADDLE_WITH_XPU void PushDenseVarsAsync( diff --git a/paddle/fluid/framework/fleet/heter_context.h b/paddle/fluid/framework/fleet/heter_context.h index 2ea3c10fd87beceb3b1a9ea95effa2d4f46480bd..fc987b523d559a2559050602c4b8e98692804c1c 100644 --- a/paddle/fluid/framework/fleet/heter_context.h +++ b/paddle/fluid/framework/fleet/heter_context.h @@ -14,7 +14,8 @@ limitations under the License. */ #pragma once -#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ + (defined PADDLE_WITH_PSLIB) #include #include diff --git a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt index 2eed13c530d919e712af0abebbb7e870f1ac1d24..6df2cd52bb401d3cc378c2776073471070f1e411 100644 --- a/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt @@ -1,6 +1,10 @@ -nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc -heter_resource.h hashtable.h DEPS cub device_context) -nv_test(test_heter_comm SRCS test_heter_comm.cu feature_value.h DEPS -heter_comm) - -nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) +IF(WITH_GPU) + nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) + nv_test(test_heter_comm SRCS test_heter_comm.cu feature_value.h DEPS heter_comm) + nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) +ENDIF() +IF(WITH_ROCM) + hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context) + hip_test(test_heter_comm SRCS test_heter_comm.cu feature_value.h DEPS heter_comm) + hip_library(heter_ps SRCS heter_ps.cu DEPS heter_comm) +ENDIF() diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable.h b/paddle/fluid/framework/fleet/heter_ps/hashtable.h index 11bd6e7aa69c3b720609c4f1bd4e90f952ebe866..2aa00e84e1599bfa09b013dfb00bbda1299fe9e6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable.h @@ -45,15 +45,15 @@ class HashTable { HashTable(const HashTable&) = delete; HashTable& operator=(const HashTable&) = delete; void insert(const KeyType* d_keys, const ValType* d_vals, size_t len, - cudaStream_t stream); + gpuStream_t stream); void get(const KeyType* d_keys, ValType* d_vals, size_t len, - cudaStream_t stream); + gpuStream_t stream); void show(); void dump_to_cpu(int devid, cudaStream_t stream); template void update(const KeyType* d_keys, const GradType* d_grads, size_t len, - Sgd sgd, cudaStream_t stream); + Sgd sgd, gpuStream_t stream); private: TableContainer* container_; diff --git a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h index ef37ed64c2a5f785b9ec79c4b971df214581c5df..871f9c7857af46d8aad7cfbfafcdc80f0f52f259 100644 --- a/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h @@ -87,7 +87,7 @@ void HashTable::show() { template void HashTable::get(const KeyType* d_keys, ValType* d_vals, - size_t len, cudaStream_t stream) { + size_t len, gpuStream_t stream) { if (len == 0) { return; } @@ -99,7 +99,7 @@ void HashTable::get(const KeyType* d_keys, ValType* d_vals, template void HashTable::insert(const KeyType* d_keys, const ValType* d_vals, size_t len, - cudaStream_t stream) { + gpuStream_t stream) { if (len == 0) { return; } @@ -147,7 +147,7 @@ template template void HashTable::update(const KeyType* d_keys, const GradType* d_grads, size_t len, - Sgd sgd, cudaStream_t stream) { + Sgd sgd, gpuStream_t stream) { if (len == 0) { return; } diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h index f95d4d3948b1924497f62689b97b13e7917aaf2b..e42a3a324f1cda43deefcb87d9a39899691324f1 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h @@ -25,7 +25,7 @@ __global__ void fill_idx(T* idx, size_t len) { } template -void show_tensor(T* input, size_t len, cudaStream_t stream, std::string name) { +void show_tensor(T* input, size_t len, gpuStream_t stream, std::string name) { T tmp[len]; cudaMemcpyAsync(&tmp, input, sizeof(T) * len, cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); @@ -270,7 +270,7 @@ void HeterComm::build_ps(int num, KeyType* h_keys, std::vector> d_key_bufs; std::vector> d_val_bufs; - cudaStream_t streams[stream_num]; + gpuStream_t streams[stream_num]; for (int i = 0; i < stream_num; ++i) { PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(streams[i]))); auto d_k_buf = memory::AllocShared(place, chunk_size * sizeof(KeyType)); diff --git a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h index 938164dd194119e86b188cff85a71c053594aef0..ad7649a8a33cb77e0581429d1da202b2d218b0dc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/heter_resource.h +++ b/paddle/fluid/framework/fleet/heter_ps/heter_resource.h @@ -34,16 +34,16 @@ class GPUResource { int dev_id() const { return dev_id_; } int index() const { return index_; } - cudaStream_t local_stream(int num) { return local_streams_[num]; } - cudaStream_t remote_stream() { return remote_stream_; } - cudaStream_t comm_stream(int num) { return comm_streams_[num]; } + gpuStream_t local_stream(int num) { return local_streams_[num]; } + gpuStream_t remote_stream() { return remote_stream_; } + gpuStream_t comm_stream(int num) { return comm_streams_[num]; } int dev_id_; int index_; std::vector dev_ids_; - cudaStream_t remote_stream_; - std::vector local_streams_; - std::vector comm_streams_; + gpuStream_t remote_stream_; + std::vector local_streams_; + std::vector comm_streams_; }; class HeterPsResource { @@ -56,9 +56,9 @@ class HeterPsResource { int total_gpu(); int get_index_by_devid(int devid); int dev_id(int num); - cudaStream_t local_stream(int gpu_num, int stream_num); - cudaStream_t remote_stream(int gpu_num); - cudaStream_t comm_stream(int gpu_num, int stream_num); + gpuStream_t local_stream(int gpu_num, int stream_num); + gpuStream_t remote_stream(int gpu_num); + gpuStream_t comm_stream(int gpu_num, int stream_num); std::vector> resources_; std::vector dev_ids_; diff --git a/paddle/fluid/framework/fleet/heter_wrapper.cc b/paddle/fluid/framework/fleet/heter_wrapper.cc index 8e232560ab6876995a735b6901a5459265f9cb05..a0667e9adbb0000a979f268779e03224aa8eb93c 100644 --- a/paddle/fluid/framework/fleet/heter_wrapper.cc +++ b/paddle/fluid/framework/fleet/heter_wrapper.cc @@ -114,7 +114,7 @@ void HeterWrapper::SerializeToReq(const std::string& varname, Scope* scope, memcpy(data_ptr, tensor->data(), tensor->numel() * SizeOfType(tensor->type())); } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) memory::Copy(platform::CPUPlace(), data_ptr, BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), tensor->data(), @@ -129,11 +129,11 @@ void HeterWrapper::SerializeToReq(const std::string& varname, Scope* scope, } } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void HeterWrapper::DeSerializeToTensor(Scope* scope, const VariableMessage& req_var, platform::Place place, - cudaStream_t stream) { + gpuStream_t stream) { // const VariableMessage& req_var = request->vars(); auto* var = scope->FindVar(req_var.varname()); auto* tensor = var->GetMutable(); @@ -157,7 +157,7 @@ void HeterWrapper::DeSerializeToTensor(Scope* scope, void* tensor_data = tensor->mutable_data(place, ToVarType(req_var.data_type())); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data, platform::CPUPlace(), req_var.data().data(), tensor->numel() * SizeOfType(tensor->type()), stream); diff --git a/paddle/fluid/framework/fleet/heter_wrapper.h b/paddle/fluid/framework/fleet/heter_wrapper.h index 55ad218198e67982c2624ff35907ec7237f01fc7..871d2e251b41016d548fa1e257560aca9db030d7 100644 --- a/paddle/fluid/framework/fleet/heter_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_wrapper.h @@ -86,9 +86,9 @@ class HeterWrapper { framework::proto::VarType::Type ToVarType(VariableMessage::Type type); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var, - platform::Place place, cudaStream_t stream); + platform::Place place, gpuStream_t stream); #endif void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var, platform::Place place); diff --git a/paddle/fluid/framework/fleet/nccl_wrapper.cc b/paddle/fluid/framework/fleet/nccl_wrapper.cc index 8ba94f4fd7a79646ba69732371ed01456c6be41f..3ac95632de6bf63d6420054529df37238cd3c24b 100644 --- a/paddle/fluid/framework/fleet/nccl_wrapper.cc +++ b/paddle/fluid/framework/fleet/nccl_wrapper.cc @@ -21,7 +21,7 @@ std::shared_ptr NCCLWrapper::s_instance_ = NULL; bool NCCLWrapper::is_initialized_ = false; void NCCLWrapper::InitNCCL() { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclCommInitRank( &(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_, nccl_info_.my_global_rank_)); @@ -30,14 +30,14 @@ void NCCLWrapper::InitNCCL() { } void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) nccl_info_.nccl_id_ = nccl_info.nccl_id_; #endif return; } NCCLInfo NCCLWrapper::GetNCCLId() { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_))); #endif @@ -46,19 +46,23 @@ NCCLInfo NCCLWrapper::GetNCCLId() { void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank, const int ranks) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) nccl_info_.local_rank_ = local_rank; nccl_info_.my_global_rank_ = global_rank; nccl_info_.global_ranks_ = ranks; platform::SetDeviceId(local_rank); +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&(nccl_info_.stream_))); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&(nccl_info_.stream_))); +#endif #endif return; } void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope, const std::vector& var_names) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) for (auto& name : var_names) { auto var = scope.FindVar(name); LoDTensor* tensor = var->GetMutable(); @@ -66,7 +70,11 @@ void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope, PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( reinterpret_cast(tensor->data()), total_size, ncclFloat, root_rank, nccl_info_.comm_, nccl_info_.stream_)); +#ifdef PADDLE_WITH_RCCL + hipStreamSynchronize(nccl_info_.stream_); +#else cudaStreamSynchronize(nccl_info_.stream_); +#endif } #endif return; diff --git a/paddle/fluid/framework/fleet/nccl_wrapper.h b/paddle/fluid/framework/fleet/nccl_wrapper.h index 3725a225dbecfec0b6c6b934b259d895eb09c9cb..e12bfd8b27dd6634127d1b8716587565eeec0a49 100644 --- a/paddle/fluid/framework/fleet/nccl_wrapper.h +++ b/paddle/fluid/framework/fleet/nccl_wrapper.h @@ -25,9 +25,12 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable_helper.h" -#if defined(PADDLE_WITH_NCCL) +#ifdef PADDLE_WITH_NCCL #include "paddle/fluid/platform/dynload/nccl.h" #endif +#ifdef PADDLE_WITH_RCCL +#include "paddle/fluid/platform/dynload/rccl.h" +#endif #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN namespace paddle { @@ -48,10 +51,10 @@ class NCCLInfo { int local_rank_; int global_ranks_; int my_global_rank_; -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) ncclUniqueId nccl_id_; ncclComm_t comm_; - cudaStream_t stream_; + gpuStream_t stream_; #endif }; diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 32eb9418b659b9b5b35d6de081dc4f2b6fc733f5..728188e7022821d0fc0244aa460814bf3d771141 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -26,7 +26,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ + (defined PADDLE_WITH_PSLIB) #include #include diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index 98e0028e42758a24e2b301b4f33d072d19c9f9ed..8a536fe0b828db9879b93557a1fedc7add2054a7 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -14,7 +14,8 @@ limitations under the License. */ #pragma once -#if (defined PADDLE_WITH_NCCL) && (defined PADDLE_WITH_PSLIB) +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ + (defined PADDLE_WITH_PSLIB) #include #include