From 572c466d19073fd278f91387cdb4825f7a787333 Mon Sep 17 00:00:00 2001 From: WangXi Date: Tue, 19 Jan 2021 19:21:55 +0800 Subject: [PATCH] [Prepare for MultiProcess xpu] unified gen nccl id, refine imperative reducer (#30455) --- paddle/fluid/imperative/all_reduce.cc | 16 ++ paddle/fluid/imperative/all_reduce.h | 15 -- paddle/fluid/imperative/nccl_context.cc | 230 ++++++------------ paddle/fluid/imperative/nccl_context.h | 77 ++---- paddle/fluid/imperative/parallel_context.h | 75 ++++++ paddle/fluid/imperative/reducer.cc | 214 ++++++++++------ paddle/fluid/imperative/reducer.h | 72 ++---- .../imperative/tests/nccl_context_test.cc | 2 + paddle/fluid/imperative/tests/test_group.cc | 103 ++++++++ .../fluid/operators/collective/CMakeLists.txt | 5 +- .../operators/collective/c_gen_nccl_id_op.cc | 33 ++- .../operators/collective/gen_nccl_id_op.cc | 51 +++- paddle/fluid/platform/CMakeLists.txt | 2 +- .../gen_comm_id_helper.cc} | 93 +++---- .../gen_comm_id_helper.h} | 32 ++- paddle/fluid/platform/nccl_helper.h | 2 +- .../tests/unittests/test_gen_nccl_id_op.py | 21 +- 17 files changed, 599 insertions(+), 444 deletions(-) create mode 100644 paddle/fluid/imperative/parallel_context.h rename paddle/fluid/{operators/collective/gen_nccl_id_op_helper.cc => platform/gen_comm_id_helper.cc} (79%) rename paddle/fluid/{operators/collective/gen_nccl_id_op_helper.h => platform/gen_comm_id_helper.h} (50%) diff --git a/paddle/fluid/imperative/all_reduce.cc b/paddle/fluid/imperative/all_reduce.cc index 57b620ff4b5..3321800aa19 100644 --- a/paddle/fluid/imperative/all_reduce.cc +++ b/paddle/fluid/imperative/all_reduce.cc @@ -16,8 +16,24 @@ #include "paddle/fluid/imperative/all_reduce.h" +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/imperative/nccl_context.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/nccl_helper.h" +#include "paddle/fluid/string/string_helper.h" + namespace paddle { namespace imperative { + static const platform::Place &GetVarPlace(const framework::Variable &src) { if (src.IsType()) { return src.Get().place(); diff --git a/paddle/fluid/imperative/all_reduce.h b/paddle/fluid/imperative/all_reduce.h index 7c6b77167b6..2185c19b696 100644 --- a/paddle/fluid/imperative/all_reduce.h +++ b/paddle/fluid/imperative/all_reduce.h @@ -16,21 +16,6 @@ #ifdef PADDLE_WITH_NCCL -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/imperative/nccl_context.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/nccl_helper.h" -#include "paddle/fluid/string/string_helper.h" - namespace paddle { namespace framework { class Variable; diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index 7c9718e78a4..04d2a148ea3 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -14,175 +14,54 @@ #include "paddle/fluid/imperative/nccl_context.h" -namespace paddle { -namespace imperative { -#if defined(PADDLE_WITH_NCCL) -void NCCLParallelContext::RecvNCCLID( - const std::string &ep, - std::vector &nccl_ids) { // NOLINT - int nrings = nccl_ids.size(); - auto addr = paddle::string::Split(ep, ':'); - PADDLE_ENFORCE_EQ( - addr.size(), 2UL, - platform::errors::InvalidArgument( - "The endpoint should contain host and port, but got %s.", ep)); - std::string host = addr[0]; - int port = std::stoi(addr[1]); - - int server_fd, new_socket; - struct sockaddr_in address; - int addrlen = sizeof(address); - char buffer[1024] = {0}; - int opt = 0; - // creating socket fd - if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) { - PADDLE_THROW( - platform::errors::Unavailable("Create server file descriptor failed.")); - } +#include +#include +#include - if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) { - PADDLE_THROW(platform::errors::Unavailable("Set socket options failed.")); - } - - address.sin_family = AF_INET; - address.sin_addr.s_addr = INADDR_ANY; - address.sin_port = htons(port); - - int try_times = 0; - int retry_time = 0; - while (true) { - if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) { - retry_time = 3 * (try_times + 1); - LOG(WARNING) << "Socket bind worker " << ep - << (try_times < 9 - ? " failed, try again after " + - std::to_string(retry_time) + " seconds." - : " failed, try again after " + - std::to_string(retry_time) + - " seconds. Bind on endpoint " + ep + - " failed. Please confirm whether the " - "communication port or GPU card is occupied."); - std::this_thread::sleep_for(std::chrono::seconds(retry_time)); - ++try_times; - continue; - } - break; - } - - VLOG(3) << "listening on: " << ep; - if (listen(server_fd, 3) < 0) { - PADDLE_THROW(platform::errors::Unavailable( - "Listen on server file descriptor failed.")); - } - - if ((new_socket = - accept(server_fd, reinterpret_cast(&address), - reinterpret_cast(&addrlen))) < 0) { - PADDLE_THROW(platform::errors::Unavailable( - "Accept the new socket file descriptor failed.")); - } - - if (read(new_socket, buffer, 1024) < 0) { - PADDLE_THROW(platform::errors::Unavailable("Read from socket failed.")); - } - - VLOG(3) << "recevived the ncclUniqueId"; - - memcpy(&nccl_ids[0], buffer, nrings * NCCL_UNIQUE_ID_BYTES); - - VLOG(3) << "closing the socket server: " << ep; - close(server_fd); -} - -void NCCLParallelContext::SendNCCLID( - const std::string &ep, const std::vector &nccl_ids) { - int nrings = nccl_ids.size(); - auto addr = paddle::string::Split(ep, ':'); - PADDLE_ENFORCE_EQ( - addr.size(), 2UL, - platform::errors::InvalidArgument( - "The endpoint should contain host and port, but got %s.", ep)); - std::string host = addr[0]; - int port = std::stoi(addr[1]); - int sock = 0; - struct sockaddr_in serv_addr; - char buffer[1024] = {0}; - - memcpy(buffer, &nccl_ids[0], nrings * NCCL_UNIQUE_ID_BYTES); - - if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - PADDLE_THROW(platform::errors::Unavailable("Create socket failed.")); - } - - memset(&serv_addr, '0', sizeof(serv_addr)); - serv_addr.sin_family = AF_INET; - serv_addr.sin_port = htons(port); +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/imperative/all_reduce.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/dynload/nccl.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" +#endif - char *ip = NULL; - struct hostent *hp; - if ((hp = gethostbyname(host.c_str())) == NULL) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Fail to get host by name %s.", host)); - } - int i = 0; - while (hp->h_addr_list[i] != NULL) { - ip = inet_ntoa(*(struct in_addr *)hp->h_addr_list[i]); - VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip; - break; - } - if (inet_pton(AF_INET, ip, &serv_addr.sin_addr) <= 0) { - PADDLE_THROW(platform::errors::Unavailable("Open address %s failed.", ep)); - } +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" +#include "paddle/fluid/string/string_helper.h" - int try_times = 0; - int retry_time = 0; - while (true) { - if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) { - retry_time = 3 * (try_times + 1); - LOG(WARNING) - << "Socket connect worker " << ep - << (try_times < 9 - ? " failed, try again after " + std::to_string(retry_time) + - " seconds." - : " failed, try again after " + std::to_string(retry_time) + - " seconds. Maybe that some process is occupied the " - "GPUs of this node now, and you should kill those " - "process manually."); - std::this_thread::sleep_for(std::chrono::seconds(retry_time)); - ++try_times; - continue; - } - VLOG(3) << "sending the ncclUniqueId to " << ep; - send(sock, buffer, NCCL_UNIQUE_ID_BYTES * nrings, 0); - break; - } - close(sock); -} +namespace paddle { +namespace imperative { +#if defined(PADDLE_WITH_NCCL) void NCCLParallelContext::BcastNCCLId( std::vector &nccl_ids, // NOLINT int root) { if (strategy_.local_rank_ == root) { - for (auto ep : strategy_.trainer_endpoints_) { - if (ep != strategy_.current_endpoint_) SendNCCLID(ep, nccl_ids); + std::vector other_trainers; + for (auto &ep : strategy_.trainer_endpoints_) { + if (ep != strategy_.current_endpoint_) { + other_trainers.push_back(ep); + } } + platform::SendBroadCastCommID(other_trainers, &nccl_ids); } else { - RecvNCCLID(strategy_.current_endpoint_, nccl_ids); + platform::RecvBroadCastCommID(strategy_.current_endpoint_, &nccl_ids); } } void NCCLParallelContext::Init() { std::vector nccl_ids; nccl_ids.resize(strategy_.nrings_); + if (strategy_.local_rank_ == 0) { // generate the unique ncclid on the root worker for (size_t i = 0; i < nccl_ids.size(); ++i) { platform::dynload::ncclGetUniqueId(&nccl_ids[i]); } - BcastNCCLId(nccl_ids, 0); - } else { - BcastNCCLId(nccl_ids, 0); } + BcastNCCLId(nccl_ids, 0); int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) { @@ -193,6 +72,12 @@ void NCCLParallelContext::Init() { platform::NCCLCommContext::Instance().CreateNCCLComm( &nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id, ring_id); + + compute_events_.emplace_back( + platform::CudaEventResourcePool::Instance().New( + BOOST_GET_CONST(platform::CUDAPlace, place_).device)); + comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New( + BOOST_GET_CONST(platform::CUDAPlace, place_).device)); } } @@ -206,11 +91,54 @@ void NCCLParallelContext::AllReduceByStream(const framework::Variable &src, AllReduce(src, dst, strategy_, ring_id, use_calc_stream); } -paddle::platform::CUDADeviceContext *NCCLParallelContext::GetDeviceContext( +paddle::platform::DeviceContext *NCCLParallelContext::GetDeviceContext( int ring_id) { - return platform::NCCLCommContext::Instance() - .Get(ring_id, place_) - ->dev_context(); + return static_cast( + platform::NCCLCommContext::Instance() + .Get(ring_id, place_) + ->dev_context()); +} + +void NCCLParallelContext::WaitCompute(int ring_id) { + PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange( + "ring id must >= 0, but got %d", ring_id)); + PADDLE_ENFORCE_LT(ring_id, compute_events_.size(), + platform::errors::OutOfRange( + "ring id must < compute events size," + "but got ring id = %d, compute events size = %d", + ring_id, compute_events_.size())); + + auto compute_stream = static_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->stream(); + auto comm_stream = + platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream(); + auto event = compute_events_[ring_id].get(); + + // compute_stream-->event-->comm_stream + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, compute_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0)); +} + +void NCCLParallelContext::WaitComm(int ring_id) { + PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange( + "ring id must >= 0, but got %d", ring_id)); + PADDLE_ENFORCE_LT(ring_id, comm_events_.size(), + platform::errors::OutOfRange( + "ring id must < comm events size," + "but got ring id = %d, comm events size = %d", + ring_id, comm_events_.size())); + + auto compute_stream = static_cast( + platform::DeviceContextPool::Instance().Get(place_)) + ->stream(); + auto comm_stream = + platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream(); + auto event = comm_events_[ring_id].get(); + + // comm_stream-->event-->compute_stream + PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, comm_stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0)); } #endif diff --git a/paddle/fluid/imperative/nccl_context.h b/paddle/fluid/imperative/nccl_context.h index b0e857a8df4..8dec0e216c5 100644 --- a/paddle/fluid/imperative/nccl_context.h +++ b/paddle/fluid/imperative/nccl_context.h @@ -13,73 +13,20 @@ // limitations under the License. #pragma once -// network header files -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) -#include -#include -#include -#include -#include -#endif - +#include #include -#include #include -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/platform/device_context.h" - #if defined(PADDLE_WITH_NCCL) -#include "paddle/fluid/imperative/all_reduce.h" +#include "paddle/fluid/platform/cuda_resource_pool.h" #include "paddle/fluid/platform/dynload/nccl.h" -#include "paddle/fluid/platform/nccl_helper.h" #endif -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/split.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/imperative/parallel_context.h" namespace paddle { namespace imperative { -struct ParallelStrategy { - int nranks_{1}; - int local_rank_{0}; - std::vector trainer_endpoints_{}; - std::string current_endpoint_{""}; - // TODO(shenliang03): support multi stream communication - int nrings_{1}; -}; - -class ParallelContext { - public: - explicit ParallelContext(const ParallelStrategy& strategy, - const platform::Place& place) - : strategy_(strategy), place_(place) {} - - virtual ~ParallelContext() {} - - virtual void Init() = 0; - - virtual void AllReduceByStream(const framework::Variable& src, - framework::Variable* dst, int ring_id = 0, - bool use_calc_stream = false) = 0; -#if defined(PADDLE_WITH_NCCL) - virtual paddle::platform::CUDADeviceContext* GetDeviceContext( - int ring_id) = 0; -#endif - - inline int GetNRings() { return strategy_.nrings_; } - - protected: - ParallelStrategy strategy_; - platform::Place place_; -}; - #if defined(PADDLE_WITH_NCCL) class NCCLParallelContext : public ParallelContext { public: @@ -87,7 +34,7 @@ class NCCLParallelContext : public ParallelContext { const platform::Place& place) : ParallelContext(strategy, place) {} - ~NCCLParallelContext() {} + ~NCCLParallelContext() override = default; void BcastNCCLId(std::vector& nccl_ids, int root); // NOLINT @@ -97,14 +44,18 @@ class NCCLParallelContext : public ParallelContext { framework::Variable* dst, int ring_id, bool use_calc_stream) override; - paddle::platform::CUDADeviceContext* GetDeviceContext(int ring_id) override; + paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override; + + void WaitCompute(int ring_id) override; + + void WaitComm(int ring_id) override; - protected: - void RecvNCCLID(const std::string& endpoint, - std::vector& nccl_ids); // NOLINT + private: + // used for comm wait compute, compute_stream-->event-->comm_stream[ring_id] + std::vector> compute_events_; - void SendNCCLID(const std::string& endpoint, - const std::vector& nccl_ids); + // used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream + std::vector> comm_events_; }; #endif diff --git a/paddle/fluid/imperative/parallel_context.h b/paddle/fluid/imperative/parallel_context.h new file mode 100644 index 00000000000..55af297e493 --- /dev/null +++ b/paddle/fluid/imperative/parallel_context.h @@ -0,0 +1,75 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace platform { +class DeviceContext; +} // namespace platform + +namespace framework { +class Variable; +} // namespace framework + +} // namespace paddle + +namespace paddle { +namespace imperative { + +struct ParallelStrategy { + int nranks_{1}; + int local_rank_{0}; + std::vector trainer_endpoints_{}; + std::string current_endpoint_{""}; + int nrings_{1}; +}; + +class ParallelContext { + public: + explicit ParallelContext(const ParallelStrategy& strategy, + const platform::Place& place) + : strategy_(strategy), place_(place) {} + + virtual ~ParallelContext() = default; + + virtual void Init() = 0; + + virtual void AllReduceByStream(const framework::Variable& src, + framework::Variable* dst, int ring_id, + bool use_calc_stream) = 0; + + virtual paddle::platform::DeviceContext* GetDeviceContext(int ring_id) = 0; + + // comm_stream[ring_id] wait compute_stream. + // if CPU, should do nothing. + virtual void WaitCompute(int ring_id) = 0; + + // compute_stream wait comm_stream[ring_id] + // if CPU, should do nothing. + virtual void WaitComm(int ring_id) = 0; + + inline int GetNRings() const { return strategy_.nrings_; } + + protected: + ParallelStrategy strategy_; + platform::Place place_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 10e8b398318..6801cac9526 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -14,60 +14,170 @@ #include "paddle/fluid/imperative/reducer.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/op_base.h" +#include "paddle/fluid/imperative/variable_wrapper.h" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/string/string_helper.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/operators/math/concat_and_split.h" +#include "paddle/fluid/operators/strided_memcpy.h" +#endif + +#include "paddle/fluid/imperative/parallel_context.h" + namespace paddle { namespace imperative { #if defined(PADDLE_WITH_NCCL) std::shared_ptr Reducer::s_instance_ = NULL; -// context is used to select the stream for concat -void Group::ConcatTensors(const platform::CUDADeviceContext &context) { - VLOG(3) << "Before concat, set output tensor size is " << all_length_; - auto tensor = dense_contents_.GetMutable(); - tensor->Resize(framework::make_ddim({all_length_})) - .mutable_data(context.GetPlace(), dtype_); +template +static void ConcatTensorsForAllReduce( + const DeviceContext &context, + const std::vector &dense_tensors_, + framework::Variable *p_dense_contents) { + operators::math::ConcatFunctor concat_functor_; + concat_functor_(context, dense_tensors_, 0, + p_dense_contents->GetMutable()); +} + +template +static void SplitTensorsForAllReduce( + const DeviceContext &context, framework::Variable *p_dense_contents, + std::vector *p_dense_tensors) { + auto *in = p_dense_contents->GetMutable(); + std::vector outs; + std::vector shape_refer; - switch (dtype_) { + outs.reserve(p_dense_tensors->size()); + shape_refer.reserve(p_dense_tensors->size()); + + for (auto &tensor : *p_dense_tensors) { + outs.emplace_back(&tensor); + shape_refer.emplace_back(&tensor); + } + // Sometimes direct copies will be faster + if (p_dense_tensors->size() < 10) { + operators::StridedMemcpyWithAxis0(context, *in, shape_refer, &outs); + } else { + operators::math::SplitFunctor split_functor_; + split_functor_(context, *in, shape_refer, 0, &outs); + } +} + +// context is used to select the stream for concat +template +static void ConcatTensorsWithType( + const DeviceContext &context, + const std::vector &dense_tensors_, + framework::Variable *p_dense_contents, + framework::proto::VarType::Type type) { + switch (type) { case framework::proto::VarType::FP16: - ConcatTensorsForAllReduce(context, dense_tensors_, - &dense_contents_); + ConcatTensorsForAllReduce( + context, dense_tensors_, p_dense_contents); break; case framework::proto::VarType::FP32: - ConcatTensorsForAllReduce(context, dense_tensors_, - &dense_contents_); + ConcatTensorsForAllReduce(context, dense_tensors_, + p_dense_contents); break; case framework::proto::VarType::FP64: - ConcatTensorsForAllReduce(context, dense_tensors_, - &dense_contents_); + ConcatTensorsForAllReduce(context, dense_tensors_, + p_dense_contents); break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors for " "allreduce.", - framework::DataTypeToString(dtype_))); + framework::DataTypeToString(type))); } } // context is used to select the stream for split -void Group::SplitTensors(const platform::CUDADeviceContext &context) { - switch (dtype_) { +template +static void SplitTensorsWithType( + const DeviceContext &context, framework::Variable *p_dense_contents, + std::vector *p_dense_tensors, + framework::proto::VarType::Type type) { + switch (type) { case framework::proto::VarType::FP16: - SplitTensorsForAllReduce(context, &dense_contents_, - &dense_tensors_); + SplitTensorsForAllReduce( + context, p_dense_contents, p_dense_tensors); break; case framework::proto::VarType::FP32: - SplitTensorsForAllReduce(context, &dense_contents_, - &dense_tensors_); + SplitTensorsForAllReduce(context, p_dense_contents, + p_dense_tensors); break; case framework::proto::VarType::FP64: - SplitTensorsForAllReduce(context, &dense_contents_, - &dense_tensors_); + SplitTensorsForAllReduce(context, p_dense_contents, + p_dense_tensors); break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors for " "allreduce.", - framework::DataTypeToString(dtype_))); + framework::DataTypeToString(type))); + } +} + +void Group::ConcatTensors(const platform::DeviceContext &context) { + VLOG(3) << "Before concat, set output tensor size is " << all_length_; + auto tensor = dense_contents_.GetMutable(); + tensor->Resize(framework::make_ddim({all_length_})) + .mutable_data(context.GetPlace(), dtype_); + + auto place = context.GetPlace(); + if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_NCCL + ConcatTensorsWithType( + static_cast(context), + dense_tensors_, &dense_contents_, dtype_); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't concat grad tensors since it's not compiled with NCCL," + "Please recompile or reinstall Paddle with NCCL support.")); +#endif + } else if (platform::is_cpu_place(place)) { + ConcatTensorsWithType( + static_cast(context), + dense_tensors_, &dense_contents_, dtype_); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Concat grad tensor not supported on place (%s)", place)); + } +} + +void Group::SplitTensors(const platform::DeviceContext &context) { + auto place = context.GetPlace(); + if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_NCCL + SplitTensorsWithType( + static_cast(context), + &dense_contents_, &dense_tensors_, dtype_); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't split grad tensor since it's not compiled with NCCL," + "Please recompile or reinstall Paddle with NCCL support.")); +#endif + } else if (platform::is_cpu_place(place)) { + SplitTensorsWithType( + static_cast(context), + &dense_contents_, &dense_tensors_, dtype_); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Split grad tensor not supported on place (%s)", place)); } } @@ -115,44 +225,13 @@ Reducer::Reducer(const std::vector> &vars, }))); var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index; } - // create streams - compute_stream_ = static_cast( - platform::DeviceContextPool::Instance().Get(place_)) - ->stream(); - for (int i = 0; i < nrings_; ++i) { - comm_streams_.emplace_back( - platform::NCCLCommContext::Instance().Get(i, place_)->stream()); - comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New( - BOOST_GET_CONST(platform::CUDAPlace, place_).device)); - } - CreateGroupEvents(group_indices.size()); std::call_once(once_flag_, []() { std::atexit([]() { Reducer::GetInstance()->ReleaseReducer(); }); }); } -void Reducer::ReleaseReducer() { - for (auto &event : group_events_) { - event.reset(); - } - for (auto &event : comm_events_) { - event.reset(); - } -} - -void Reducer::CreateGroupEvents(int group_num) { - // release old events - for (auto &event : group_events_) { - event.reset(); - } - group_events_.clear(); - group_events_.resize(group_num); - for (auto &event : group_events_) { - event = platform::CudaEventResourcePool::Instance().New( - BOOST_GET_CONST(platform::CUDAPlace, place_).device); - } -} +void Reducer::ReleaseReducer() { parallel_ctx_.reset(); } void Reducer::InitializeDenseGroups( const std::vector &variable_indices_, Group *p_group) { @@ -455,18 +534,18 @@ void Reducer::MarkGroupReady(size_t group_index) { return; } - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventRecord(group_events_[group_index].get(), compute_stream_)); - - for (int i = 0; i < nrings_; ++i) { - PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent( - comm_streams_[i], group_events_[group_index].get(), 0)); - } - for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; ++next_group_) { auto &group = groups_[next_group_]; int run_order = next_group_ % nrings_; + + // For CUDA or XPU, compute_stream --> comm_stream. + // For CPU, do nothing. + // NOTE. Because concat uses the comm_stream, + // so we expose WaitCompute() interface and call + // it here. + parallel_ctx_->WaitCompute(run_order); + if (group.is_sparse_) { if (group.sparse_contents_ != nullptr) { VLOG(3) << "sparse group [" << next_group_ @@ -526,20 +605,13 @@ void Reducer::FinalizeBackward() { all_group_ready_ = false; // Must prevent compute_stream_ starting until all comm streams have finished for (int i = 0; i < nrings_; ++i) { - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaEventRecord(comm_events_[i].get(), comm_streams_[i])); - } - for (int i = 0; i < nrings_; ++i) { - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaStreamWaitEvent(compute_stream_, comm_events_[i].get(), 0)); + parallel_ctx_->WaitComm(i); } if (NeedRebuildGroup()) { VLOG(3) << "Start rebuilding the groups"; auto rebuild_group_indices = RebuildGruops(); - auto rebuild_group_number = rebuild_group_indices.size(); group_indices_ = std::move(rebuild_group_indices); - CreateGroupEvents(rebuild_group_number); InitializeGroups(group_indices_); } diff --git a/paddle/fluid/imperative/reducer.h b/paddle/fluid/imperative/reducer.h index 62b61616026..9bb528bbdef 100644 --- a/paddle/fluid/imperative/reducer.h +++ b/paddle/fluid/imperative/reducer.h @@ -24,60 +24,27 @@ #include #include #include + #include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/imperative/layer.h" -#include "paddle/fluid/imperative/op_base.h" -#include "paddle/fluid/imperative/variable_wrapper.h" -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable.h" -#if defined(PADDLE_WITH_NCCL) -#include "paddle/fluid/imperative/all_reduce.h" -#include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" -#include "paddle/fluid/platform/cuda_resource_pool.h" -#endif +namespace paddle { +namespace platform { +class DeviceContext; +} // namespace platform + +namespace imperative { +class ParallelContext; +class VarBase; +class VariableWrapper; +} // namespace imperative +} // namespace paddle namespace paddle { namespace imperative { #if defined(PADDLE_WITH_NCCL) -template -void ConcatTensorsForAllReduce( - const platform::CUDADeviceContext& context, - const std::vector& dense_tensors_, - framework::Variable* p_dense_contents) { - operators::math::ConcatFunctor - concat_functor_; - concat_functor_(context, dense_tensors_, 0, - p_dense_contents->GetMutable()); -} - -template -void SplitTensorsForAllReduce(const platform::CUDADeviceContext& context, - framework::Variable* p_dense_contents, - std::vector* p_dense_tensors) { - auto* in = p_dense_contents->GetMutable(); - std::vector outs; - std::vector shape_refer; - - outs.reserve(p_dense_tensors->size()); - shape_refer.reserve(p_dense_tensors->size()); - - for (auto& tensor : *p_dense_tensors) { - outs.emplace_back(&tensor); - shape_refer.emplace_back(&tensor); - } - // Sometimes direct copies will be faster - if (p_dense_tensors->size() < 10) { - operators::StridedMemcpyWithAxis0(context, *in, shape_refer, &outs); - } else { - operators::math::SplitFunctor - split_functor_; - split_functor_(context, *in, shape_refer, 0, &outs); - } -} - class Group { public: // Here, we use dense_contents_ & sparse_contents_ to @@ -104,10 +71,10 @@ class Group { framework::proto::VarType::Type dtype_; // context is used to select the stream for concat - void ConcatTensors(const platform::CUDADeviceContext& context); + void ConcatTensors(const platform::DeviceContext& context); // context is used to select the stream for split - void SplitTensors(const platform::CUDADeviceContext& context); + void SplitTensors(const platform::DeviceContext& context); friend std::ostream& operator<<(std::ostream&, const Group&); }; @@ -155,8 +122,6 @@ class Reducer { std::vector> RebuildGruops(); - void CreateGroupEvents(int group_num); - inline bool NeedRebuildGroup() { return !has_rebuilt_group_; } // Reducer Singleton @@ -193,11 +158,6 @@ class Reducer { std::shared_ptr parallel_ctx_; std::vector variable_locators_; - // Following variables are to help sync stream - std::vector> group_events_; - std::vector> comm_events_; - cudaStream_t compute_stream_; - std::vector comm_streams_; int nrings_ = 1; // Following variables are to help rebuild group diff --git a/paddle/fluid/imperative/tests/nccl_context_test.cc b/paddle/fluid/imperative/tests/nccl_context_test.cc index 649746a5bd2..ab4d4add069 100644 --- a/paddle/fluid/imperative/tests/nccl_context_test.cc +++ b/paddle/fluid/imperative/tests/nccl_context_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include // NOLINT + #include "paddle/fluid/imperative/nccl_context.h" #include "gtest/gtest.h" diff --git a/paddle/fluid/imperative/tests/test_group.cc b/paddle/fluid/imperative/tests/test_group.cc index 243f78704e7..146ed9396b9 100644 --- a/paddle/fluid/imperative/tests/test_group.cc +++ b/paddle/fluid/imperative/tests/test_group.cc @@ -60,6 +60,109 @@ TEST(TestGroup, TestPrintGroupMessage) { ASSERT_STREQ(stream2.str().c_str(), head.c_str()); } +template +void GroupConcatSplit(Place place, size_t size) { + platform::CPUPlace cpu_place; + Group group; + + // [[0.0], [0.0, 1.0], [0.0, 1.0, 2.0] .. ] + std::vector vars; + vars.resize(size); + for (size_t i = 0; i < size; ++i) { + auto len = i + 1; + auto* tensor = vars[i].GetMutable(); + tensor->Resize({static_cast(len)}); + auto* data = tensor->mutable_data(place); + + std::vector value; + for (size_t j = 0; j < len; ++j) { + value.push_back(static_cast(1.0 * j)); + } + + if (std::is_same::value) { + paddle::memory::Copy(place, data, cpu_place, value.data(), + sizeof(T) * value.size(), 0); + } else { + paddle::memory::Copy(place, data, cpu_place, value.data(), + sizeof(T) * value.size()); + } + + framework::Tensor tmp; + tmp.ShareDataWith(*tensor).Resize({static_cast(len)}); + group.dense_tensors_.push_back(std::move(tmp)); + group.all_length_ += len; + group.dtype_ = tensor->type(); + } + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + + { // concat + group.ConcatTensors(*dev_ctx); + + auto* tensor = group.dense_contents_.GetMutable(); + framework::Tensor tmp; + framework::TensorCopySync(*tensor, cpu_place, &tmp); + auto* data = tmp.data(); + size_t offset = 0; + for (size_t i = 0; i < size; ++i) { + auto len = i + 1; + for (size_t j = 0; j < len; ++j) { + EXPECT_EQ(data[offset + j], static_cast(1.0 * j)); + // [[-0.0], [-0.0, -1.0], [-0.0, -1.0, -2.0] .. ] + data[offset + j] = -data[offset + j]; + } + offset += len; + } + framework::TensorCopySync(tmp, place, tensor); + } + + { // split + group.SplitTensors(*dev_ctx); + for (size_t i = 0; i < size; ++i) { + auto len = i + 1; + auto& tensor = group.dense_tensors_[i]; + framework::Tensor tmp; + framework::TensorCopySync(tensor, cpu_place, &tmp); + auto* data = tmp.data(); + + for (size_t j = 0; j < len; ++j) { + EXPECT_EQ(data[j], static_cast(-1.0 * j)); + } + } + } +} + +TEST(TestGroup, TestConcatSplit) { + platform::CUDAPlace cuda_place(0); + platform::CPUPlace cpu_place; + + int size = 3; + GroupConcatSplit(cpu_place, size); + GroupConcatSplit(cpu_place, size); + GroupConcatSplit(cpu_place, size); + + GroupConcatSplit(cuda_place, size); + GroupConcatSplit(cuda_place, size); + GroupConcatSplit(cuda_place, size); + + size = 15; + GroupConcatSplit(cpu_place, size); + GroupConcatSplit(cpu_place, size); + GroupConcatSplit(cpu_place, size); + + GroupConcatSplit(cuda_place, size); + GroupConcatSplit(cuda_place, size); + GroupConcatSplit(cuda_place, size); +} + +TEST(TestGroup, TestConcatSplitException) { + platform::CUDAPinnedPlace place; + + int size = 3; + ASSERT_ANY_THROW(GroupConcatSplit(place, size)); +} #endif } // namespace imperative diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index 09d4adee947..2b3c80839f2 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -15,9 +15,8 @@ register_operators(EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DE if(WITH_NCCL) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) - cc_library(gen_nccl_id_op_helper SRCS gen_nccl_id_op_helper.cc DEPS nccl_common) - op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} gen_nccl_id_op_helper) - op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} gen_nccl_id_op_helper) + op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) + op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) endif() if(WITH_GLOO) diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index 26f639ebc98..9e540112b84 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -23,11 +23,32 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" namespace paddle { namespace operators { +static void GenNCCLID(std::vector* nccl_ids) { + for (size_t i = 0; i < nccl_ids->size(); ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclGetUniqueId(&(*nccl_ids)[i])); + } +} + +static void CopyNCCLIDToVar(const std::vector& nccl_ids, + std::function func, + const framework::Scope& scope) { + for (size_t i = 0; i < nccl_ids.size(); ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto nccl_id = var->GetMutable(); + memcpy(nccl_id, &nccl_ids[i], sizeof(ncclUniqueId)); + } +} + class CGenNCCLIdOp : public framework::OperatorBase { public: CGenNCCLIdOp(const std::string& type, @@ -45,14 +66,20 @@ class CGenNCCLIdOp : public framework::OperatorBase { return Output("Out"); }; + std::vector nccl_ids; + nccl_ids.resize(1); + if (rank == 0) { + GenNCCLID(&nccl_ids); std::vector endpoint_list = Attr>("other_endpoints"); - SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); + platform::SendBroadCastCommID(endpoint_list, &nccl_ids); } else { std::string endpoint = Attr("endpoint"); - RecvBroadCastNCCLID(endpoint, 1, func, local_scope); + platform::RecvBroadCastCommID(endpoint, &nccl_ids); } + + CopyNCCLIDToVar(nccl_ids, func, scope); scope.DeleteScope(&local_scope); } }; diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op.cc b/paddle/fluid/operators/collective/gen_nccl_id_op.cc index a985da5d5d0..85fd9452bff 100644 --- a/paddle/fluid/operators/collective/gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/gen_nccl_id_op.cc @@ -27,11 +27,32 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/split.h" -#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" namespace paddle { namespace operators { +static void GenNCCLID(std::vector* nccl_ids) { + for (size_t i = 0; i < nccl_ids->size(); ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclGetUniqueId(&(*nccl_ids)[i])); + } +} + +static void CopyNCCLIDToVar(const std::vector& nccl_ids, + std::function func, + const framework::Scope& scope) { + for (size_t i = 0; i < nccl_ids.size(); ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto nccl_id = var->GetMutable(); + memcpy(nccl_id, &nccl_ids[i], sizeof(ncclUniqueId)); + } +} + class GenNCCLIdOp : public framework::OperatorBase { public: GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, @@ -98,19 +119,25 @@ class GenNCCLIdOp : public framework::OperatorBase { << ", trainers:" << ss.str(); int server_fd = -1; + std::vector nccl_ids; + nccl_ids.resize(nccl_comm_num); /// 1. init flat std::function func = platform::GetFlatNCCLVarName; + // broadcast unique id if (trainer_id == 0) { + GenNCCLID(&nccl_ids); + // server endpoints std::vector flat_endpoints; flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1, trainers.end()); - SendBroadCastNCCLID(flat_endpoints, nccl_comm_num, func, scope); + platform::SendBroadCastCommID(flat_endpoints, &nccl_ids); } else { - server_fd = CreateListenSocket(endpoint); - RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + server_fd = platform::CreateListenSocket(endpoint); + platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); } + CopyNCCLIDToVar(nccl_ids, func, scope); /// 2. hierarchical inter ncclid func = platform::GetHierarchicalInterNCCLVarName; @@ -127,10 +154,13 @@ class GenNCCLIdOp : public framework::OperatorBase { } VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str(); - SendBroadCastNCCLID(inter_endpoints, nccl_comm_num, func, scope); + GenNCCLID(&nccl_ids); + platform::SendBroadCastCommID(inter_endpoints, &nccl_ids); + CopyNCCLIDToVar(nccl_ids, func, scope); } else if (inter_trainer_id > 0) { VLOG(1) << "Hierarchical inter ring"; - RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); + CopyNCCLIDToVar(nccl_ids, func, scope); } /// 3. hierarchical exter ncclid @@ -146,15 +176,18 @@ class GenNCCLIdOp : public framework::OperatorBase { } VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str(); - SendBroadCastNCCLID(exter_endpoints, nccl_comm_num, func, scope); + GenNCCLID(&nccl_ids); + platform::SendBroadCastCommID(exter_endpoints, &nccl_ids); + CopyNCCLIDToVar(nccl_ids, func, scope); } else if (exter_trainer_id > 0) { VLOG(1) << "Hierarchical exter ring"; - RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); + platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); + CopyNCCLIDToVar(nccl_ids, func, scope); } // close socket server if (trainer_id != 0) { - CloseSocket(server_fd); + platform::CloseSocket(server_fd); } } }; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 6ae1f52ec03..f2a8309f00c 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -101,7 +101,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS}) -cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) +cc_library(collective_helper SRCS collective_helper.cc gen_comm_id_helper.cc DEPS framework_proto device_context enforce) if(WITH_GPU) cc_library(cuda_resource_pool SRCS cuda_resource_pool.cc DEPS gpu_info) diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc similarity index 79% rename from paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc rename to paddle/fluid/platform/gen_comm_id_helper.cc index a0df244000b..08f0af5fc91 100644 --- a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" +#ifdef PADDLE_WITH_NCCL +#include "paddle/fluid/platform/gen_comm_id_helper.h" #include #include @@ -31,7 +32,7 @@ limitations under the License. */ #include "paddle/fluid/string/split.h" namespace paddle { -namespace operators { +namespace platform { constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; @@ -257,26 +258,29 @@ static int ConnectAddr(const std::string& ep, const char* head) { return sock; } -static void RecvNCCLID(int conn, ncclUniqueId* nccl_id) { +template +static void RecvCommID(int conn, CommUniqueId* nccl_id) { char buffer[1024] = {0}; - static_assert(NCCL_UNIQUE_ID_BYTES <= 1024, + static_assert(sizeof(CommUniqueId) <= 1024, "nccl id bytes must <= buffer size"); - CHECK_SYS_CALL(SocketRecv(conn, buffer, NCCL_UNIQUE_ID_BYTES), "recv ncc id"); - memcpy(nccl_id, buffer, NCCL_UNIQUE_ID_BYTES); + CHECK_SYS_CALL(SocketRecv(conn, buffer, sizeof(CommUniqueId)), + "recv comm unique id"); + memcpy(nccl_id, buffer, sizeof(CommUniqueId)); } -static void SendNCCLID(int conn, ncclUniqueId* nccl_id) { +template +static void SendCommID(int conn, CommUniqueId* nccl_id) { char buffer[1024] = {0}; - memcpy(buffer, nccl_id, NCCL_UNIQUE_ID_BYTES); + memcpy(buffer, nccl_id, sizeof(CommUniqueId)); - CHECK_SYS_CALL(SocketSend(conn, buffer, NCCL_UNIQUE_ID_BYTES), - "send nccl id"); + CHECK_SYS_CALL(SocketSend(conn, buffer, sizeof(CommUniqueId)), + "send comm unique id"); } -void SendBroadCastNCCLID(std::vector servers, int nccl_comm_num, - std::function func, - const framework::Scope& scope) { +template +void SendBroadCastCommID(std::vector servers, + std::vector* nccl_ids) { // connect with server std::vector connects; for (auto server : servers) { @@ -286,23 +290,13 @@ void SendBroadCastNCCLID(std::vector servers, int nccl_comm_num, } VLOG(3) << "connecting completed..."; - for (int i = 0; i < nccl_comm_num; ++i) { - std::string var_name = func(i); - auto var = scope.FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::NotFound("Variable with name %s is not found", - var_name.c_str())); - auto nccl_id = var->GetMutable(); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(nccl_id)); - + for (size_t i = 0; i < nccl_ids->size(); ++i) { int j = 0; for (auto conn : connects) { - VLOG(3) << "sending nccl_id_var: " << var_name << " to " << servers[j] - << " nccl_comm_no: " << i; - SendNCCLID(conn, nccl_id); + VLOG(3) << "sending comm_id to " << servers[j] << " nccl_comm_no: " << i; + SendCommID(conn, &(*nccl_ids)[i]); ++j; } - VLOG(3) << "sending completed..."; } // close client @@ -311,34 +305,43 @@ void SendBroadCastNCCLID(std::vector servers, int nccl_comm_num, } } -void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, - std::function func, - const framework::Scope& scope) { +template +void RecvBroadCastCommID(std::string endpoint, + std::vector* nccl_ids) { int server = CreateListenSocket(endpoint); - RecvBroadCastNCCLID(server, endpoint, nccl_comm_num, func, scope); + RecvBroadCastCommID(server, endpoint, nccl_ids); CloseSocket(server); } -void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, - std::function func, - const framework::Scope& scope) { +template +void RecvBroadCastCommID(int server_fd, std::string endpoint, + std::vector* nccl_ids) { int client = SocketAccept(server_fd, COMM_HEAD); - for (int i = 0; i < nccl_comm_num; ++i) { - std::string var_name = func(i); - auto var = scope.FindVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::NotFound("Variable with name %s is not found", - var_name.c_str())); - auto nccl_id = var->GetMutable(); - - VLOG(3) << "trainer: " << endpoint << " receiving nccl_id_var: " << var_name - << " from trainer 0, nccl_comm_no: " << i; - RecvNCCLID(client, nccl_id); + for (size_t i = 0; i < nccl_ids->size(); ++i) { + VLOG(3) << "trainer: " << endpoint + << " receiving comm_id from trainer 0, nccl_comm_no: " << i; + RecvCommID(client, &(*nccl_ids)[i]); } + VLOG(3) << "receiving completed..."; CloseSocket(client); } -} // namespace operators +/// template instantiation +#define INSTANT_TEMPLATE(Type) \ + template void SendBroadCastCommID(std::vector servers, \ + std::vector * nccl_ids); \ + template void RecvBroadCastCommID(std::string endpoint, \ + std::vector * nccl_ids); + +#ifdef PADDLE_WITH_NCCL +INSTANT_TEMPLATE(ncclUniqueId) +#endif +#ifdef PADDLE_WITH_XPU_BKCL +INSTANT_TEMPLATE(bkclUniqueId) +#endif +} // namespace platform } // namespace paddle + +#endif diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h b/paddle/fluid/platform/gen_comm_id_helper.h similarity index 50% rename from paddle/fluid/operators/collective/gen_nccl_id_op_helper.h rename to paddle/fluid/platform/gen_comm_id_helper.h index 38751805191..5384d704708 100644 --- a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h +++ b/paddle/fluid/platform/gen_comm_id_helper.h @@ -14,35 +14,31 @@ limitations under the License. */ #pragma once +#ifdef PADDLE_WITH_NCCL #include #include #include namespace paddle { -namespace framework { -class Scope; -} // namespace framework -} // namespace paddle - -namespace paddle { -namespace operators { +namespace platform { int CreateListenSocket(const std::string& ep); void CloseSocket(int fd); -void SendBroadCastNCCLID(std::vector servers, int nccl_comm_num, - std::function func, - const framework::Scope& scope); +template +void SendBroadCastCommID(std::vector servers, + std::vector* nccl_ids); -// server listen on endpoint, then recv nccl id -void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, - std::function func, - const framework::Scope& scope); +template +void RecvBroadCastCommID(std::string endpoint, + std::vector* nccl_ids); // recv nccl id from socket -void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, - std::function func, - const framework::Scope& scope); -} // namespace operators +template +void RecvBroadCastCommID(int server_fd, std::string endpoint, + std::vector* nccl_ids); +} // namespace platform } // namespace paddle + +#endif diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index e6c5f06c4c4..faa1a7c5ee8 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef PADDLE_WITH_NCCL #pragma once +#ifdef PADDLE_WITH_NCCL #include #include #include diff --git a/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py b/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py index bd186e09006..17df3347dc4 100644 --- a/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_gen_nccl_id_op.py @@ -14,10 +14,11 @@ import unittest import os +import copy from launch_function_helper import wait, _find_free_port -from multiprocessing import Pool, Process +from threading import Thread -os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10") +os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10,gen_comm_id*=10") import paddle from paddle.fluid import core @@ -29,8 +30,8 @@ def run_gen_ncc_id(attr): nccl_comm_num = attr['nccl_comm_num'] use_hallreduce = attr['use_hierarchical_allreduce'] - startup_program = paddle.static.default_startup_program() - main_program = paddle.static.default_main_program() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): nccl_id_var = startup_program.global_block().create_var( @@ -60,9 +61,10 @@ def run_gen_ncc_id(attr): attrs=attr) place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(startup_program) + scope = paddle.static.Scope() + with paddle.static.scope_guard(scope): + exe.run(startup_program) class TestGenNcclIdOp(unittest.TestCase): @@ -97,16 +99,19 @@ class TestGenNcclIdOp(unittest.TestCase): procs = [] for i in range(nranks): attr['trainer_id'] = i - p = Process(target=run_gen_ncc_id, args=(attr, )) + # NOTE. multiprocessing cannot be covered by coverage + p = Thread(target=run_gen_ncc_id, args=(copy.copy(attr), )) p.start() procs.append(p) - wait(procs, timeout=120) + for p in procs: + p.join() def test_flat(self): print(">>> test gen flat nccl id") self.gen_nccl_id(2) print("<<< end test gen flat nccl id") + print() def test_hierarchical(self): print(">>> test gen hierarchical nccl id") -- GitLab