“ccf5b80bbd4b9c8da86ae8f9702bed675e24a845”上不存在“test/ir/inference/test_trt_convert_assign.py”
未验证 提交 04e24e58 编写于 作者: L LiYuRio 提交者: GitHub

Create comm_context and modified static init (#49536)

* comm_context and static init

* refactor: move to phi/core/distributed

* refactor: avoid mutable_data usage

* fix: windows sock

* fix: device without nccl
Co-authored-by: 元无心's avatarWen Sun <syl1887415157@126.com>
上级 67fc8e93
add_subdirectory(auto_parallel) add_subdirectory(auto_parallel)
add_subdirectory(collective) add_subdirectory(collective)
add_subdirectory(store)
if(WITH_PYTHON) if(WITH_PYTHON)
py_proto_compile(ps_py_proto SRCS the_one_ps.proto) py_proto_compile(ps_py_proto SRCS the_one_ps.proto)
add_custom_target( add_custom_target(
......
...@@ -12,7 +12,7 @@ if(WITH_DISTRIBUTE) ...@@ -12,7 +12,7 @@ if(WITH_DISTRIBUTE)
cc_library( cc_library(
process_group_gloo process_group_gloo
SRCS process_group_gloo.cc SRCS process_group_gloo.cc
DEPS phi_api eager_api gloo_wrapper) DEPS phi_api eager_api gloo_wrapper tcp_store)
endif() endif()
if(WITH_NCCL OR WITH_RCCL) if(WITH_NCCL OR WITH_RCCL)
...@@ -20,6 +20,7 @@ if(WITH_NCCL OR WITH_RCCL) ...@@ -20,6 +20,7 @@ if(WITH_NCCL OR WITH_RCCL)
process_group_nccl process_group_nccl
SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc
DEPS process_group DEPS process_group
tcp_store
place place
enforce enforce
collective_helper collective_helper
...@@ -32,7 +33,12 @@ if(WITH_XPU_BKCL) ...@@ -32,7 +33,12 @@ if(WITH_XPU_BKCL)
cc_library( cc_library(
process_group_bkcl process_group_bkcl
SRCS process_group_bkcl.cc bkcl_tools.cc common.cc SRCS process_group_bkcl.cc bkcl_tools.cc common.cc
DEPS process_group place enforce collective_helper device_context DEPS process_group
tcp_store
place
enforce
collective_helper
device_context
dense_tensor) dense_tensor)
endif() endif()
...@@ -47,6 +53,11 @@ if(WITH_CUSTOM_DEVICE) ...@@ -47,6 +53,11 @@ if(WITH_CUSTOM_DEVICE)
cc_library( cc_library(
process_group_custom process_group_custom
SRCS process_group_custom.cc custom_ccl_tools.cc common.cc SRCS process_group_custom.cc custom_ccl_tools.cc common.cc
DEPS process_group phi_backends place enforce collective_helper DEPS process_group
tcp_store
phi_backends
place
enforce
collective_helper
device_context) device_context)
endif() endif()
...@@ -72,7 +72,8 @@ bool ProcessGroupBKCL::BKCLTask::Wait(std::chrono::milliseconds timeout) { ...@@ -72,7 +72,8 @@ bool ProcessGroupBKCL::BKCLTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait // Same as Wait
void ProcessGroupBKCL::BKCLTask::Synchronize() { Wait(kWaitTimeout); } void ProcessGroupBKCL::BKCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr<Store>& store, ProcessGroupBKCL::ProcessGroupBKCL(
const std::shared_ptr<phi::distributed::Store>& store,
int rank, int rank,
int size, int size,
int gid) int gid)
...@@ -606,7 +607,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( ...@@ -606,7 +607,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
} }
std::shared_ptr<ProcessGroupBKCL> ProcessGroupBKCL::CreateProcessGroupBKCL( std::shared_ptr<ProcessGroupBKCL> ProcessGroupBKCL::CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) { const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid) {
auto process_group = auto process_group =
std::make_shared<ProcessGroupBKCL>(store, rank, size, gid); std::make_shared<ProcessGroupBKCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group); ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
......
...@@ -21,11 +21,11 @@ ...@@ -21,11 +21,11 @@
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
#include "paddle/fluid/distributed/collective/bkcl_tools.h" #include "paddle/fluid/distributed/collective/bkcl_tools.h"
...@@ -67,13 +67,16 @@ class ProcessGroupBKCL : public ProcessGroupWithStream { ...@@ -67,13 +67,16 @@ class ProcessGroupBKCL : public ProcessGroupWithStream {
}; };
public: public:
ProcessGroupBKCL(const std::shared_ptr<Store>& store, ProcessGroupBKCL(const std::shared_ptr<phi::distributed::Store>& store,
int rank, int rank,
int size, int size,
int gid); int gid);
static std::shared_ptr<ProcessGroupBKCL> CreateProcessGroupBKCL( static std::shared_ptr<ProcessGroupBKCL> CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid); const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid);
std::string GetBackendName() const override { std::string GetBackendName() const override {
return std::string(BKCL_BACKEND_NAME); return std::string(BKCL_BACKEND_NAME);
...@@ -176,7 +179,7 @@ class ProcessGroupBKCL : public ProcessGroupWithStream { ...@@ -176,7 +179,7 @@ class ProcessGroupBKCL : public ProcessGroupWithStream {
void SyncCalcStream(const Place& place); void SyncCalcStream(const Place& place);
private: private:
std::shared_ptr<Store> store_; std::shared_ptr<phi::distributed::Store> store_;
std::mutex mutex_; std::mutex mutex_;
std::shared_ptr<XPUEventManager> calc_event_; // event on calc stream std::shared_ptr<XPUEventManager> calc_event_; // event on calc stream
std::unordered_map<std::string, phi::XPUContext*> place_to_calc_ctx_; std::unordered_map<std::string, phi::XPUContext*> place_to_calc_ctx_;
......
...@@ -98,7 +98,8 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) { ...@@ -98,7 +98,8 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait // Same as Wait
void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); } void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr<Store>& store, ProcessGroupCustom::ProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type, const std::string& device_type,
int rank, int rank,
int size, int size,
...@@ -438,7 +439,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( ...@@ -438,7 +439,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::shared_ptr<ProcessGroupCustom> std::shared_ptr<ProcessGroupCustom>
ProcessGroupCustom::CreateProcessGroupCustom( ProcessGroupCustom::CreateProcessGroupCustom(
const std::shared_ptr<Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type, const std::string& device_type,
int rank, int rank,
int size, int size,
......
...@@ -24,17 +24,18 @@ ...@@ -24,17 +24,18 @@
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h" #include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h" #include "paddle/fluid/platform/device/npu/npu_stream.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/distributed/store/store.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using Place = paddle::platform::Place; using Place = paddle::platform::Place;
using CustomDeviceContext = paddle::platform::CustomDeviceContext; using CustomDeviceContext = paddle::platform::CustomDeviceContext;
class ProcessGroupCustom : public ProcessGroupWithoutStream { class ProcessGroupCustom : public ProcessGroupWithoutStream {
public: public:
class CustomTask : public ProcessGroup::Task, class CustomTask : public ProcessGroup::Task,
...@@ -64,14 +65,14 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream { ...@@ -64,14 +65,14 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream {
const std::string device_type_; const std::string device_type_;
}; };
ProcessGroupCustom(const std::shared_ptr<Store>& store, ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type, const std::string& device_type,
int rank, int rank,
int size, int size,
int gid); int gid);
static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom( static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type, const std::string& device_type,
int rank, int rank,
int size, int size,
...@@ -127,7 +128,7 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream { ...@@ -127,7 +128,7 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream {
CommType opType, CommType opType,
const std::vector<phi::DenseTensor>& inputs); const std::vector<phi::DenseTensor>& inputs);
std::shared_ptr<Store> store_; std::shared_ptr<phi::distributed::Store> store_;
std::shared_ptr<CustomCCLCommManager> custom_comm_; std::shared_ptr<CustomCCLCommManager> custom_comm_;
std::mutex mutex_; std::mutex mutex_;
std::unordered_map<std::string, std::unordered_map<std::string,
......
...@@ -177,7 +177,7 @@ ProcessGroupGloo::GlooTask::GlooTask( ...@@ -177,7 +177,7 @@ ProcessGroupGloo::GlooTask::GlooTask(
: ProcessGroup::Task(rank, inputs, comm_type) {} : ProcessGroup::Task(rank, inputs, comm_type) {}
ProcessGroupGloo::ProcessGroupGloo( ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
int rank, int rank,
int world_size, int world_size,
int gid, int gid,
...@@ -601,10 +601,11 @@ ProcessGroupGloo::createDefaultDevice() { ...@@ -601,10 +601,11 @@ ProcessGroupGloo::createDefaultDevice() {
0, 0,
platform::errors::Fatal("Get hostname error for createDefaultDevice.")); platform::errors::Fatal("Get hostname error for createDefaultDevice."));
::addrinfo* result; ::addrinfo* result;
result = tcputils::get_addr_info(hostname.data(), "", 0, AF_UNSPEC); result = phi::distributed::tcputils::get_addr_info(
hostname.data(), "", 0, AF_UNSPEC);
::addrinfo* cur; ::addrinfo* cur;
for (cur = result; cur != nullptr; cur = cur->ai_next) { for (cur = result; cur != nullptr; cur = cur->ai_next) {
SocketType socket = phi::distributed::SocketType socket =
::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);
if (socket == -1) { if (socket == -1) {
continue; continue;
...@@ -628,7 +629,10 @@ ProcessGroupGloo::createDefaultDevice() { ...@@ -628,7 +629,10 @@ ProcessGroupGloo::createDefaultDevice() {
} }
std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo( std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo(
const std::shared_ptr<Store>& store, int rank, int size, int gid) { const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid) {
std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
auto opts = GlooOptions::create(); auto opts = GlooOptions::create();
char* ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); char* ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h" #include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/fluid/distributed/store/store.h" #include "paddle/phi/core/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_store.h" #include "paddle/phi/core/distributed/store/tcp_store.h"
#ifdef PADDLE_WITH_GLOO #ifdef PADDLE_WITH_GLOO
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" #include "paddle/fluid/framework/fleet/gloo_wrapper.h"
...@@ -52,7 +52,7 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { ...@@ -52,7 +52,7 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
class GlooStore : public ::gloo::rendezvous::Store { class GlooStore : public ::gloo::rendezvous::Store {
public: public:
explicit GlooStore(const std::shared_ptr<paddle::distributed::Store>& store) explicit GlooStore(const std::shared_ptr<phi::distributed::Store>& store)
: _store(store) {} : _store(store) {}
~GlooStore() = default; ~GlooStore() = default;
...@@ -86,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { ...@@ -86,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
} }
protected: protected:
std::shared_ptr<paddle::distributed::Store> _store; std::shared_ptr<phi::distributed::Store> _store;
}; };
class GlooOptions { class GlooOptions {
...@@ -99,14 +99,14 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { ...@@ -99,14 +99,14 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
std::shared_ptr<::gloo::transport::Device> device; std::shared_ptr<::gloo::transport::Device> device;
}; };
ProcessGroupGloo(const std::shared_ptr<paddle::distributed::Store>& store, ProcessGroupGloo(const std::shared_ptr<phi::distributed::Store>& store,
int rank, int rank,
int world_size, int world_size,
int gid, int gid,
std::shared_ptr<GlooOptions> options); std::shared_ptr<GlooOptions> options);
static std::shared_ptr<ProcessGroupGloo> CreateProcessGroupGloo( static std::shared_ptr<ProcessGroupGloo> CreateProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
int rank, int rank,
int world_size, int world_size,
int gid); int gid);
......
...@@ -86,7 +86,8 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { ...@@ -86,7 +86,8 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait // Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); } void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store, ProcessGroupNCCL::ProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store,
int rank, int rank,
int size, int size,
int gid) int gid)
...@@ -1151,7 +1152,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1151,7 +1152,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
} }
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL( std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) { const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid) {
auto process_group = auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid); std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group); ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
......
...@@ -22,10 +22,10 @@ ...@@ -22,10 +22,10 @@
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/device_event.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/nccl_tools.h" #include "paddle/fluid/distributed/collective/nccl_tools.h"
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
#ifdef PADDLE_WITH_RCCL #ifdef PADDLE_WITH_RCCL
#include "paddle/phi/backends/dynload/rccl.h" #include "paddle/phi/backends/dynload/rccl.h"
#elif PADDLE_WITH_NCCL #else
#include "paddle/phi/backends/dynload/nccl.h" #include "paddle/phi/backends/dynload/nccl.h"
#endif #endif
...@@ -76,9 +76,12 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -76,9 +76,12 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
public: public:
static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL( static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid); const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid);
ProcessGroupNCCL(const std::shared_ptr<Store>& store, ProcessGroupNCCL(const std::shared_ptr<phi::distributed::Store>& store,
int rank, int rank,
int size, int size,
int gid); int gid);
...@@ -243,7 +246,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -243,7 +246,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
const std::vector<Place>& places); const std::vector<Place>& places);
private: private:
std::shared_ptr<Store> store_; std::shared_ptr<phi::distributed::Store> store_;
std::unordered_map<std::string, platform::DeviceEvent> std::unordered_map<std::string, platform::DeviceEvent>
place_to_calc_event_; // event on calc stream place_to_calc_event_; // event on calc stream
......
...@@ -34,7 +34,8 @@ register_operators( ...@@ -34,7 +34,8 @@ register_operators(
${COLLECTIVE_DEPS}) ${COLLECTIVE_DEPS})
if(WITH_NCCL OR WITH_RCCL) if(WITH_NCCL OR WITH_RCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper
comm_context_manager nccl_comm_context)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
endif() endif()
......
...@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif #endif
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
namespace paddle { namespace paddle {
...@@ -31,35 +32,27 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> { ...@@ -31,35 +32,27 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto x = ctx.Input<phi::DenseTensor>("X"); auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out"); auto out = ctx.Output<phi::DenseTensor>("Out");
int numel = x->numel();
ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); const auto& place = ctx.GetPlace();
auto map = distributed::ProcessGroupMapFromGid::getInstance(); ctx.device_context().Alloc<T>(out);
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*x);
out_tensor.push_back(*out);
auto task = pg->Broadcast(in_tensor, out_tensor);
task->Wait();
return;
}
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");
gpuStream_t stream = ctx.cuda_device_context().stream();
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(rid)) {
auto* comm_context = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(rid));
comm_context->Broadcast(out, *x, root, stream);
} else {
// NOTE(liyurui): This will be removed after moving this operator to phi.
int numel = x->numel();
ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (root == comm->rank()) { if (root == comm->rank()) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), reinterpret_cast<void*>(const_cast<T*>(x->data<T>())),
...@@ -70,7 +63,6 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> { ...@@ -70,7 +63,6 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
stream)); stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel(); << x->numel();
if (out != x) { if (out != x) {
framework::TensorCopy( framework::TensorCopy(
*static_cast<const phi::DenseTensor*>(x), *static_cast<const phi::DenseTensor*>(x),
...@@ -79,18 +71,13 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> { ...@@ -79,18 +71,13 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
static_cast<phi::DenseTensor*>(out)); static_cast<phi::DenseTensor*>(out));
} }
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
platform::dynload::ncclBcast(out->mutable_data<T>(place), out->data<T>(), numel, dtype, root, comm->comm(), stream));
numel,
dtype,
root,
comm->comm(),
stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received " VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received "
<< phi::product(out->dims()); << phi::product(out->dims());
} }
}
out->Resize(x->dims());
out->set_lod(x->lod()); out->set_lod(x->lod());
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
......
...@@ -37,6 +37,7 @@ set(PYBIND_DEPS ...@@ -37,6 +37,7 @@ set(PYBIND_DEPS
global_utils global_utils
phi_utils phi_utils
tcp_store tcp_store
comm_context_manager
new_profiler new_profiler
auto_parallel auto_parallel
jit_layer jit_layer
......
...@@ -21,25 +21,40 @@ limitations under the License. */ ...@@ -21,25 +21,40 @@ limitations under the License. */
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <chrono> #include <chrono>
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/distributed/store/tcp_store.h" #include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace py = pybind11; namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using TCPStore = paddle::distributed::TCPStore; void BindCommContextManager(py::module *m) {
auto CommContextManager =
py::class_<phi::distributed::CommContextManager,
std::shared_ptr<phi::distributed::CommContextManager>>(
*m, "CommContextManager")
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
.def_static(
"create_nccl_comm_context",
&phi::distributed::CommContextManager::CreateNCCLCommContext,
py::call_guard<py::gil_scoped_release>())
#endif
.def("set_store", &phi::distributed::CommContextManager::SetStore);
}
using TCPStore = phi::distributed::TCPStore;
void BindTCPStore(py::module *m) { void BindTCPStore(py::module *m) {
auto Store = auto Store = py::class_<phi::distributed::Store,
py::class_<distributed::Store, std::shared_ptr<distributed::Store>>( std::shared_ptr<phi::distributed::Store>>(*m, "Store")
*m, "Store")
.def(py::init<>()) .def(py::init<>())
.def( .def(
"set", "set",
[](distributed::Store &self, [](phi::distributed::Store &self,
const std::string &key, const std::string &key,
const std::string &value) { const std::string &value) {
std::vector<uint8_t> data(value.begin(), value.end()); std::vector<uint8_t> data(value.begin(), value.end());
...@@ -50,7 +65,7 @@ void BindTCPStore(py::module *m) { ...@@ -50,7 +65,7 @@ void BindTCPStore(py::module *m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def( .def(
"get", "get",
[](distributed::Store &self, [](phi::distributed::Store &self,
const std::string &key) -> py::bytes { const std::string &key) -> py::bytes {
auto data = self.get(key); auto data = self.get(key);
return py::bytes(reinterpret_cast<char *>(data.data()), return py::bytes(reinterpret_cast<char *>(data.data()),
...@@ -59,10 +74,10 @@ void BindTCPStore(py::module *m) { ...@@ -59,10 +74,10 @@ void BindTCPStore(py::module *m) {
py::arg("key"), py::arg("key"),
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("add", .def("add",
&distributed::Store::add, &phi::distributed::Store::add,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("wait", .def("wait",
&distributed::Store::wait, &phi::distributed::Store::wait,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store) py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store)
......
...@@ -26,6 +26,7 @@ namespace paddle { ...@@ -26,6 +26,7 @@ namespace paddle {
namespace pybind { namespace pybind {
void BindTCPStore(pybind11::module* m); void BindTCPStore(pybind11::module* m);
void BindCommContextManager(pybind11::module* m);
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -46,7 +46,6 @@ limitations under the License. */ ...@@ -46,7 +46,6 @@ limitations under the License. */
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/process_group_gloo.h" #include "paddle/fluid/distributed/collective/process_group_gloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#endif #endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
......
...@@ -1871,6 +1871,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1871,6 +1871,7 @@ All parameter, weight, gradient are variables in Paddle.
BindGlobalValueGetterSetter(&m); BindGlobalValueGetterSetter(&m);
BindFleetExecutor(&m); BindFleetExecutor(&m);
BindTCPStore(&m); BindTCPStore(&m);
BindCommContextManager(&m);
BindAutoParallel(&m); BindAutoParallel(&m);
BindJitProperty(&m); BindJitProperty(&m);
......
# compatible utils used for fluid op system # compatible utils used for fluid op system
add_subdirectory(compat) add_subdirectory(compat)
add_subdirectory(distributed)
if(WITH_GPU) if(WITH_GPU)
proto_library(external_error_proto SRCS external_error.proto) proto_library(external_error_proto SRCS external_error.proto)
......
add_subdirectory(store)
set(COMM_CONTEXT_MANAGER_DEPS tcp_store)
if(WITH_NCCL OR WITH_RCCL)
cc_library(
nccl_comm_context
SRCS nccl_comm_context.cc
DEPS dense_tensor)
list(APPEND COMM_CONTEXT_MANAGER_DEPS nccl_comm_context)
endif()
cc_library(
comm_context_manager
SRCS comm_context_manager.cc
DEPS ${COMM_CONTEXT_MANAGER_DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/macros.h"
namespace phi {
namespace distributed {
class CommContext {
public:
CommContext(int rank, int size) : rank_(rank), size_(size) {}
virtual ~CommContext() = default;
protected:
int rank_;
int size_;
private:
DISABLE_COPY_AND_ASSIGN(CommContext);
};
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include <memory>
#include <string>
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/enforce.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
namespace phi {
namespace distributed {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void CommContextManager::CreateNCCLCommContext(
const std::shared_ptr<Store>& store,
int dev_id,
int ring_id,
int rank,
int size) {
phi::backends::gpu::SetDeviceId(dev_id);
ncclUniqueId nccl_id;
if (rank == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id));
}
std::string unique_key = "NCCLCommContext/" + std::to_string(ring_id);
if (rank == 0) {
std::vector<uint8_t> nccl_id_wrapper(
reinterpret_cast<uint8_t*>(&nccl_id),
reinterpret_cast<uint8_t*>(&nccl_id) + NCCL_UNIQUE_ID_BYTES);
store->set(unique_key, nccl_id_wrapper);
} else {
const auto& nccl_id_wrapper = store->get(unique_key);
std::memcpy(&nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
}
auto nccl_comm_context =
std::make_unique<NCCLCommContext>(rank, size, nccl_id);
auto& comm_context_manager = CommContextManager::GetInstance();
comm_context_manager.SetStore(store);
comm_context_manager.Emplace(ring_id, std::move(nccl_comm_context));
}
#endif
CommContext* CommContextManager::Emplace(
int ring_id, std::unique_ptr<CommContext> comm_context) {
PADDLE_ENFORCE_EQ(
id_to_comm_context_.find(ring_id),
id_to_comm_context_.end(),
errors::AlreadyExists("Ring id %d already exists in the map.", ring_id));
id_to_comm_context_.emplace(ring_id, std::move(comm_context));
return id_to_comm_context_.at(ring_id).get();
}
CommContext* CommContextManager::Get(int ring_id) const {
PADDLE_ENFORCE_NE(
id_to_comm_context_.find(ring_id),
id_to_comm_context_.end(),
errors::NotFound("Can not find ring id %d in map.", ring_id));
return id_to_comm_context_.at(ring_id).get();
}
bool CommContextManager::Has(int ring_id) const {
return id_to_comm_context_.find(ring_id) != id_to_comm_context_.end();
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <unordered_map>
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
namespace phi {
namespace distributed {
class Store;
class CommContextManager {
public:
CommContextManager() = default;
~CommContextManager() = default;
static CommContextManager& GetInstance() {
static CommContextManager instance;
return instance;
}
void SetStore(const std::shared_ptr<Store>& store) { store_ = store; }
CommContext* Emplace(int ring_id, std::unique_ptr<CommContext> comm_context);
CommContext* Get(int ring_id) const;
bool Has(int ring_id) const;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
static void CreateNCCLCommContext(const std::shared_ptr<Store>& store,
int dev_id,
int ring_id,
int rank,
int size);
#endif
private:
DISABLE_COPY_AND_ASSIGN(CommContextManager);
std::unordered_map<int, std::unique_ptr<CommContext>> id_to_comm_context_;
std::shared_ptr<Store> store_;
};
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace distributed {
NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id)
: CommContext(rank, size) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_));
}
void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
gpuStream_t stream) {
phi::dynload::ncclBroadcast(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
ToNCCLDataType(in_tensor.type()),
root,
nccl_comm_,
stream);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
#if defined(PADDLE_WITH_RCCL)
#include "paddle/phi/backends/dynload/rccl.h"
#else
#include "paddle/phi/backends/dynload/nccl.h"
#endif
namespace phi {
class DenseTensor;
namespace distributed {
class NCCLCommContext final : public CommContext {
public:
NCCLCommContext(int rank, int size, ncclUniqueId nccl_id);
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
gpuStream_t stream);
private:
DISABLE_COPY_AND_ASSIGN(NCCLCommContext);
ncclComm_t nccl_comm_;
};
} // namespace distributed
} // namespace phi
cc_library( cc_library(
tcp_store tcp_store
SRCS tcp_store.cc tcp_utils.cc socket.cpp SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc
DEPS enforce glog) DEPS enforce glog)
if(NOT WIN32) if(NOT WIN32)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/store/socket.h" #include "paddle/phi/core/distributed/store/socket.h"
#ifndef _WIN32 #ifndef _WIN32
#include <arpa/inet.h> #include <arpa/inet.h>
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include <errno.h> #include <errno.h>
#include <stdio.h> #include <stdio.h>
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
#ifdef _WIN32 #ifdef _WIN32
...@@ -75,5 +75,5 @@ std::string GetSockName(int fd) { ...@@ -75,5 +75,5 @@ std::string GetSockName(int fd) {
return std::string(out); return std::string(out);
} }
}; // namespace distributed } // namespace distributed
}; // namespace paddle } // namespace phi
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
#include <string> #include <string>
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
int GetSockName(int fd, char* out, int out_len); int GetSockName(int fd, char* out, int out_len);
std::string GetSockName(int fd); std::string GetSockName(int fd);
}; // namespace distributed } // namespace distributed
}; // namespace paddle } // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace distributed {
int64_t Store::add(const std::string& key, int64_t value) {
PADDLE_THROW(
errors::InvalidArgument("Implement the add method in the subclass."));
}
std::vector<uint8_t> Store::get(const std::string& key) {
PADDLE_THROW(
errors::InvalidArgument("Implement the get method in the subclass."));
}
void Store::wait(const std::string& key) {
PADDLE_THROW(
errors::InvalidArgument("Implement the wait method in the subclass."));
}
void Store::set(const std::string& key, const std::vector<uint8_t>& value) {
PADDLE_THROW(
errors::InvalidArgument("Implement the set method in the subclass."));
}
} // namespace distributed
} // namespace phi
...@@ -18,9 +18,7 @@ ...@@ -18,9 +18,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/store/tcp_utils.h" namespace phi {
namespace paddle {
namespace distributed { namespace distributed {
class Store { class Store {
...@@ -29,22 +27,10 @@ class Store { ...@@ -29,22 +27,10 @@ class Store {
explicit Store(const int timeout) : _timeout(timeout) {} explicit Store(const int timeout) : _timeout(timeout) {}
virtual ~Store() = default; virtual ~Store() = default;
virtual int64_t add(const std::string& key, int64_t value) { virtual int64_t add(const std::string& key, int64_t value);
PADDLE_THROW(platform::errors::InvalidArgument( virtual std::vector<uint8_t> get(const std::string& key);
"Implement the add method in the subclass.")); virtual void wait(const std::string& key);
} virtual void set(const std::string& key, const std::vector<uint8_t>& value);
virtual std::vector<uint8_t> get(const std::string& key) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual void wait(const std::string& key) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual void set(const std::string& key, const std::vector<uint8_t>& value) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual int timeout() { return _timeout; } virtual int timeout() { return _timeout; }
...@@ -53,4 +39,4 @@ class Store { ...@@ -53,4 +39,4 @@ class Store {
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -12,17 +12,16 @@ ...@@ -12,17 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/store/tcp_store.h" #include "paddle/phi/core/distributed/store/tcp_store.h"
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <thread> #include <thread>
#include "paddle/fluid/distributed/store/tcp_utils.h" #include "paddle/phi/core/distributed/store/tcp_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/flags.h" #include "paddle/phi/core/flags.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
namespace detail { namespace detail {
...@@ -90,7 +89,7 @@ void MasterDaemon::_do_get(SocketType socket) { ...@@ -90,7 +89,7 @@ void MasterDaemon::_do_get(SocketType socket) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, iter,
_store.end(), _store.end(),
platform::errors::InvalidArgument("Key %s not found in TCPStore.", key)); phi::errors::InvalidArgument("Key %s not found in TCPStore.", key));
std::vector<uint8_t> value = iter->second; std::vector<uint8_t> value = iter->second;
tcputils::send_vector<uint8_t>(socket, value); tcputils::send_vector<uint8_t>(socket, value);
} }
...@@ -100,7 +99,7 @@ void MasterDaemon::InitControlFd() { ...@@ -100,7 +99,7 @@ void MasterDaemon::InitControlFd() {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
pipe(_control_fd.data()), pipe(_control_fd.data()),
-1, -1,
platform::errors::Fatal("failed to cread control pipe errno:%d", errno)); phi::errors::Fatal("failed to cread control pipe errno:%d", errno));
} }
void MasterDaemon::CloseControlFd() { void MasterDaemon::CloseControlFd() {
for (int fd : _control_fd) { for (int fd : _control_fd) {
...@@ -112,10 +111,10 @@ void MasterDaemon::CloseControlFd() { ...@@ -112,10 +111,10 @@ void MasterDaemon::CloseControlFd() {
void MasterDaemon::StopByControlFd() { void MasterDaemon::StopByControlFd() {
VLOG(4) << ("begin to run StopByControlFd"); VLOG(4) << ("begin to run StopByControlFd");
if (_control_fd[1] != -1) { if (_control_fd[1] != -1) {
PADDLE_ENFORCE_NE(::write(_control_fd[1], "\0", 1), PADDLE_ENFORCE_NE(
::write(_control_fd[1], "\0", 1),
-1, -1,
platform::errors::Fatal( phi::errors::Fatal("failed to write control pipe errno:%d", errno));
"failed to write control pipe errno:%d", errno));
// close the write end of the pipe // close the write end of the pipe
::close(_control_fd[1]); ::close(_control_fd[1]);
_control_fd[1] = -1; _control_fd[1] = -1;
...@@ -125,7 +124,7 @@ void MasterDaemon::StopByControlFd() { ...@@ -125,7 +124,7 @@ void MasterDaemon::StopByControlFd() {
void MasterDaemon::InitControlFd() { void MasterDaemon::InitControlFd() {
ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL); ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL);
PADDLE_ENFORCE(ghStopEvent_, PADDLE_ENFORCE(ghStopEvent_,
platform::errors::Fatal("failed to cread control pipe")); phi::errors::Fatal("failed to cread control pipe"));
} }
void MasterDaemon::CloseControlFd() { CloseHandle(ghStopEvent_); } void MasterDaemon::CloseControlFd() { CloseHandle(ghStopEvent_); }
void MasterDaemon::StopByControlFd() { SetEvent(ghStopEvent_); } void MasterDaemon::StopByControlFd() { SetEvent(ghStopEvent_); }
...@@ -231,8 +230,8 @@ void MasterDaemon::run() { ...@@ -231,8 +230,8 @@ void MasterDaemon::run() {
// The control pipe receive shutdown event, and begin to close it. // The control pipe receive shutdown event, and begin to close it.
if (fds[1].revents != 0) { if (fds[1].revents != 0) {
if (fds[1].revents & ~(POLLIN | POLLHUP)) { if (fds[1].revents & ~(POLLIN | POLLHUP)) {
PADDLE_THROW(paddle::platform::errors::Fatal("Undefined event type:%d", PADDLE_THROW(
fds[1].revents)); phi::errors::Fatal("Undefined event type:%d", fds[1].revents));
} }
VLOG(0) VLOG(0)
<< "receive shutdown event and so quit from MasterDaemon run loop"; << "receive shutdown event and so quit from MasterDaemon run loop";
...@@ -312,9 +311,7 @@ TCPStore::TCPStore(std::string host, ...@@ -312,9 +311,7 @@ TCPStore::TCPStore(std::string host,
: Store(timeout), _is_master(is_master), _num_workers(num_workers) { : Store(timeout), _is_master(is_master), _num_workers(num_workers) {
_timeout = timeout; _timeout = timeout;
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
timeout, timeout, 0, phi::errors::InvalidArgument("timeout must >= %d", timeout));
0,
platform::errors::InvalidArgument("timeout must >= %d", timeout));
VLOG(3) << "input timeout" << timeout << ", member timeout:" << _timeout; VLOG(3) << "input timeout" << timeout << ", member timeout:" << _timeout;
if (_is_master) { if (_is_master) {
...@@ -355,7 +352,7 @@ void TCPStore::waitWorkers() { ...@@ -355,7 +352,7 @@ void TCPStore::waitWorkers() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
completed, completed,
_num_workers, _num_workers,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"TCPStore timeouted and not all workers got ready.")); "TCPStore timeouted and not all workers got ready."));
} }
} while (true); } while (true);
...@@ -398,4 +395,4 @@ void TCPStore::wait(const std::string& key) { ...@@ -398,4 +395,4 @@ void TCPStore::wait(const std::string& key) {
TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; } TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; }
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -30,11 +30,11 @@ ...@@ -30,11 +30,11 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/distributed/store/socket.h" #include "paddle/phi/core/distributed/store/socket.h"
#include "paddle/fluid/distributed/store/store.h" #include "paddle/phi/core/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h" #include "paddle/phi/core/distributed/store/tcp_utils.h"
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
enum class ReplyType { WAITING, STOP_WAIT }; enum class ReplyType { WAITING, STOP_WAIT };
...@@ -143,4 +143,4 @@ class TCPStore : public Store { ...@@ -143,4 +143,4 @@ class TCPStore : public Store {
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -12,15 +12,13 @@ ...@@ -12,15 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/store/tcp_utils.h" #include "paddle/phi/core/distributed/store/tcp_utils.h"
#include <cerrno> #include <cerrno>
#include <cstring> #include <cstring>
#include <thread> #include <thread>
#include "paddle/fluid/platform/enforce.h" namespace phi {
namespace paddle {
namespace distributed { namespace distributed {
namespace tcputils { namespace tcputils {
...@@ -60,7 +58,7 @@ void close_socket(SocketType socket) { ...@@ -60,7 +58,7 @@ void close_socket(SocketType socket) {
: ""); : "");
PADDLE_ENFORCE_EQ(n, PADDLE_ENFORCE_EQ(n,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"%s network %s:%s cannot be obtained. Details: %s.", "%s network %s:%s cannot be obtained. Details: %s.",
proto, proto,
host, host,
...@@ -73,7 +71,7 @@ void close_socket(SocketType socket) { ...@@ -73,7 +71,7 @@ void close_socket(SocketType socket) {
void free_addr_info(::addrinfo* hint) { void free_addr_info(::addrinfo* hint) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
hint, hint,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The parameter for free_addr_info cannot be null.")); "The parameter for free_addr_info cannot be null."));
::freeaddrinfo(hint); ::freeaddrinfo(hint);
} }
...@@ -91,10 +89,10 @@ SocketType tcp_connect(const std::string host, ...@@ -91,10 +89,10 @@ SocketType tcp_connect(const std::string host,
do { do {
for (::addrinfo* cur = res; cur != nullptr; cur = cur->ai_next) { for (::addrinfo* cur = res; cur != nullptr; cur = cur->ai_next) {
sockfd = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); sockfd = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);
PADDLE_ENFORCE_GT(sockfd, PADDLE_ENFORCE_GT(
sockfd,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument("Create socket to connect %s:%s failed. "
"Create socket to connect %s:%s failed. "
"Details: %s. ", "Details: %s. ",
host, host,
port, port,
...@@ -125,7 +123,7 @@ SocketType tcp_connect(const std::string host, ...@@ -125,7 +123,7 @@ SocketType tcp_connect(const std::string host,
PADDLE_ENFORCE_GT(sockfd, PADDLE_ENFORCE_GT(sockfd,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Network %s:%s cannot be connected.", host, port)); "Network %s:%s cannot be connected.", host, port));
VLOG(0) << "Successfully connected to " << host << ":" << port; VLOG(0) << "Successfully connected to " << host << ":" << port;
...@@ -173,7 +171,7 @@ SocketType tcp_listen(const std::string host, ...@@ -173,7 +171,7 @@ SocketType tcp_listen(const std::string host,
PADDLE_ENFORCE_GT(sockfd, PADDLE_ENFORCE_GT(sockfd,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Bind network on %s:%s failedd.", node, port)); "Bind network on %s:%s failedd.", node, port));
::listen(sockfd, LISTENQ); ::listen(sockfd, LISTENQ);
...@@ -190,7 +188,7 @@ SocketType tcp_accept(SocketType socket) { ...@@ -190,7 +188,7 @@ SocketType tcp_accept(SocketType socket) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
new_socket, new_socket,
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The server failed to accept a new connection. Details: %s.", "The server failed to accept a new connection. Details: %s.",
socket_error().message())); socket_error().message()));
#ifndef _WIN32 #ifndef _WIN32
...@@ -225,4 +223,4 @@ std::string receive_string(SocketType socket) { ...@@ -225,4 +223,4 @@ std::string receive_string(SocketType socket) {
} // namespace tcputils } // namespace tcputils
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -26,14 +26,15 @@ ...@@ -26,14 +26,15 @@
#include <sys/socket.h> #include <sys/socket.h>
#include <unistd.h> #include <unistd.h>
#endif #endif
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/enforce.h"
// Utility functions for TCP socket. // Utility functions for TCP socket.
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
#ifdef _WIN32 #ifdef _WIN32
...@@ -82,7 +83,7 @@ void send_bytes(SocketType socket, const T* buffer, size_t len) { ...@@ -82,7 +83,7 @@ void send_bytes(SocketType socket, const T* buffer, size_t len) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
byte_sent, byte_sent,
0, 0,
platform::errors::InvalidArgument("TCP send error. Details: %s.", phi::errors::InvalidArgument("TCP send error. Details: %s.",
socket_error().message())); socket_error().message()));
to_send -= byte_sent; to_send -= byte_sent;
ptr += byte_sent; ptr += byte_sent;
...@@ -102,7 +103,7 @@ void receive_bytes(SocketType socket, T* buffer, size_t len) { ...@@ -102,7 +103,7 @@ void receive_bytes(SocketType socket, T* buffer, size_t len) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
byte_received, byte_received,
0, 0,
platform::errors::InvalidArgument("TCP receive error. Details: %s.", phi::errors::InvalidArgument("TCP receive error. Details: %s.",
socket_error().message())); socket_error().message()));
to_recv -= byte_received; to_recv -= byte_received;
...@@ -140,4 +141,4 @@ T receive_value(SocketType socket) { ...@@ -140,4 +141,4 @@ T receive_value(SocketType socket) {
} // namespace tcputils } // namespace tcputils
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace phi
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
// limitations under the License. // limitations under the License.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/distributed/store/tcp_store.h" #include "paddle/phi/core/distributed/store/tcp_store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h" #include "paddle/phi/core/distributed/store/tcp_utils.h"
#ifdef _WIN32 #ifdef _WIN32
#include <windows.h> #include <windows.h>
#endif #endif
namespace paddle { namespace phi {
namespace distributed { namespace distributed {
TEST(MasterDaemon, init) { TEST(MasterDaemon, init) {
...@@ -48,6 +48,5 @@ TEST(TCPStore, init) { ...@@ -48,6 +48,5 @@ TEST(TCPStore, init) {
paddle::errors::Fatal("result of add is not right")); paddle::errors::Fatal("result of add is not right"));
} }
*/ */
} // namespace distributed
}; // namespace distributed } // namespace phi
}; // namespace paddle
...@@ -211,4 +211,33 @@ inline int TransToProtoVarType(const DataType& dtype) { ...@@ -211,4 +211,33 @@ inline int TransToProtoVarType(const DataType& dtype) {
} }
} }
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
inline ncclDataType_t ToNCCLDataType(DataType type) {
if (type == DataType::FLOAT32) {
return ncclFloat;
} else if (type == DataType::FLOAT64) {
return ncclDouble;
} else if (type == DataType::INT32) {
return ncclInt;
} else if (type == DataType::INT64) {
return ncclInt64;
} else if (type == DataType::FLOAT16) {
return ncclFloat16;
} else if (type == DataType::UINT8) {
return ncclUint8;
} else if (type == DataType::INT8) {
return ncclInt8;
} else if (type == DataType::BOOL) {
return ncclUint8;
#if NCCL_VERSION_CODE >= 21000
} else if (type == DataType::BFLOAT16) {
return ncclBfloat16;
#endif
} else {
PADDLE_THROW(
errors::Unimplemented("This datatype in nccl is not supported."));
}
}
#endif
} // namespace phi } // namespace phi
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import os
import paddle import paddle
...@@ -325,3 +326,25 @@ def is_available(): ...@@ -325,3 +326,25 @@ def is_available():
""" """
return core.is_compiled_with_dist() return core.is_compiled_with_dist()
def _init_parallel_env(backend):
master_endpoint = os.getenv("PADDLE_MASTER", None)
if master_endpoint:
master_addr = master_endpoint.split(":")[0]
master_port = int(master_endpoint.split(":")[1])
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id
is_master = rank == 0
store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
)
if backend == "nccl":
core.CommContextManager.create_nccl_comm_context(
store, dev_id, 0, rank, world_size
)
...@@ -243,6 +243,7 @@ def init_parallel_env(): ...@@ -243,6 +243,7 @@ def init_parallel_env():
_set_expected_place(place) _set_expected_place(place)
group = None group = None
if backend in _valid_backend_list and in_dygraph_mode(): if backend in _valid_backend_list and in_dygraph_mode():
if _default_group_name in _get_group_map_by_name(): if _default_group_name in _get_group_map_by_name():
return _get_group_map_by_name()[_default_group_name] return _get_group_map_by_name()[_default_group_name]
......
...@@ -30,6 +30,14 @@ class TestCollectiveBroadcastAPI(TestDistBase): ...@@ -30,6 +30,14 @@ class TestCollectiveBroadcastAPI(TestDistBase):
"collective_broadcast_api.py", "broadcast", "nccl" "collective_broadcast_api.py", "broadcast", "nccl"
) )
def test_broadcast_nccl_with_comm_context(self):
self.check_with_place(
"collective_broadcast_api.py",
"broadcast",
"nccl",
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_broadcast_gloo(self): def test_broadcast_gloo(self):
self.check_with_place( self.check_with_place(
"collective_broadcast_api.py", "broadcast", "gloo", "0" "collective_broadcast_api.py", "broadcast", "gloo", "0"
......
...@@ -108,6 +108,9 @@ class TestCollectiveAPIRunnerBase: ...@@ -108,6 +108,9 @@ class TestCollectiveAPIRunnerBase:
rank = args["trainerid"] rank = args["trainerid"]
current_endpoint = args["currentendpoint"] current_endpoint = args["currentendpoint"]
nranks = 2 nranks = 2
if args["use_comm_context"]:
paddle.distributed.collective._init_parallel_env(args["backend"])
else:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
if args['backend'] == 'nccl': if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
...@@ -150,6 +153,7 @@ def runtime_main(test_class, col_type): ...@@ -150,6 +153,7 @@ def runtime_main(test_class, col_type):
args["path_id"] = int(os.getenv("PATH_ID")) args["path_id"] = int(os.getenv("PATH_ID"))
args["static_mode"] = int(os.getenv("STATIC_MODE")) args["static_mode"] = int(os.getenv("STATIC_MODE"))
args["dtype"] = os.getenv("DTYPE") args["dtype"] = os.getenv("DTYPE")
args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0")))
model.run_trainer(args) model.run_trainer(args)
...@@ -162,6 +166,7 @@ class TestDistBase(unittest.TestCase): ...@@ -162,6 +166,7 @@ class TestDistBase(unittest.TestCase):
self._find_free_port(), self._find_free_port(),
) )
self._python_interp = sys.executable self._python_interp = sys.executable
self._master_endpoints = "127.0.0.1:%s" % (self._find_free_port())
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
...@@ -204,6 +209,7 @@ class TestDistBase(unittest.TestCase): ...@@ -204,6 +209,7 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w0_ep, "PADDLE_CURRENT_ENDPOINT": w0_ep,
"PADDLE_MASTER": self._master_endpoints,
} }
env1 = { env1 = {
...@@ -212,6 +218,7 @@ class TestDistBase(unittest.TestCase): ...@@ -212,6 +218,7 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep, "PADDLE_CURRENT_ENDPOINT": w1_ep,
"PADDLE_MASTER": self._master_endpoints,
} }
elif core.is_compiled_with_xpu(): elif core.is_compiled_with_xpu():
env0 = { env0 = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册