未验证 提交 572c466d 编写于 作者: W WangXi 提交者: GitHub

[Prepare for MultiProcess xpu] unified gen nccl id, refine imperative reducer (#30455)

上级 549855ac
...@@ -16,8 +16,24 @@ ...@@ -16,8 +16,24 @@
#include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/imperative/all_reduce.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include <string>
#include <utility>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
static const platform::Place &GetVarPlace(const framework::Variable &src) { static const platform::Place &GetVarPlace(const framework::Variable &src) {
if (src.IsType<framework::LoDTensor>()) { if (src.IsType<framework::LoDTensor>()) {
return src.Get<framework::LoDTensor>().place(); return src.Get<framework::LoDTensor>().place();
......
...@@ -16,21 +16,6 @@ ...@@ -16,21 +16,6 @@
#ifdef PADDLE_WITH_NCCL #ifdef PADDLE_WITH_NCCL
#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include <string>
#include <utility>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Variable; class Variable;
......
...@@ -14,175 +14,54 @@ ...@@ -14,175 +14,54 @@
#include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/nccl_context.h"
namespace paddle { #include <string>
namespace imperative { #include <utility>
#if defined(PADDLE_WITH_NCCL) #include <vector>
void NCCLParallelContext::RecvNCCLID(
const std::string &ep,
std::vector<ncclUniqueId> &nccl_ids) { // NOLINT
int nrings = nccl_ids.size();
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
platform::errors::InvalidArgument(
"The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0];
int port = std::stoi(addr[1]);
int server_fd, new_socket;
struct sockaddr_in address;
int addrlen = sizeof(address);
char buffer[1024] = {0};
int opt = 0;
// creating socket fd
if ((server_fd = socket(AF_INET, SOCK_STREAM, 0)) == 0) {
PADDLE_THROW(
platform::errors::Unavailable("Create server file descriptor failed."));
}
if (setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt))) { #if defined(PADDLE_WITH_NCCL)
PADDLE_THROW(platform::errors::Unavailable("Set socket options failed.")); #include "paddle/fluid/imperative/all_reduce.h"
} #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/dynload/nccl.h"
address.sin_family = AF_INET; #include "paddle/fluid/platform/gen_comm_id_helper.h"
address.sin_addr.s_addr = INADDR_ANY; #endif
address.sin_port = htons(port);
int try_times = 0;
int retry_time = 0;
while (true) {
if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0) {
retry_time = 3 * (try_times + 1);
LOG(WARNING) << "Socket bind worker " << ep
<< (try_times < 9
? " failed, try again after " +
std::to_string(retry_time) + " seconds."
: " failed, try again after " +
std::to_string(retry_time) +
" seconds. Bind on endpoint " + ep +
" failed. Please confirm whether the "
"communication port or GPU card is occupied.");
std::this_thread::sleep_for(std::chrono::seconds(retry_time));
++try_times;
continue;
}
break;
}
VLOG(3) << "listening on: " << ep;
if (listen(server_fd, 3) < 0) {
PADDLE_THROW(platform::errors::Unavailable(
"Listen on server file descriptor failed."));
}
if ((new_socket =
accept(server_fd, reinterpret_cast<struct sockaddr *>(&address),
reinterpret_cast<socklen_t *>(&addrlen))) < 0) {
PADDLE_THROW(platform::errors::Unavailable(
"Accept the new socket file descriptor failed."));
}
if (read(new_socket, buffer, 1024) < 0) {
PADDLE_THROW(platform::errors::Unavailable("Read from socket failed."));
}
VLOG(3) << "recevived the ncclUniqueId";
memcpy(&nccl_ids[0], buffer, nrings * NCCL_UNIQUE_ID_BYTES);
VLOG(3) << "closing the socket server: " << ep;
close(server_fd);
}
void NCCLParallelContext::SendNCCLID(
const std::string &ep, const std::vector<ncclUniqueId> &nccl_ids) {
int nrings = nccl_ids.size();
auto addr = paddle::string::Split(ep, ':');
PADDLE_ENFORCE_EQ(
addr.size(), 2UL,
platform::errors::InvalidArgument(
"The endpoint should contain host and port, but got %s.", ep));
std::string host = addr[0];
int port = std::stoi(addr[1]);
int sock = 0;
struct sockaddr_in serv_addr;
char buffer[1024] = {0};
memcpy(buffer, &nccl_ids[0], nrings * NCCL_UNIQUE_ID_BYTES);
if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
PADDLE_THROW(platform::errors::Unavailable("Create socket failed."));
}
memset(&serv_addr, '0', sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(port);
char *ip = NULL; #include "paddle/fluid/framework/variable.h"
struct hostent *hp; #include "paddle/fluid/platform/device_context.h"
if ((hp = gethostbyname(host.c_str())) == NULL) { #include "paddle/fluid/platform/place.h"
PADDLE_THROW(platform::errors::InvalidArgument( #include "paddle/fluid/string/split.h"
"Fail to get host by name %s.", host)); #include "paddle/fluid/string/string_helper.h"
}
int i = 0;
while (hp->h_addr_list[i] != NULL) {
ip = inet_ntoa(*(struct in_addr *)hp->h_addr_list[i]);
VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip;
break;
}
if (inet_pton(AF_INET, ip, &serv_addr.sin_addr) <= 0) {
PADDLE_THROW(platform::errors::Unavailable("Open address %s failed.", ep));
}
int try_times = 0; namespace paddle {
int retry_time = 0; namespace imperative {
while (true) { #if defined(PADDLE_WITH_NCCL)
if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0) {
retry_time = 3 * (try_times + 1);
LOG(WARNING)
<< "Socket connect worker " << ep
<< (try_times < 9
? " failed, try again after " + std::to_string(retry_time) +
" seconds."
: " failed, try again after " + std::to_string(retry_time) +
" seconds. Maybe that some process is occupied the "
"GPUs of this node now, and you should kill those "
"process manually.");
std::this_thread::sleep_for(std::chrono::seconds(retry_time));
++try_times;
continue;
}
VLOG(3) << "sending the ncclUniqueId to " << ep;
send(sock, buffer, NCCL_UNIQUE_ID_BYTES * nrings, 0);
break;
}
close(sock);
}
void NCCLParallelContext::BcastNCCLId( void NCCLParallelContext::BcastNCCLId(
std::vector<ncclUniqueId> &nccl_ids, // NOLINT std::vector<ncclUniqueId> &nccl_ids, // NOLINT
int root) { int root) {
if (strategy_.local_rank_ == root) { if (strategy_.local_rank_ == root) {
for (auto ep : strategy_.trainer_endpoints_) { std::vector<std::string> other_trainers;
if (ep != strategy_.current_endpoint_) SendNCCLID(ep, nccl_ids); for (auto &ep : strategy_.trainer_endpoints_) {
if (ep != strategy_.current_endpoint_) {
other_trainers.push_back(ep);
}
} }
platform::SendBroadCastCommID(other_trainers, &nccl_ids);
} else { } else {
RecvNCCLID(strategy_.current_endpoint_, nccl_ids); platform::RecvBroadCastCommID(strategy_.current_endpoint_, &nccl_ids);
} }
} }
void NCCLParallelContext::Init() { void NCCLParallelContext::Init() {
std::vector<ncclUniqueId> nccl_ids; std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(strategy_.nrings_); nccl_ids.resize(strategy_.nrings_);
if (strategy_.local_rank_ == 0) { if (strategy_.local_rank_ == 0) {
// generate the unique ncclid on the root worker // generate the unique ncclid on the root worker
for (size_t i = 0; i < nccl_ids.size(); ++i) { for (size_t i = 0; i < nccl_ids.size(); ++i) {
platform::dynload::ncclGetUniqueId(&nccl_ids[i]); platform::dynload::ncclGetUniqueId(&nccl_ids[i]);
} }
BcastNCCLId(nccl_ids, 0);
} else {
BcastNCCLId(nccl_ids, 0);
} }
BcastNCCLId(nccl_ids, 0);
int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device; int gpu_id = BOOST_GET_CONST(platform::CUDAPlace, place_).device;
for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) { for (int ring_id = 0; ring_id < strategy_.nrings_; ring_id++) {
...@@ -193,6 +72,12 @@ void NCCLParallelContext::Init() { ...@@ -193,6 +72,12 @@ void NCCLParallelContext::Init() {
platform::NCCLCommContext::Instance().CreateNCCLComm( platform::NCCLCommContext::Instance().CreateNCCLComm(
&nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id, &nccl_ids[ring_id], strategy_.nranks_, strategy_.local_rank_, gpu_id,
ring_id); ring_id);
compute_events_.emplace_back(
platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
} }
} }
...@@ -206,11 +91,54 @@ void NCCLParallelContext::AllReduceByStream(const framework::Variable &src, ...@@ -206,11 +91,54 @@ void NCCLParallelContext::AllReduceByStream(const framework::Variable &src,
AllReduce(src, dst, strategy_, ring_id, use_calc_stream); AllReduce(src, dst, strategy_, ring_id, use_calc_stream);
} }
paddle::platform::CUDADeviceContext *NCCLParallelContext::GetDeviceContext( paddle::platform::DeviceContext *NCCLParallelContext::GetDeviceContext(
int ring_id) { int ring_id) {
return platform::NCCLCommContext::Instance() return static_cast<platform::DeviceContext *>(
.Get(ring_id, place_) platform::NCCLCommContext::Instance()
->dev_context(); .Get(ring_id, place_)
->dev_context());
}
void NCCLParallelContext::WaitCompute(int ring_id) {
PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange(
"ring id must >= 0, but got %d", ring_id));
PADDLE_ENFORCE_LT(ring_id, compute_events_.size(),
platform::errors::OutOfRange(
"ring id must < compute events size,"
"but got ring id = %d, compute events size = %d",
ring_id, compute_events_.size()));
auto compute_stream = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream();
auto event = compute_events_[ring_id].get();
// compute_stream-->event-->comm_stream
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0));
}
void NCCLParallelContext::WaitComm(int ring_id) {
PADDLE_ENFORCE_GE(ring_id, 0, platform::errors::OutOfRange(
"ring id must >= 0, but got %d", ring_id));
PADDLE_ENFORCE_LT(ring_id, comm_events_.size(),
platform::errors::OutOfRange(
"ring id must < comm events size,"
"but got ring id = %d, comm events size = %d",
ring_id, comm_events_.size()));
auto compute_stream = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
auto comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place_)->stream();
auto event = comm_events_[ring_id].get();
// comm_stream-->event-->compute_stream
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, comm_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0));
} }
#endif #endif
......
...@@ -13,73 +13,20 @@ ...@@ -13,73 +13,20 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
// network header files #include <memory>
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <sys/socket.h>
#endif
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/imperative/all_reduce.h" #include "paddle/fluid/platform/cuda_resource_pool.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/imperative/parallel_context.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
struct ParallelStrategy {
int nranks_{1};
int local_rank_{0};
std::vector<std::string> trainer_endpoints_{};
std::string current_endpoint_{""};
// TODO(shenliang03): support multi stream communication
int nrings_{1};
};
class ParallelContext {
public:
explicit ParallelContext(const ParallelStrategy& strategy,
const platform::Place& place)
: strategy_(strategy), place_(place) {}
virtual ~ParallelContext() {}
virtual void Init() = 0;
virtual void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id = 0,
bool use_calc_stream = false) = 0;
#if defined(PADDLE_WITH_NCCL)
virtual paddle::platform::CUDADeviceContext* GetDeviceContext(
int ring_id) = 0;
#endif
inline int GetNRings() { return strategy_.nrings_; }
protected:
ParallelStrategy strategy_;
platform::Place place_;
};
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
class NCCLParallelContext : public ParallelContext { class NCCLParallelContext : public ParallelContext {
public: public:
...@@ -87,7 +34,7 @@ class NCCLParallelContext : public ParallelContext { ...@@ -87,7 +34,7 @@ class NCCLParallelContext : public ParallelContext {
const platform::Place& place) const platform::Place& place)
: ParallelContext(strategy, place) {} : ParallelContext(strategy, place) {}
~NCCLParallelContext() {} ~NCCLParallelContext() override = default;
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root); // NOLINT void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, int root); // NOLINT
...@@ -97,14 +44,18 @@ class NCCLParallelContext : public ParallelContext { ...@@ -97,14 +44,18 @@ class NCCLParallelContext : public ParallelContext {
framework::Variable* dst, int ring_id, framework::Variable* dst, int ring_id,
bool use_calc_stream) override; bool use_calc_stream) override;
paddle::platform::CUDADeviceContext* GetDeviceContext(int ring_id) override; paddle::platform::DeviceContext* GetDeviceContext(int ring_id) override;
void WaitCompute(int ring_id) override;
void WaitComm(int ring_id) override;
protected: private:
void RecvNCCLID(const std::string& endpoint, // used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<ncclUniqueId>& nccl_ids); // NOLINT std::vector<std::shared_ptr<platform::CudaEventObject>> compute_events_;
void SendNCCLID(const std::string& endpoint, // used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream
const std::vector<ncclUniqueId>& nccl_ids); std::vector<std::shared_ptr<platform::CudaEventObject>> comm_events_;
}; };
#endif #endif
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace imperative {
struct ParallelStrategy {
int nranks_{1};
int local_rank_{0};
std::vector<std::string> trainer_endpoints_{};
std::string current_endpoint_{""};
int nrings_{1};
};
class ParallelContext {
public:
explicit ParallelContext(const ParallelStrategy& strategy,
const platform::Place& place)
: strategy_(strategy), place_(place) {}
virtual ~ParallelContext() = default;
virtual void Init() = 0;
virtual void AllReduceByStream(const framework::Variable& src,
framework::Variable* dst, int ring_id,
bool use_calc_stream) = 0;
virtual paddle::platform::DeviceContext* GetDeviceContext(int ring_id) = 0;
// comm_stream[ring_id] wait compute_stream.
// if CPU, should do nothing.
virtual void WaitCompute(int ring_id) = 0;
// compute_stream wait comm_stream[ring_id]
// if CPU, should do nothing.
virtual void WaitComm(int ring_id) = 0;
inline int GetNRings() const { return strategy_.nrings_; }
protected:
ParallelStrategy strategy_;
platform::Place place_;
};
} // namespace imperative
} // namespace paddle
...@@ -14,60 +14,170 @@ ...@@ -14,60 +14,170 @@
#include "paddle/fluid/imperative/reducer.h" #include "paddle/fluid/imperative/reducer.h"
#include <algorithm>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/string/string_helper.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#endif
#include "paddle/fluid/imperative/parallel_context.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
std::shared_ptr<Reducer> Reducer::s_instance_ = NULL; std::shared_ptr<Reducer> Reducer::s_instance_ = NULL;
// context is used to select the stream for concat template <typename DeviceContext, typename T>
void Group::ConcatTensors(const platform::CUDADeviceContext &context) { static void ConcatTensorsForAllReduce(
VLOG(3) << "Before concat, set output tensor size is " << all_length_; const DeviceContext &context,
auto tensor = dense_contents_.GetMutable<framework::LoDTensor>(); const std::vector<framework::Tensor> &dense_tensors_,
tensor->Resize(framework::make_ddim({all_length_})) framework::Variable *p_dense_contents) {
.mutable_data(context.GetPlace(), dtype_); operators::math::ConcatFunctor<DeviceContext, T> concat_functor_;
concat_functor_(context, dense_tensors_, 0,
p_dense_contents->GetMutable<framework::LoDTensor>());
}
template <typename DeviceContext, typename T>
static void SplitTensorsForAllReduce(
const DeviceContext &context, framework::Variable *p_dense_contents,
std::vector<framework::Tensor> *p_dense_tensors) {
auto *in = p_dense_contents->GetMutable<framework::LoDTensor>();
std::vector<framework::Tensor *> outs;
std::vector<const framework::Tensor *> shape_refer;
switch (dtype_) { outs.reserve(p_dense_tensors->size());
shape_refer.reserve(p_dense_tensors->size());
for (auto &tensor : *p_dense_tensors) {
outs.emplace_back(&tensor);
shape_refer.emplace_back(&tensor);
}
// Sometimes direct copies will be faster
if (p_dense_tensors->size() < 10) {
operators::StridedMemcpyWithAxis0<T>(context, *in, shape_refer, &outs);
} else {
operators::math::SplitFunctor<DeviceContext, T> split_functor_;
split_functor_(context, *in, shape_refer, 0, &outs);
}
}
// context is used to select the stream for concat
template <typename DeviceContext>
static void ConcatTensorsWithType(
const DeviceContext &context,
const std::vector<framework::Tensor> &dense_tensors_,
framework::Variable *p_dense_contents,
framework::proto::VarType::Type type) {
switch (type) {
case framework::proto::VarType::FP16: case framework::proto::VarType::FP16:
ConcatTensorsForAllReduce<platform::float16>(context, dense_tensors_, ConcatTensorsForAllReduce<DeviceContext, platform::float16>(
&dense_contents_); context, dense_tensors_, p_dense_contents);
break; break;
case framework::proto::VarType::FP32: case framework::proto::VarType::FP32:
ConcatTensorsForAllReduce<float>(context, dense_tensors_, ConcatTensorsForAllReduce<DeviceContext, float>(context, dense_tensors_,
&dense_contents_); p_dense_contents);
break; break;
case framework::proto::VarType::FP64: case framework::proto::VarType::FP64:
ConcatTensorsForAllReduce<double>(context, dense_tensors_, ConcatTensorsForAllReduce<DeviceContext, double>(context, dense_tensors_,
&dense_contents_); p_dense_contents);
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors for " "Data type (%s) is not supported when it concats tensors for "
"allreduce.", "allreduce.",
framework::DataTypeToString(dtype_))); framework::DataTypeToString(type)));
} }
} }
// context is used to select the stream for split // context is used to select the stream for split
void Group::SplitTensors(const platform::CUDADeviceContext &context) { template <typename DeviceContext>
switch (dtype_) { static void SplitTensorsWithType(
const DeviceContext &context, framework::Variable *p_dense_contents,
std::vector<framework::Tensor> *p_dense_tensors,
framework::proto::VarType::Type type) {
switch (type) {
case framework::proto::VarType::FP16: case framework::proto::VarType::FP16:
SplitTensorsForAllReduce<platform::float16>(context, &dense_contents_, SplitTensorsForAllReduce<DeviceContext, platform::float16>(
&dense_tensors_); context, p_dense_contents, p_dense_tensors);
break; break;
case framework::proto::VarType::FP32: case framework::proto::VarType::FP32:
SplitTensorsForAllReduce<float>(context, &dense_contents_, SplitTensorsForAllReduce<DeviceContext, float>(context, p_dense_contents,
&dense_tensors_); p_dense_tensors);
break; break;
case framework::proto::VarType::FP64: case framework::proto::VarType::FP64:
SplitTensorsForAllReduce<double>(context, &dense_contents_, SplitTensorsForAllReduce<DeviceContext, double>(context, p_dense_contents,
&dense_tensors_); p_dense_tensors);
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for " "Data type (%s) is not supported when it splits tensors for "
"allreduce.", "allreduce.",
framework::DataTypeToString(dtype_))); framework::DataTypeToString(type)));
}
}
void Group::ConcatTensors(const platform::DeviceContext &context) {
VLOG(3) << "Before concat, set output tensor size is " << all_length_;
auto tensor = dense_contents_.GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim({all_length_}))
.mutable_data(context.GetPlace(), dtype_);
auto place = context.GetPlace();
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_NCCL
ConcatTensorsWithType(
static_cast<const platform::CUDADeviceContext &>(context),
dense_tensors_, &dense_contents_, dtype_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't concat grad tensors since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."));
#endif
} else if (platform::is_cpu_place(place)) {
ConcatTensorsWithType(
static_cast<const platform::CPUDeviceContext &>(context),
dense_tensors_, &dense_contents_, dtype_);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Concat grad tensor not supported on place (%s)", place));
}
}
void Group::SplitTensors(const platform::DeviceContext &context) {
auto place = context.GetPlace();
if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_NCCL
SplitTensorsWithType(
static_cast<const platform::CUDADeviceContext &>(context),
&dense_contents_, &dense_tensors_, dtype_);
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split grad tensor since it's not compiled with NCCL,"
"Please recompile or reinstall Paddle with NCCL support."));
#endif
} else if (platform::is_cpu_place(place)) {
SplitTensorsWithType(
static_cast<const platform::CPUDeviceContext &>(context),
&dense_contents_, &dense_tensors_, dtype_);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Split grad tensor not supported on place (%s)", place));
} }
} }
...@@ -115,44 +225,13 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars, ...@@ -115,44 +225,13 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
}))); })));
var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index; var_index_map_[var->GradVarBase()->SharedVar().get()] = global_var_index;
} }
// create streams
compute_stream_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_))
->stream();
for (int i = 0; i < nrings_; ++i) {
comm_streams_.emplace_back(
platform::NCCLCommContext::Instance().Get(i, place_)->stream());
comm_events_.emplace_back(platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device));
}
CreateGroupEvents(group_indices.size());
std::call_once(once_flag_, []() { std::call_once(once_flag_, []() {
std::atexit([]() { Reducer::GetInstance()->ReleaseReducer(); }); std::atexit([]() { Reducer::GetInstance()->ReleaseReducer(); });
}); });
} }
void Reducer::ReleaseReducer() { void Reducer::ReleaseReducer() { parallel_ctx_.reset(); }
for (auto &event : group_events_) {
event.reset();
}
for (auto &event : comm_events_) {
event.reset();
}
}
void Reducer::CreateGroupEvents(int group_num) {
// release old events
for (auto &event : group_events_) {
event.reset();
}
group_events_.clear();
group_events_.resize(group_num);
for (auto &event : group_events_) {
event = platform::CudaEventResourcePool::Instance().New(
BOOST_GET_CONST(platform::CUDAPlace, place_).device);
}
}
void Reducer::InitializeDenseGroups( void Reducer::InitializeDenseGroups(
const std::vector<size_t> &variable_indices_, Group *p_group) { const std::vector<size_t> &variable_indices_, Group *p_group) {
...@@ -455,18 +534,18 @@ void Reducer::MarkGroupReady(size_t group_index) { ...@@ -455,18 +534,18 @@ void Reducer::MarkGroupReady(size_t group_index) {
return; return;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(group_events_[group_index].get(), compute_stream_));
for (int i = 0; i < nrings_; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(
comm_streams_[i], group_events_[group_index].get(), 0));
}
for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0; for (; next_group_ < groups_.size() && groups_[next_group_].pending_ == 0;
++next_group_) { ++next_group_) {
auto &group = groups_[next_group_]; auto &group = groups_[next_group_];
int run_order = next_group_ % nrings_; int run_order = next_group_ % nrings_;
// For CUDA or XPU, compute_stream --> comm_stream.
// For CPU, do nothing.
// NOTE. Because concat uses the comm_stream,
// so we expose WaitCompute() interface and call
// it here.
parallel_ctx_->WaitCompute(run_order);
if (group.is_sparse_) { if (group.is_sparse_) {
if (group.sparse_contents_ != nullptr) { if (group.sparse_contents_ != nullptr) {
VLOG(3) << "sparse group [" << next_group_ VLOG(3) << "sparse group [" << next_group_
...@@ -526,20 +605,13 @@ void Reducer::FinalizeBackward() { ...@@ -526,20 +605,13 @@ void Reducer::FinalizeBackward() {
all_group_ready_ = false; all_group_ready_ = false;
// Must prevent compute_stream_ starting until all comm streams have finished // Must prevent compute_stream_ starting until all comm streams have finished
for (int i = 0; i < nrings_; ++i) { for (int i = 0; i < nrings_; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS( parallel_ctx_->WaitComm(i);
cudaEventRecord(comm_events_[i].get(), comm_streams_[i]));
}
for (int i = 0; i < nrings_; ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(compute_stream_, comm_events_[i].get(), 0));
} }
if (NeedRebuildGroup()) { if (NeedRebuildGroup()) {
VLOG(3) << "Start rebuilding the groups"; VLOG(3) << "Start rebuilding the groups";
auto rebuild_group_indices = RebuildGruops(); auto rebuild_group_indices = RebuildGruops();
auto rebuild_group_number = rebuild_group_indices.size();
group_indices_ = std::move(rebuild_group_indices); group_indices_ = std::move(rebuild_group_indices);
CreateGroupEvents(rebuild_group_number);
InitializeGroups(group_indices_); InitializeGroups(group_indices_);
} }
......
...@@ -24,60 +24,27 @@ ...@@ -24,60 +24,27 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/string/string_helper.h"
#if defined(PADDLE_WITH_NCCL) namespace paddle {
#include "paddle/fluid/imperative/all_reduce.h" namespace platform {
#include "paddle/fluid/operators/math/concat_and_split.h" class DeviceContext;
#include "paddle/fluid/operators/strided_memcpy.h" } // namespace platform
#include "paddle/fluid/platform/cuda_resource_pool.h"
#endif namespace imperative {
class ParallelContext;
class VarBase;
class VariableWrapper;
} // namespace imperative
} // namespace paddle
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
template <typename T>
void ConcatTensorsForAllReduce(
const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& dense_tensors_,
framework::Variable* p_dense_contents) {
operators::math::ConcatFunctor<platform::CUDADeviceContext, T>
concat_functor_;
concat_functor_(context, dense_tensors_, 0,
p_dense_contents->GetMutable<framework::LoDTensor>());
}
template <typename T>
void SplitTensorsForAllReduce(const platform::CUDADeviceContext& context,
framework::Variable* p_dense_contents,
std::vector<framework::Tensor>* p_dense_tensors) {
auto* in = p_dense_contents->GetMutable<framework::LoDTensor>();
std::vector<framework::Tensor*> outs;
std::vector<const framework::Tensor*> shape_refer;
outs.reserve(p_dense_tensors->size());
shape_refer.reserve(p_dense_tensors->size());
for (auto& tensor : *p_dense_tensors) {
outs.emplace_back(&tensor);
shape_refer.emplace_back(&tensor);
}
// Sometimes direct copies will be faster
if (p_dense_tensors->size() < 10) {
operators::StridedMemcpyWithAxis0<T>(context, *in, shape_refer, &outs);
} else {
operators::math::SplitFunctor<platform::CUDADeviceContext, T>
split_functor_;
split_functor_(context, *in, shape_refer, 0, &outs);
}
}
class Group { class Group {
public: public:
// Here, we use dense_contents_ & sparse_contents_ to // Here, we use dense_contents_ & sparse_contents_ to
...@@ -104,10 +71,10 @@ class Group { ...@@ -104,10 +71,10 @@ class Group {
framework::proto::VarType::Type dtype_; framework::proto::VarType::Type dtype_;
// context is used to select the stream for concat // context is used to select the stream for concat
void ConcatTensors(const platform::CUDADeviceContext& context); void ConcatTensors(const platform::DeviceContext& context);
// context is used to select the stream for split // context is used to select the stream for split
void SplitTensors(const platform::CUDADeviceContext& context); void SplitTensors(const platform::DeviceContext& context);
friend std::ostream& operator<<(std::ostream&, const Group&); friend std::ostream& operator<<(std::ostream&, const Group&);
}; };
...@@ -155,8 +122,6 @@ class Reducer { ...@@ -155,8 +122,6 @@ class Reducer {
std::vector<std::vector<size_t>> RebuildGruops(); std::vector<std::vector<size_t>> RebuildGruops();
void CreateGroupEvents(int group_num);
inline bool NeedRebuildGroup() { return !has_rebuilt_group_; } inline bool NeedRebuildGroup() { return !has_rebuilt_group_; }
// Reducer Singleton // Reducer Singleton
...@@ -193,11 +158,6 @@ class Reducer { ...@@ -193,11 +158,6 @@ class Reducer {
std::shared_ptr<imperative::ParallelContext> parallel_ctx_; std::shared_ptr<imperative::ParallelContext> parallel_ctx_;
std::vector<VariableLocator> variable_locators_; std::vector<VariableLocator> variable_locators_;
// Following variables are to help sync stream
std::vector<std::shared_ptr<platform::CudaEventObject>> group_events_;
std::vector<std::shared_ptr<platform::CudaEventObject>> comm_events_;
cudaStream_t compute_stream_;
std::vector<cudaStream_t> comm_streams_;
int nrings_ = 1; int nrings_ = 1;
// Following variables are to help rebuild group // Following variables are to help rebuild group
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <thread> // NOLINT
#include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/nccl_context.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
......
...@@ -60,6 +60,109 @@ TEST(TestGroup, TestPrintGroupMessage) { ...@@ -60,6 +60,109 @@ TEST(TestGroup, TestPrintGroupMessage) {
ASSERT_STREQ(stream2.str().c_str(), head.c_str()); ASSERT_STREQ(stream2.str().c_str(), head.c_str());
} }
template <typename T, typename Place>
void GroupConcatSplit(Place place, size_t size) {
platform::CPUPlace cpu_place;
Group group;
// [[0.0], [0.0, 1.0], [0.0, 1.0, 2.0] .. ]
std::vector<framework::Variable> vars;
vars.resize(size);
for (size_t i = 0; i < size; ++i) {
auto len = i + 1;
auto* tensor = vars[i].GetMutable<framework::LoDTensor>();
tensor->Resize({static_cast<int64_t>(len)});
auto* data = tensor->mutable_data<T>(place);
std::vector<T> value;
for (size_t j = 0; j < len; ++j) {
value.push_back(static_cast<T>(1.0 * j));
}
if (std::is_same<Place, platform::CUDAPlace>::value) {
paddle::memory::Copy(place, data, cpu_place, value.data(),
sizeof(T) * value.size(), 0);
} else {
paddle::memory::Copy(place, data, cpu_place, value.data(),
sizeof(T) * value.size());
}
framework::Tensor tmp;
tmp.ShareDataWith(*tensor).Resize({static_cast<int64_t>(len)});
group.dense_tensors_.push_back(std::move(tmp));
group.all_length_ += len;
group.dtype_ = tensor->type();
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
{ // concat
group.ConcatTensors(*dev_ctx);
auto* tensor = group.dense_contents_.GetMutable<framework::LoDTensor>();
framework::Tensor tmp;
framework::TensorCopySync(*tensor, cpu_place, &tmp);
auto* data = tmp.data<T>();
size_t offset = 0;
for (size_t i = 0; i < size; ++i) {
auto len = i + 1;
for (size_t j = 0; j < len; ++j) {
EXPECT_EQ(data[offset + j], static_cast<T>(1.0 * j));
// [[-0.0], [-0.0, -1.0], [-0.0, -1.0, -2.0] .. ]
data[offset + j] = -data[offset + j];
}
offset += len;
}
framework::TensorCopySync(tmp, place, tensor);
}
{ // split
group.SplitTensors(*dev_ctx);
for (size_t i = 0; i < size; ++i) {
auto len = i + 1;
auto& tensor = group.dense_tensors_[i];
framework::Tensor tmp;
framework::TensorCopySync(tensor, cpu_place, &tmp);
auto* data = tmp.data<T>();
for (size_t j = 0; j < len; ++j) {
EXPECT_EQ(data[j], static_cast<T>(-1.0 * j));
}
}
}
}
TEST(TestGroup, TestConcatSplit) {
platform::CUDAPlace cuda_place(0);
platform::CPUPlace cpu_place;
int size = 3;
GroupConcatSplit<float>(cpu_place, size);
GroupConcatSplit<double>(cpu_place, size);
GroupConcatSplit<platform::float16>(cpu_place, size);
GroupConcatSplit<float>(cuda_place, size);
GroupConcatSplit<double>(cuda_place, size);
GroupConcatSplit<platform::float16>(cuda_place, size);
size = 15;
GroupConcatSplit<float>(cpu_place, size);
GroupConcatSplit<double>(cpu_place, size);
GroupConcatSplit<platform::float16>(cpu_place, size);
GroupConcatSplit<float>(cuda_place, size);
GroupConcatSplit<double>(cuda_place, size);
GroupConcatSplit<platform::float16>(cuda_place, size);
}
TEST(TestGroup, TestConcatSplitException) {
platform::CUDAPinnedPlace place;
int size = 3;
ASSERT_ANY_THROW(GroupConcatSplit<float>(place, size));
}
#endif #endif
} // namespace imperative } // namespace imperative
......
...@@ -15,9 +15,8 @@ register_operators(EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DE ...@@ -15,9 +15,8 @@ register_operators(EXCLUDES c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DE
if(WITH_NCCL) if(WITH_NCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
cc_library(gen_nccl_id_op_helper SRCS gen_nccl_id_op_helper.cc DEPS nccl_common) op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} gen_nccl_id_op_helper) op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} gen_nccl_id_op_helper)
endif() endif()
if(WITH_GLOO) if(WITH_GLOO)
......
...@@ -23,11 +23,32 @@ limitations under the License. */ ...@@ -23,11 +23,32 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" #include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static void GenNCCLID(std::vector<ncclUniqueId>* nccl_ids) {
for (size_t i = 0; i < nccl_ids->size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclGetUniqueId(&(*nccl_ids)[i]));
}
}
static void CopyNCCLIDToVar(const std::vector<ncclUniqueId>& nccl_ids,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
for (size_t i = 0; i < nccl_ids.size(); ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto nccl_id = var->GetMutable<ncclUniqueId>();
memcpy(nccl_id, &nccl_ids[i], sizeof(ncclUniqueId));
}
}
class CGenNCCLIdOp : public framework::OperatorBase { class CGenNCCLIdOp : public framework::OperatorBase {
public: public:
CGenNCCLIdOp(const std::string& type, CGenNCCLIdOp(const std::string& type,
...@@ -45,14 +66,20 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -45,14 +66,20 @@ class CGenNCCLIdOp : public framework::OperatorBase {
return Output("Out"); return Output("Out");
}; };
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);
if (rank == 0) { if (rank == 0) {
GenNCCLID(&nccl_ids);
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints"); Attr<std::vector<std::string>>("other_endpoints");
SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); platform::SendBroadCastCommID(endpoint_list, &nccl_ids);
} else { } else {
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
RecvBroadCastNCCLID(endpoint, 1, func, local_scope); platform::RecvBroadCastCommID(endpoint, &nccl_ids);
} }
CopyNCCLIDToVar(nccl_ids, func, scope);
scope.DeleteScope(&local_scope); scope.DeleteScope(&local_scope);
} }
}; };
......
...@@ -27,11 +27,32 @@ limitations under the License. */ ...@@ -27,11 +27,32 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" #include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static void GenNCCLID(std::vector<ncclUniqueId>* nccl_ids) {
for (size_t i = 0; i < nccl_ids->size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclGetUniqueId(&(*nccl_ids)[i]));
}
}
static void CopyNCCLIDToVar(const std::vector<ncclUniqueId>& nccl_ids,
std::function<std::string(size_t)> func,
const framework::Scope& scope) {
for (size_t i = 0; i < nccl_ids.size(); ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto nccl_id = var->GetMutable<ncclUniqueId>();
memcpy(nccl_id, &nccl_ids[i], sizeof(ncclUniqueId));
}
}
class GenNCCLIdOp : public framework::OperatorBase { class GenNCCLIdOp : public framework::OperatorBase {
public: public:
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
...@@ -98,19 +119,25 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -98,19 +119,25 @@ class GenNCCLIdOp : public framework::OperatorBase {
<< ", trainers:" << ss.str(); << ", trainers:" << ss.str();
int server_fd = -1; int server_fd = -1;
std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(nccl_comm_num);
/// 1. init flat /// 1. init flat
std::function<std::string(size_t)> func = platform::GetFlatNCCLVarName; std::function<std::string(size_t)> func = platform::GetFlatNCCLVarName;
// broadcast unique id
if (trainer_id == 0) { if (trainer_id == 0) {
GenNCCLID(&nccl_ids);
// server endpoints // server endpoints
std::vector<std::string> flat_endpoints; std::vector<std::string> flat_endpoints;
flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1, flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1,
trainers.end()); trainers.end());
SendBroadCastNCCLID(flat_endpoints, nccl_comm_num, func, scope); platform::SendBroadCastCommID(flat_endpoints, &nccl_ids);
} else { } else {
server_fd = CreateListenSocket(endpoint); server_fd = platform::CreateListenSocket(endpoint);
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
} }
CopyNCCLIDToVar(nccl_ids, func, scope);
/// 2. hierarchical inter ncclid /// 2. hierarchical inter ncclid
func = platform::GetHierarchicalInterNCCLVarName; func = platform::GetHierarchicalInterNCCLVarName;
...@@ -127,10 +154,13 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -127,10 +154,13 @@ class GenNCCLIdOp : public framework::OperatorBase {
} }
VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str(); VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str();
SendBroadCastNCCLID(inter_endpoints, nccl_comm_num, func, scope); GenNCCLID(&nccl_ids);
platform::SendBroadCastCommID(inter_endpoints, &nccl_ids);
CopyNCCLIDToVar(nccl_ids, func, scope);
} else if (inter_trainer_id > 0) { } else if (inter_trainer_id > 0) {
VLOG(1) << "Hierarchical inter ring"; VLOG(1) << "Hierarchical inter ring";
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
CopyNCCLIDToVar(nccl_ids, func, scope);
} }
/// 3. hierarchical exter ncclid /// 3. hierarchical exter ncclid
...@@ -146,15 +176,18 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -146,15 +176,18 @@ class GenNCCLIdOp : public framework::OperatorBase {
} }
VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str(); VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str();
SendBroadCastNCCLID(exter_endpoints, nccl_comm_num, func, scope); GenNCCLID(&nccl_ids);
platform::SendBroadCastCommID(exter_endpoints, &nccl_ids);
CopyNCCLIDToVar(nccl_ids, func, scope);
} else if (exter_trainer_id > 0) { } else if (exter_trainer_id > 0) {
VLOG(1) << "Hierarchical exter ring"; VLOG(1) << "Hierarchical exter ring";
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope); platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids);
CopyNCCLIDToVar(nccl_ids, func, scope);
} }
// close socket server // close socket server
if (trainer_id != 0) { if (trainer_id != 0) {
CloseSocket(server_fd); platform::CloseSocket(server_fd);
} }
} }
}; };
......
...@@ -101,7 +101,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool ...@@ -101,7 +101,7 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS}) ${dgc_deps} dlpack cudnn_workspace_helper ${XPU_CTX_DEPS})
cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce) cc_library(collective_helper SRCS collective_helper.cc gen_comm_id_helper.cc DEPS framework_proto device_context enforce)
if(WITH_GPU) if(WITH_GPU)
cc_library(cuda_resource_pool SRCS cuda_resource_pool.cc DEPS gpu_info) cc_library(cuda_resource_pool SRCS cuda_resource_pool.cc DEPS gpu_info)
......
...@@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h" #ifdef PADDLE_WITH_NCCL
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h> #include <netdb.h>
...@@ -31,7 +32,7 @@ limitations under the License. */ ...@@ -31,7 +32,7 @@ limitations under the License. */
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
namespace paddle { namespace paddle {
namespace operators { namespace platform {
constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; constexpr char COMM_HEAD[] = "_pd_gen_comm_id_";
...@@ -257,26 +258,29 @@ static int ConnectAddr(const std::string& ep, const char* head) { ...@@ -257,26 +258,29 @@ static int ConnectAddr(const std::string& ep, const char* head) {
return sock; return sock;
} }
static void RecvNCCLID(int conn, ncclUniqueId* nccl_id) { template <typename CommUniqueId>
static void RecvCommID(int conn, CommUniqueId* nccl_id) {
char buffer[1024] = {0}; char buffer[1024] = {0};
static_assert(NCCL_UNIQUE_ID_BYTES <= 1024, static_assert(sizeof(CommUniqueId) <= 1024,
"nccl id bytes must <= buffer size"); "nccl id bytes must <= buffer size");
CHECK_SYS_CALL(SocketRecv(conn, buffer, NCCL_UNIQUE_ID_BYTES), "recv ncc id"); CHECK_SYS_CALL(SocketRecv(conn, buffer, sizeof(CommUniqueId)),
memcpy(nccl_id, buffer, NCCL_UNIQUE_ID_BYTES); "recv comm unique id");
memcpy(nccl_id, buffer, sizeof(CommUniqueId));
} }
static void SendNCCLID(int conn, ncclUniqueId* nccl_id) { template <typename CommUniqueId>
static void SendCommID(int conn, CommUniqueId* nccl_id) {
char buffer[1024] = {0}; char buffer[1024] = {0};
memcpy(buffer, nccl_id, NCCL_UNIQUE_ID_BYTES); memcpy(buffer, nccl_id, sizeof(CommUniqueId));
CHECK_SYS_CALL(SocketSend(conn, buffer, NCCL_UNIQUE_ID_BYTES), CHECK_SYS_CALL(SocketSend(conn, buffer, sizeof(CommUniqueId)),
"send nccl id"); "send comm unique id");
} }
void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num, template <typename CommUniqueId>
std::function<std::string(size_t)> func, void SendBroadCastCommID(std::vector<std::string> servers,
const framework::Scope& scope) { std::vector<CommUniqueId>* nccl_ids) {
// connect with server // connect with server
std::vector<int> connects; std::vector<int> connects;
for (auto server : servers) { for (auto server : servers) {
...@@ -286,23 +290,13 @@ void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num, ...@@ -286,23 +290,13 @@ void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num,
} }
VLOG(3) << "connecting completed..."; VLOG(3) << "connecting completed...";
for (int i = 0; i < nccl_comm_num; ++i) { for (size_t i = 0; i < nccl_ids->size(); ++i) {
std::string var_name = func(i);
auto var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto nccl_id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(nccl_id));
int j = 0; int j = 0;
for (auto conn : connects) { for (auto conn : connects) {
VLOG(3) << "sending nccl_id_var: " << var_name << " to " << servers[j] VLOG(3) << "sending comm_id to " << servers[j] << " nccl_comm_no: " << i;
<< " nccl_comm_no: " << i; SendCommID(conn, &(*nccl_ids)[i]);
SendNCCLID(conn, nccl_id);
++j; ++j;
} }
VLOG(3) << "sending completed...";
} }
// close client // close client
...@@ -311,34 +305,43 @@ void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num, ...@@ -311,34 +305,43 @@ void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num,
} }
} }
void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, template <typename CommUniqueId>
std::function<std::string(size_t)> func, void RecvBroadCastCommID(std::string endpoint,
const framework::Scope& scope) { std::vector<CommUniqueId>* nccl_ids) {
int server = CreateListenSocket(endpoint); int server = CreateListenSocket(endpoint);
RecvBroadCastNCCLID(server, endpoint, nccl_comm_num, func, scope); RecvBroadCastCommID(server, endpoint, nccl_ids);
CloseSocket(server); CloseSocket(server);
} }
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, template <typename CommUniqueId>
std::function<std::string(size_t)> func, void RecvBroadCastCommID(int server_fd, std::string endpoint,
const framework::Scope& scope) { std::vector<CommUniqueId>* nccl_ids) {
int client = SocketAccept(server_fd, COMM_HEAD); int client = SocketAccept(server_fd, COMM_HEAD);
for (int i = 0; i < nccl_comm_num; ++i) { for (size_t i = 0; i < nccl_ids->size(); ++i) {
std::string var_name = func(i); VLOG(3) << "trainer: " << endpoint
auto var = scope.FindVar(var_name); << " receiving comm_id from trainer 0, nccl_comm_no: " << i;
PADDLE_ENFORCE_NOT_NULL( RecvCommID(client, &(*nccl_ids)[i]);
var, platform::errors::NotFound("Variable with name %s is not found",
var_name.c_str()));
auto nccl_id = var->GetMutable<ncclUniqueId>();
VLOG(3) << "trainer: " << endpoint << " receiving nccl_id_var: " << var_name
<< " from trainer 0, nccl_comm_no: " << i;
RecvNCCLID(client, nccl_id);
} }
VLOG(3) << "receiving completed..."; VLOG(3) << "receiving completed...";
CloseSocket(client); CloseSocket(client);
} }
} // namespace operators /// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
#ifdef PADDLE_WITH_NCCL
INSTANT_TEMPLATE(ncclUniqueId)
#endif
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE(bkclUniqueId)
#endif
} // namespace platform
} // namespace paddle } // namespace paddle
#endif
...@@ -14,35 +14,31 @@ limitations under the License. */ ...@@ -14,35 +14,31 @@ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_NCCL
#include <functional> #include <functional>
#include <string> #include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace platform {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
int CreateListenSocket(const std::string& ep); int CreateListenSocket(const std::string& ep);
void CloseSocket(int fd); void CloseSocket(int fd);
void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num, template <typename CommUniqueId>
std::function<std::string(size_t)> func, void SendBroadCastCommID(std::vector<std::string> servers,
const framework::Scope& scope); std::vector<CommUniqueId>* nccl_ids);
// server listen on endpoint, then recv nccl id template <typename CommUniqueId>
void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, void RecvBroadCastCommID(std::string endpoint,
std::function<std::string(size_t)> func, std::vector<CommUniqueId>* nccl_ids);
const framework::Scope& scope);
// recv nccl id from socket // recv nccl id from socket
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, template <typename CommUniqueId>
std::function<std::string(size_t)> func, void RecvBroadCastCommID(int server_fd, std::string endpoint,
const framework::Scope& scope); std::vector<CommUniqueId>* nccl_ids);
} // namespace operators } // namespace platform
} // namespace paddle } // namespace paddle
#endif
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifdef PADDLE_WITH_NCCL
#pragma once #pragma once
#ifdef PADDLE_WITH_NCCL
#include <stdio.h> #include <stdio.h>
#include <memory> #include <memory>
#include <string> #include <string>
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
import unittest import unittest
import os import os
import copy
from launch_function_helper import wait, _find_free_port from launch_function_helper import wait, _find_free_port
from multiprocessing import Pool, Process from threading import Thread
os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10") os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10,gen_comm_id*=10")
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -29,8 +30,8 @@ def run_gen_ncc_id(attr): ...@@ -29,8 +30,8 @@ def run_gen_ncc_id(attr):
nccl_comm_num = attr['nccl_comm_num'] nccl_comm_num = attr['nccl_comm_num']
use_hallreduce = attr['use_hierarchical_allreduce'] use_hallreduce = attr['use_hierarchical_allreduce']
startup_program = paddle.static.default_startup_program() startup_program = paddle.static.Program()
main_program = paddle.static.default_main_program() main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
nccl_id_var = startup_program.global_block().create_var( nccl_id_var = startup_program.global_block().create_var(
...@@ -60,9 +61,10 @@ def run_gen_ncc_id(attr): ...@@ -60,9 +61,10 @@ def run_gen_ncc_id(attr):
attrs=attr) attrs=attr)
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(startup_program) scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
class TestGenNcclIdOp(unittest.TestCase): class TestGenNcclIdOp(unittest.TestCase):
...@@ -97,16 +99,19 @@ class TestGenNcclIdOp(unittest.TestCase): ...@@ -97,16 +99,19 @@ class TestGenNcclIdOp(unittest.TestCase):
procs = [] procs = []
for i in range(nranks): for i in range(nranks):
attr['trainer_id'] = i attr['trainer_id'] = i
p = Process(target=run_gen_ncc_id, args=(attr, )) # NOTE. multiprocessing cannot be covered by coverage
p = Thread(target=run_gen_ncc_id, args=(copy.copy(attr), ))
p.start() p.start()
procs.append(p) procs.append(p)
wait(procs, timeout=120) for p in procs:
p.join()
def test_flat(self): def test_flat(self):
print(">>> test gen flat nccl id") print(">>> test gen flat nccl id")
self.gen_nccl_id(2) self.gen_nccl_id(2)
print("<<< end test gen flat nccl id") print("<<< end test gen flat nccl id")
print()
def test_hierarchical(self): def test_hierarchical(self):
print(">>> test gen hierarchical nccl id") print(">>> test gen hierarchical nccl id")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册