diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt index 2472a8b4a03ceab44235698982e28dd464be155c..34d480de9ee780db13c7428e0e42bb95e894b5b6 100755 --- a/paddle/fluid/distributed/CMakeLists.txt +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -1,6 +1,5 @@ add_subdirectory(auto_parallel) add_subdirectory(collective) -add_subdirectory(store) if(WITH_PYTHON) py_proto_compile(ps_py_proto SRCS the_one_ps.proto) add_custom_target( diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index 23e1b48d151b17c5e7a98b75f84edcc2ec46b38c..39e4c00462da0c54c81198aab5d2799d585e1d5f 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -12,7 +12,7 @@ if(WITH_DISTRIBUTE) cc_library( process_group_gloo SRCS process_group_gloo.cc - DEPS phi_api eager_api gloo_wrapper) + DEPS phi_api eager_api gloo_wrapper tcp_store) endif() if(WITH_NCCL OR WITH_RCCL) @@ -20,6 +20,7 @@ if(WITH_NCCL OR WITH_RCCL) process_group_nccl SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc DEPS process_group + tcp_store place enforce collective_helper @@ -32,7 +33,12 @@ if(WITH_XPU_BKCL) cc_library( process_group_bkcl SRCS process_group_bkcl.cc bkcl_tools.cc common.cc - DEPS process_group place enforce collective_helper device_context + DEPS process_group + tcp_store + place + enforce + collective_helper + device_context dense_tensor) endif() @@ -47,6 +53,11 @@ if(WITH_CUSTOM_DEVICE) cc_library( process_group_custom SRCS process_group_custom.cc custom_ccl_tools.cc common.cc - DEPS process_group phi_backends place enforce collective_helper + DEPS process_group + tcp_store + phi_backends + place + enforce + collective_helper device_context) endif() diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.cc b/paddle/fluid/distributed/collective/process_group_bkcl.cc index e1f35ecd5e19228347a7da908ae42d32f7e98161..b8cd285c87ff391cb1768e155d4a6ac287e4a40e 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.cc +++ b/paddle/fluid/distributed/collective/process_group_bkcl.cc @@ -72,10 +72,11 @@ bool ProcessGroupBKCL::BKCLTask::Wait(std::chrono::milliseconds timeout) { // Same as Wait void ProcessGroupBKCL::BKCLTask::Synchronize() { Wait(kWaitTimeout); } -ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr& store, - int rank, - int size, - int gid) +ProcessGroupBKCL::ProcessGroupBKCL( + const std::shared_ptr& store, + int rank, + int size, + int gid) : ProcessGroupWithStream(rank, size, gid), store_(store) {} void ProcessGroupBKCL::GroupStart() { @@ -606,7 +607,10 @@ std::shared_ptr ProcessGroupBKCL::AllGather( } std::shared_ptr ProcessGroupBKCL::CreateProcessGroupBKCL( - const std::shared_ptr& store, int rank, int size, int gid) { + const std::shared_ptr& store, + int rank, + int size, + int gid) { auto process_group = std::make_shared(store, rank, size, gid); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); diff --git a/paddle/fluid/distributed/collective/process_group_bkcl.h b/paddle/fluid/distributed/collective/process_group_bkcl.h index 15c908554bea7a729f85f0a939a7b748765d385c..cf8c983d8e66a884a54d2426b03685102140ebfb 100644 --- a/paddle/fluid/distributed/collective/process_group_bkcl.h +++ b/paddle/fluid/distributed/collective/process_group_bkcl.h @@ -21,11 +21,11 @@ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h" -#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/store/store.h" #if defined(PADDLE_WITH_XPU) #include "paddle/fluid/distributed/collective/bkcl_tools.h" @@ -67,13 +67,16 @@ class ProcessGroupBKCL : public ProcessGroupWithStream { }; public: - ProcessGroupBKCL(const std::shared_ptr& store, + ProcessGroupBKCL(const std::shared_ptr& store, int rank, int size, int gid); static std::shared_ptr CreateProcessGroupBKCL( - const std::shared_ptr& store, int rank, int size, int gid); + const std::shared_ptr& store, + int rank, + int size, + int gid); std::string GetBackendName() const override { return std::string(BKCL_BACKEND_NAME); @@ -176,7 +179,7 @@ class ProcessGroupBKCL : public ProcessGroupWithStream { void SyncCalcStream(const Place& place); private: - std::shared_ptr store_; + std::shared_ptr store_; std::mutex mutex_; std::shared_ptr calc_event_; // event on calc stream std::unordered_map place_to_calc_ctx_; diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 2fb23b455c3dbcdffe49bddccdd590df74224966..89ec97c08dd1db50e13738c3f26fa6c3a0a6ac1e 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -98,11 +98,12 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) { // Same as Wait void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); } -ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr& store, - const std::string& device_type, - int rank, - int size, - int gid) +ProcessGroupCustom::ProcessGroupCustom( + const std::shared_ptr& store, + const std::string& device_type, + int rank, + int size, + int gid) : ProcessGroupWithoutStream(rank, size, gid), store_(store), device_type_(device_type) {} @@ -438,7 +439,7 @@ std::shared_ptr ProcessGroupCustom::Broadcast( std::shared_ptr ProcessGroupCustom::CreateProcessGroupCustom( - const std::shared_ptr& store, + const std::shared_ptr& store, const std::string& device_type, int rank, int size, diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index 3169b9d5bc7469d5e307bd8327198507ae2675bc..82f236b331272c5c88fba0934e139d7d6007c34e 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -24,17 +24,18 @@ #include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_without_stream.h" -#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/platform/device/npu/npu_stream.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/distributed/store/store.h" namespace paddle { namespace distributed { using Place = paddle::platform::Place; using CustomDeviceContext = paddle::platform::CustomDeviceContext; + class ProcessGroupCustom : public ProcessGroupWithoutStream { public: class CustomTask : public ProcessGroup::Task, @@ -64,14 +65,14 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream { const std::string device_type_; }; - ProcessGroupCustom(const std::shared_ptr& store, + ProcessGroupCustom(const std::shared_ptr& store, const std::string& device_type, int rank, int size, int gid); static std::shared_ptr CreateProcessGroupCustom( - const std::shared_ptr& store, + const std::shared_ptr& store, const std::string& device_type, int rank, int size, @@ -127,7 +128,7 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream { CommType opType, const std::vector& inputs); - std::shared_ptr store_; + std::shared_ptr store_; std::shared_ptr custom_comm_; std::mutex mutex_; std::unordered_map& store, + const std::shared_ptr& store, int rank, int world_size, int gid, @@ -601,10 +601,11 @@ ProcessGroupGloo::createDefaultDevice() { 0, platform::errors::Fatal("Get hostname error for createDefaultDevice.")); ::addrinfo* result; - result = tcputils::get_addr_info(hostname.data(), "", 0, AF_UNSPEC); + result = phi::distributed::tcputils::get_addr_info( + hostname.data(), "", 0, AF_UNSPEC); ::addrinfo* cur; for (cur = result; cur != nullptr; cur = cur->ai_next) { - SocketType socket = + phi::distributed::SocketType socket = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); if (socket == -1) { continue; @@ -628,7 +629,10 @@ ProcessGroupGloo::createDefaultDevice() { } std::shared_ptr ProcessGroupGloo::CreateProcessGroupGloo( - const std::shared_ptr& store, int rank, int size, int gid) { + const std::shared_ptr& store, + int rank, + int size, + int gid) { std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; auto opts = GlooOptions::create(); char* ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); diff --git a/paddle/fluid/distributed/collective/process_group_gloo.h b/paddle/fluid/distributed/collective/process_group_gloo.h index 4a72a58ee19d0bdbf7c862a44f7cc9f50266db4d..5b41949f5210c99893cbe8e96062cb9a315c54e9 100644 --- a/paddle/fluid/distributed/collective/process_group_gloo.h +++ b/paddle/fluid/distributed/collective/process_group_gloo.h @@ -20,8 +20,8 @@ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_without_stream.h" -#include "paddle/fluid/distributed/store/store.h" -#include "paddle/fluid/distributed/store/tcp_store.h" +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/phi/core/distributed/store/tcp_store.h" #ifdef PADDLE_WITH_GLOO #include "paddle/fluid/framework/fleet/gloo_wrapper.h" @@ -52,7 +52,7 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { class GlooStore : public ::gloo::rendezvous::Store { public: - explicit GlooStore(const std::shared_ptr& store) + explicit GlooStore(const std::shared_ptr& store) : _store(store) {} ~GlooStore() = default; @@ -86,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { } protected: - std::shared_ptr _store; + std::shared_ptr _store; }; class GlooOptions { @@ -99,14 +99,14 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { std::shared_ptr<::gloo::transport::Device> device; }; - ProcessGroupGloo(const std::shared_ptr& store, + ProcessGroupGloo(const std::shared_ptr& store, int rank, int world_size, int gid, std::shared_ptr options); static std::shared_ptr CreateProcessGroupGloo( - const std::shared_ptr& store, + const std::shared_ptr& store, int rank, int world_size, int gid); diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 425edf6e3728ac738470e0477d4d53e95c230570..37fb6312e6356d6e8af34882711102d7af776c4b 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -86,10 +86,11 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { // Same as Wait void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); } -ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr& store, - int rank, - int size, - int gid) +ProcessGroupNCCL::ProcessGroupNCCL( + const std::shared_ptr& store, + int rank, + int size, + int gid) : ProcessGroupWithStream(rank, size, gid), store_(store) {} void ProcessGroupNCCL::GroupStart() { @@ -1151,7 +1152,10 @@ std::shared_ptr ProcessGroupNCCL::Scatter( } std::shared_ptr ProcessGroupNCCL::CreateProcessGroupNCCL( - const std::shared_ptr& store, int rank, int size, int gid) { + const std::shared_ptr& store, + int rank, + int size, + int gid) { auto process_group = std::make_shared(store, rank, size, gid); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index 9d268cb03f530ed3a241be08f914d677db9379f5..cb83b0ddfe748a0d6ff3731dfcf8bec96db744dd 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -22,10 +22,10 @@ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h" -#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/platform/device_event.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/store/store.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/nccl_tools.h" @@ -33,7 +33,7 @@ #ifdef PADDLE_WITH_RCCL #include "paddle/phi/backends/dynload/rccl.h" -#elif PADDLE_WITH_NCCL +#else #include "paddle/phi/backends/dynload/nccl.h" #endif @@ -76,9 +76,12 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { public: static std::shared_ptr CreateProcessGroupNCCL( - const std::shared_ptr& store, int rank, int size, int gid); + const std::shared_ptr& store, + int rank, + int size, + int gid); - ProcessGroupNCCL(const std::shared_ptr& store, + ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size, int gid); @@ -243,7 +246,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { const std::vector& places); private: - std::shared_ptr store_; + std::shared_ptr store_; std::unordered_map place_to_calc_event_; // event on calc stream diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index e29b3f6639f1ed3884bada90584aab174b3e58f5..7221832191abdc3ff3223fbaba702058b21863ba 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -34,7 +34,8 @@ register_operators( ${COLLECTIVE_DEPS}) if(WITH_NCCL OR WITH_RCCL) - set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) + set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper + comm_context_manager nccl_comm_context) op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) endif() diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index f2ff11316aa7ec02bf5d38b6bf1ffe50bf463918..1652e6c982e6546d7afbee61d9eaadc844804d54 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_broadcast_op.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/nccl_comm_context.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif -#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/phi/api/include/tensor.h" namespace paddle { @@ -31,66 +32,52 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto x = ctx.Input("X"); auto out = ctx.Output("Out"); - int numel = x->numel(); - ncclDataType_t dtype = - platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); int rid = ctx.Attr("ring_id"); - auto place = ctx.GetPlace(); - auto map = distributed::ProcessGroupMapFromGid::getInstance(); - if (map->has(rid)) { - // Use ProcessGroup - distributed::ProcessGroup* pg = map->get(rid); - std::vector in_tensor; - std::vector out_tensor; - in_tensor.push_back(*x); - out_tensor.push_back(*out); - auto task = pg->Broadcast(in_tensor, out_tensor); - task->Wait(); - return; - } - - auto comm = platform::NCCLCommContext::Instance().Get(rid, place); - gpuStream_t stream = nullptr; - if (ctx.Attr("use_calc_stream")) { - // should ExecutionContext for calc stream. - stream = ctx.cuda_device_context().stream(); - } else { - stream = comm->stream(); - } + const auto& place = ctx.GetPlace(); + ctx.device_context().Alloc(out); int root = ctx.Attr("root"); - if (root == comm->rank()) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast( - reinterpret_cast(const_cast(x->data())), - numel, - dtype, - root, - comm->comm(), - stream)); - VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " - << x->numel(); - if (out != x) { - framework::TensorCopy( - *static_cast(x), - place, - *platform::DeviceContextPool::Instance().Get(place), - static_cast(out)); - } + gpuStream_t stream = ctx.cuda_device_context().stream(); + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + if (comm_context_manager.Has(rid)) { + auto* comm_context = static_cast( + comm_context_manager.Get(rid)); + + comm_context->Broadcast(out, *x, root, stream); } else { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::ncclBcast(out->mutable_data(place), - numel, - dtype, - root, - comm->comm(), - stream)); - VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received " - << phi::product(out->dims()); + // NOTE(liyurui): This will be removed after moving this operator to phi. + int numel = x->numel(); + ncclDataType_t dtype = + platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + if (root == comm->rank()) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast( + reinterpret_cast(const_cast(x->data())), + numel, + dtype, + root, + comm->comm(), + stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " + << x->numel(); + if (out != x) { + framework::TensorCopy( + *static_cast(x), + place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast( + out->data(), numel, dtype, root, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received " + << phi::product(out->dims()); + } } - out->Resize(x->dims()); out->set_lod(x->lod()); #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 5636f5678d774cff700b3f26d8e4c302a1e63dc6..bac22532cbb421401bb74970375c5d782fc6a55b 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -37,6 +37,7 @@ set(PYBIND_DEPS global_utils phi_utils tcp_store + comm_context_manager new_profiler auto_parallel jit_layer diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 3803d132515a5beb7d5bd84e7f45b8107f1236e8..5be251916146ded74ffa4d16cf9d685be6a4b7c8 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -21,49 +21,64 @@ limitations under the License. */ #include #include +#include #include -#include "paddle/fluid/distributed/store/tcp_store.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/store/tcp_store.h" namespace py = pybind11; namespace paddle { namespace pybind { -using TCPStore = paddle::distributed::TCPStore; +void BindCommContextManager(py::module *m) { + auto CommContextManager = + py::class_>( + *m, "CommContextManager") +#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) + .def_static( + "create_nccl_comm_context", + &phi::distributed::CommContextManager::CreateNCCLCommContext, + py::call_guard()) +#endif + .def("set_store", &phi::distributed::CommContextManager::SetStore); +} + +using TCPStore = phi::distributed::TCPStore; void BindTCPStore(py::module *m) { - auto Store = - py::class_>( - *m, "Store") - .def(py::init<>()) - .def( - "set", - [](distributed::Store &self, - const std::string &key, - const std::string &value) { - std::vector data(value.begin(), value.end()); - self.set(key, data); - }, - py::arg("key"), - py::arg("value"), - py::call_guard()) - .def( - "get", - [](distributed::Store &self, - const std::string &key) -> py::bytes { - auto data = self.get(key); - return py::bytes(reinterpret_cast(data.data()), - data.size()); - }, - py::arg("key"), - py::call_guard()) - .def("add", - &distributed::Store::add, - py::call_guard()) - .def("wait", - &distributed::Store::wait, - py::call_guard()); + auto Store = py::class_>(*m, "Store") + .def(py::init<>()) + .def( + "set", + [](phi::distributed::Store &self, + const std::string &key, + const std::string &value) { + std::vector data(value.begin(), value.end()); + self.set(key, data); + }, + py::arg("key"), + py::arg("value"), + py::call_guard()) + .def( + "get", + [](phi::distributed::Store &self, + const std::string &key) -> py::bytes { + auto data = self.get(key); + return py::bytes(reinterpret_cast(data.data()), + data.size()); + }, + py::arg("key"), + py::call_guard()) + .def("add", + &phi::distributed::Store::add, + py::call_guard()) + .def("wait", + &phi::distributed::Store::wait, + py::call_guard()); py::class_>(*m, "TCPStore", Store) .def(py::init([](std::string hostname, diff --git a/paddle/fluid/pybind/communication.h b/paddle/fluid/pybind/communication.h index 17045ccfe65cae25471ceff3abf0129b2a21acb0..b22750afe9a817e1a32db5a14b04c09d950018dc 100644 --- a/paddle/fluid/pybind/communication.h +++ b/paddle/fluid/pybind/communication.h @@ -26,6 +26,7 @@ namespace paddle { namespace pybind { void BindTCPStore(pybind11::module* m); +void BindCommContextManager(pybind11::module* m); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 8bac6f92e7b4d1a25b59417e0783b2dfa8478cf1..6bf409d527cca935f2d3e24b17c9a0f7147d61c1 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -46,7 +46,6 @@ limitations under the License. */ #if defined(PADDLE_WITH_GLOO) #include "paddle/fluid/distributed/collective/process_group_gloo.h" -#include "paddle/fluid/distributed/store/tcp_store.h" #endif #if defined(PADDLE_WITH_XPU_BKCL) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c457b14325e7977287cd8b59279211c6b486c659..1be0b13789aebffda0e7d3a534958630533ea079 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1871,6 +1871,7 @@ All parameter, weight, gradient are variables in Paddle. BindGlobalValueGetterSetter(&m); BindFleetExecutor(&m); BindTCPStore(&m); + BindCommContextManager(&m); BindAutoParallel(&m); BindJitProperty(&m); diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 08004673edfcee93645836ad5e1175fb7f9bba9c..e47e3a731c41465d569f3ab8b74e92fa172dc2a9 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -1,5 +1,6 @@ # compatible utils used for fluid op system add_subdirectory(compat) +add_subdirectory(distributed) if(WITH_GPU) proto_library(external_error_proto SRCS external_error.proto) diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..92a5f1078715a03f66ae6fd3b5d211ffbd7a7073 --- /dev/null +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -0,0 +1,16 @@ +add_subdirectory(store) + +set(COMM_CONTEXT_MANAGER_DEPS tcp_store) + +if(WITH_NCCL OR WITH_RCCL) + cc_library( + nccl_comm_context + SRCS nccl_comm_context.cc + DEPS dense_tensor) + list(APPEND COMM_CONTEXT_MANAGER_DEPS nccl_comm_context) +endif() + +cc_library( + comm_context_manager + SRCS comm_context_manager.cc + DEPS ${COMM_CONTEXT_MANAGER_DEPS}) diff --git a/paddle/phi/core/distributed/comm_context.h b/paddle/phi/core/distributed/comm_context.h new file mode 100644 index 0000000000000000000000000000000000000000..0a878e6e8bce64ab7fe2c086778ca72b2d38154c --- /dev/null +++ b/paddle/phi/core/distributed/comm_context.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/macros.h" + +namespace phi { +namespace distributed { + +class CommContext { + public: + CommContext(int rank, int size) : rank_(rank), size_(size) {} + virtual ~CommContext() = default; + + protected: + int rank_; + int size_; + + private: + DISABLE_COPY_AND_ASSIGN(CommContext); +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..818952736e009374ba004b71cf6266b20d1be397 --- /dev/null +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/comm_context_manager.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/phi/core/enforce.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#endif + +namespace phi { +namespace distributed { + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +void CommContextManager::CreateNCCLCommContext( + const std::shared_ptr& store, + int dev_id, + int ring_id, + int rank, + int size) { + phi::backends::gpu::SetDeviceId(dev_id); + ncclUniqueId nccl_id; + if (rank == 0) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id)); + } + + std::string unique_key = "NCCLCommContext/" + std::to_string(ring_id); + if (rank == 0) { + std::vector nccl_id_wrapper( + reinterpret_cast(&nccl_id), + reinterpret_cast(&nccl_id) + NCCL_UNIQUE_ID_BYTES); + store->set(unique_key, nccl_id_wrapper); + } else { + const auto& nccl_id_wrapper = store->get(unique_key); + std::memcpy(&nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size()); + } + + auto nccl_comm_context = + std::make_unique(rank, size, nccl_id); + auto& comm_context_manager = CommContextManager::GetInstance(); + comm_context_manager.SetStore(store); + comm_context_manager.Emplace(ring_id, std::move(nccl_comm_context)); +} +#endif + +CommContext* CommContextManager::Emplace( + int ring_id, std::unique_ptr comm_context) { + PADDLE_ENFORCE_EQ( + id_to_comm_context_.find(ring_id), + id_to_comm_context_.end(), + errors::AlreadyExists("Ring id %d already exists in the map.", ring_id)); + id_to_comm_context_.emplace(ring_id, std::move(comm_context)); + return id_to_comm_context_.at(ring_id).get(); +} + +CommContext* CommContextManager::Get(int ring_id) const { + PADDLE_ENFORCE_NE( + id_to_comm_context_.find(ring_id), + id_to_comm_context_.end(), + errors::NotFound("Can not find ring id %d in map.", ring_id)); + + return id_to_comm_context_.at(ring_id).get(); +} + +bool CommContextManager::Has(int ring_id) const { + return id_to_comm_context_.find(ring_id) != id_to_comm_context_.end(); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..5d57856eeedf687c1dc1ca3717d8295659a01c97 --- /dev/null +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/core/distributed/comm_context.h" +#include "paddle/phi/core/macros.h" + +namespace phi { +namespace distributed { + +class Store; + +class CommContextManager { + public: + CommContextManager() = default; + ~CommContextManager() = default; + + static CommContextManager& GetInstance() { + static CommContextManager instance; + return instance; + } + + void SetStore(const std::shared_ptr& store) { store_ = store; } + + CommContext* Emplace(int ring_id, std::unique_ptr comm_context); + + CommContext* Get(int ring_id) const; + + bool Has(int ring_id) const; + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + static void CreateNCCLCommContext(const std::shared_ptr& store, + int dev_id, + int ring_id, + int rank, + int size); +#endif + + private: + DISABLE_COPY_AND_ASSIGN(CommContextManager); + + std::unordered_map> id_to_comm_context_; + std::shared_ptr store_; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..32c1a2e744fc5d9ff827833a00a9ea9c8ace1fad --- /dev/null +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { +namespace distributed { + +NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id) + : CommContext(rank, size) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_)); +} + +void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root, + gpuStream_t stream) { + phi::dynload::ncclBroadcast(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + root, + nccl_comm_, + stream); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_context.h b/paddle/phi/core/distributed/nccl_comm_context.h new file mode 100644 index 0000000000000000000000000000000000000000..8e1590cfc254dca8e756b62810f207d95f3747cb --- /dev/null +++ b/paddle/phi/core/distributed/nccl_comm_context.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/core/distributed/comm_context.h" +#include "paddle/phi/core/macros.h" + +#if defined(PADDLE_WITH_RCCL) +#include "paddle/phi/backends/dynload/rccl.h" +#else +#include "paddle/phi/backends/dynload/nccl.h" +#endif + +namespace phi { +class DenseTensor; +namespace distributed { + +class NCCLCommContext final : public CommContext { + public: + NCCLCommContext(int rank, int size, ncclUniqueId nccl_id); + + void Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root, + gpuStream_t stream); + + private: + DISABLE_COPY_AND_ASSIGN(NCCLCommContext); + + ncclComm_t nccl_comm_; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/fluid/distributed/store/CMakeLists.txt b/paddle/phi/core/distributed/store/CMakeLists.txt similarity index 73% rename from paddle/fluid/distributed/store/CMakeLists.txt rename to paddle/phi/core/distributed/store/CMakeLists.txt index 111a8e95d38bb87d9f5370608c200500c0f53321..ac5c8ae9f5c789e9ff6fff68f6899e3c21e93e85 100644 --- a/paddle/fluid/distributed/store/CMakeLists.txt +++ b/paddle/phi/core/distributed/store/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library( tcp_store - SRCS tcp_store.cc tcp_utils.cc socket.cpp + SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc DEPS enforce glog) if(NOT WIN32) diff --git a/paddle/fluid/distributed/store/socket.cpp b/paddle/phi/core/distributed/store/socket.cpp similarity index 94% rename from paddle/fluid/distributed/store/socket.cpp rename to paddle/phi/core/distributed/store/socket.cpp index ca6dc0f02902af6f0b5ccf402f1a64def9e4f91f..122ab124dae82af9684a15ee97031f7e46f48616 100644 --- a/paddle/fluid/distributed/store/socket.cpp +++ b/paddle/phi/core/distributed/store/socket.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/distributed/store/socket.h" +#include "paddle/phi/core/distributed/store/socket.h" #ifndef _WIN32 #include @@ -23,7 +23,7 @@ #include #include -namespace paddle { +namespace phi { namespace distributed { #ifdef _WIN32 @@ -75,5 +75,5 @@ std::string GetSockName(int fd) { return std::string(out); } -}; // namespace distributed -}; // namespace paddle +} // namespace distributed +} // namespace phi diff --git a/paddle/fluid/distributed/store/socket.h b/paddle/phi/core/distributed/store/socket.h similarity index 91% rename from paddle/fluid/distributed/store/socket.h rename to paddle/phi/core/distributed/store/socket.h index f423d2643354bd4d7aa83defc7ea98e35dd26bb3..028e88786749b9629a81e56c0945aa60abd6ba05 100644 --- a/paddle/fluid/distributed/store/socket.h +++ b/paddle/phi/core/distributed/store/socket.h @@ -16,11 +16,11 @@ #include -namespace paddle { +namespace phi { namespace distributed { int GetSockName(int fd, char* out, int out_len); std::string GetSockName(int fd); -}; // namespace distributed -}; // namespace paddle +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/store/store.cc b/paddle/phi/core/distributed/store/store.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e7db8895b99f14a0b4a458f309c21a17c9f5bc3 --- /dev/null +++ b/paddle/phi/core/distributed/store/store.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { +namespace distributed { + +int64_t Store::add(const std::string& key, int64_t value) { + PADDLE_THROW( + errors::InvalidArgument("Implement the add method in the subclass.")); +} + +std::vector Store::get(const std::string& key) { + PADDLE_THROW( + errors::InvalidArgument("Implement the get method in the subclass.")); +} + +void Store::wait(const std::string& key) { + PADDLE_THROW( + errors::InvalidArgument("Implement the wait method in the subclass.")); +} + +void Store::set(const std::string& key, const std::vector& value) { + PADDLE_THROW( + errors::InvalidArgument("Implement the set method in the subclass.")); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/fluid/distributed/store/store.h b/paddle/phi/core/distributed/store/store.h similarity index 59% rename from paddle/fluid/distributed/store/store.h rename to paddle/phi/core/distributed/store/store.h index eb329276d67b1ac446276c08eab6ad57f7041cfb..fa509586eefdf210fd22a054eb3544e31b8819f6 100644 --- a/paddle/fluid/distributed/store/store.h +++ b/paddle/phi/core/distributed/store/store.h @@ -18,9 +18,7 @@ #include #include -#include "paddle/fluid/distributed/store/tcp_utils.h" - -namespace paddle { +namespace phi { namespace distributed { class Store { @@ -29,22 +27,10 @@ class Store { explicit Store(const int timeout) : _timeout(timeout) {} virtual ~Store() = default; - virtual int64_t add(const std::string& key, int64_t value) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Implement the add method in the subclass.")); - } - virtual std::vector get(const std::string& key) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Implement the add method in the subclass.")); - } - virtual void wait(const std::string& key) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Implement the add method in the subclass.")); - } - virtual void set(const std::string& key, const std::vector& value) { - PADDLE_THROW(platform::errors::InvalidArgument( - "Implement the add method in the subclass.")); - } + virtual int64_t add(const std::string& key, int64_t value); + virtual std::vector get(const std::string& key); + virtual void wait(const std::string& key); + virtual void set(const std::string& key, const std::vector& value); virtual int timeout() { return _timeout; } @@ -53,4 +39,4 @@ class Store { }; } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/fluid/distributed/store/tcp_store.cc b/paddle/phi/core/distributed/store/tcp_store.cc similarity index 92% rename from paddle/fluid/distributed/store/tcp_store.cc rename to paddle/phi/core/distributed/store/tcp_store.cc index 0ecfcef42458d18a482860579111031f9b3ab787..34aa24216826b368392670b504adff2c00a9336c 100644 --- a/paddle/fluid/distributed/store/tcp_store.cc +++ b/paddle/phi/core/distributed/store/tcp_store.cc @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/distributed/store/tcp_store.h" +#include "paddle/phi/core/distributed/store/tcp_store.h" #include #include #include -#include "paddle/fluid/distributed/store/tcp_utils.h" -#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/distributed/store/tcp_utils.h" #include "paddle/phi/core/flags.h" -namespace paddle { +namespace phi { namespace distributed { namespace detail { @@ -90,7 +89,7 @@ void MasterDaemon::_do_get(SocketType socket) { PADDLE_ENFORCE_NE( iter, _store.end(), - platform::errors::InvalidArgument("Key %s not found in TCPStore.", key)); + phi::errors::InvalidArgument("Key %s not found in TCPStore.", key)); std::vector value = iter->second; tcputils::send_vector(socket, value); } @@ -100,7 +99,7 @@ void MasterDaemon::InitControlFd() { PADDLE_ENFORCE_NE( pipe(_control_fd.data()), -1, - platform::errors::Fatal("failed to cread control pipe errno:%d", errno)); + phi::errors::Fatal("failed to cread control pipe errno:%d", errno)); } void MasterDaemon::CloseControlFd() { for (int fd : _control_fd) { @@ -112,10 +111,10 @@ void MasterDaemon::CloseControlFd() { void MasterDaemon::StopByControlFd() { VLOG(4) << ("begin to run StopByControlFd"); if (_control_fd[1] != -1) { - PADDLE_ENFORCE_NE(::write(_control_fd[1], "\0", 1), - -1, - platform::errors::Fatal( - "failed to write control pipe errno:%d", errno)); + PADDLE_ENFORCE_NE( + ::write(_control_fd[1], "\0", 1), + -1, + phi::errors::Fatal("failed to write control pipe errno:%d", errno)); // close the write end of the pipe ::close(_control_fd[1]); _control_fd[1] = -1; @@ -125,7 +124,7 @@ void MasterDaemon::StopByControlFd() { void MasterDaemon::InitControlFd() { ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL); PADDLE_ENFORCE(ghStopEvent_, - platform::errors::Fatal("failed to cread control pipe")); + phi::errors::Fatal("failed to cread control pipe")); } void MasterDaemon::CloseControlFd() { CloseHandle(ghStopEvent_); } void MasterDaemon::StopByControlFd() { SetEvent(ghStopEvent_); } @@ -231,8 +230,8 @@ void MasterDaemon::run() { // The control pipe receive shutdown event, and begin to close it. if (fds[1].revents != 0) { if (fds[1].revents & ~(POLLIN | POLLHUP)) { - PADDLE_THROW(paddle::platform::errors::Fatal("Undefined event type:%d", - fds[1].revents)); + PADDLE_THROW( + phi::errors::Fatal("Undefined event type:%d", fds[1].revents)); } VLOG(0) << "receive shutdown event and so quit from MasterDaemon run loop"; @@ -312,9 +311,7 @@ TCPStore::TCPStore(std::string host, : Store(timeout), _is_master(is_master), _num_workers(num_workers) { _timeout = timeout; PADDLE_ENFORCE_GT( - timeout, - 0, - platform::errors::InvalidArgument("timeout must >= %d", timeout)); + timeout, 0, phi::errors::InvalidArgument("timeout must >= %d", timeout)); VLOG(3) << "input timeout" << timeout << ", member timeout:" << _timeout; if (_is_master) { @@ -355,7 +352,7 @@ void TCPStore::waitWorkers() { PADDLE_ENFORCE_EQ( completed, _num_workers, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "TCPStore timeouted and not all workers got ready.")); } } while (true); @@ -398,4 +395,4 @@ void TCPStore::wait(const std::string& key) { TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; } } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/fluid/distributed/store/tcp_store.h b/paddle/phi/core/distributed/store/tcp_store.h similarity index 95% rename from paddle/fluid/distributed/store/tcp_store.h rename to paddle/phi/core/distributed/store/tcp_store.h index 06f2ce55041b1f0a976caedcd13924c62b03e901..663275242d8ab4f28fc3ab296614accc09bc4e3b 100644 --- a/paddle/fluid/distributed/store/tcp_store.h +++ b/paddle/phi/core/distributed/store/tcp_store.h @@ -30,11 +30,11 @@ #include #include -#include "paddle/fluid/distributed/store/socket.h" -#include "paddle/fluid/distributed/store/store.h" -#include "paddle/fluid/distributed/store/tcp_utils.h" +#include "paddle/phi/core/distributed/store/socket.h" +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/phi/core/distributed/store/tcp_utils.h" -namespace paddle { +namespace phi { namespace distributed { enum class ReplyType { WAITING, STOP_WAIT }; @@ -143,4 +143,4 @@ class TCPStore : public Store { }; } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/fluid/distributed/store/tcp_utils.cc b/paddle/phi/core/distributed/store/tcp_utils.cc similarity index 89% rename from paddle/fluid/distributed/store/tcp_utils.cc rename to paddle/phi/core/distributed/store/tcp_utils.cc index 0a419aa95d7856f88d88108801408375f58b6613..d7b1fd3b972edfa43b5af60e74d46d76c841de13 100644 --- a/paddle/fluid/distributed/store/tcp_utils.cc +++ b/paddle/phi/core/distributed/store/tcp_utils.cc @@ -12,15 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/distributed/store/tcp_utils.h" +#include "paddle/phi/core/distributed/store/tcp_utils.h" #include #include #include -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { +namespace phi { namespace distributed { namespace tcputils { @@ -60,7 +58,7 @@ void close_socket(SocketType socket) { : ""); PADDLE_ENFORCE_EQ(n, 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "%s network %s:%s cannot be obtained. Details: %s.", proto, host, @@ -73,7 +71,7 @@ void close_socket(SocketType socket) { void free_addr_info(::addrinfo* hint) { PADDLE_ENFORCE_NOT_NULL( hint, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The parameter for free_addr_info cannot be null.")); ::freeaddrinfo(hint); } @@ -91,14 +89,14 @@ SocketType tcp_connect(const std::string host, do { for (::addrinfo* cur = res; cur != nullptr; cur = cur->ai_next) { sockfd = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol); - PADDLE_ENFORCE_GT(sockfd, - 0, - platform::errors::InvalidArgument( - "Create socket to connect %s:%s failed. " - "Details: %s. ", - host, - port, - socket_error().message())); + PADDLE_ENFORCE_GT( + sockfd, + 0, + phi::errors::InvalidArgument("Create socket to connect %s:%s failed. " + "Details: %s. ", + host, + port, + socket_error().message())); if (::connect(sockfd, cur->ai_addr, cur->ai_addrlen) == 0) { retry = false; @@ -125,7 +123,7 @@ SocketType tcp_connect(const std::string host, PADDLE_ENFORCE_GT(sockfd, 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Network %s:%s cannot be connected.", host, port)); VLOG(0) << "Successfully connected to " << host << ":" << port; @@ -173,7 +171,7 @@ SocketType tcp_listen(const std::string host, PADDLE_ENFORCE_GT(sockfd, 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Bind network on %s:%s failedd.", node, port)); ::listen(sockfd, LISTENQ); @@ -190,7 +188,7 @@ SocketType tcp_accept(SocketType socket) { PADDLE_ENFORCE_GT( new_socket, 0, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The server failed to accept a new connection. Details: %s.", socket_error().message())); #ifndef _WIN32 @@ -225,4 +223,4 @@ std::string receive_string(SocketType socket) { } // namespace tcputils } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/fluid/distributed/store/tcp_utils.h b/paddle/phi/core/distributed/store/tcp_utils.h similarity index 90% rename from paddle/fluid/distributed/store/tcp_utils.h rename to paddle/phi/core/distributed/store/tcp_utils.h index 7aa38cf548fb256cc4f94866991f67214fa77b74..af11ad27f04254a11f86d6abd3fac26994d8b12f 100644 --- a/paddle/fluid/distributed/store/tcp_utils.h +++ b/paddle/phi/core/distributed/store/tcp_utils.h @@ -26,14 +26,15 @@ #include #include #endif + #include #include #include -#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/enforce.h" // Utility functions for TCP socket. -namespace paddle { +namespace phi { namespace distributed { #ifdef _WIN32 @@ -82,8 +83,8 @@ void send_bytes(SocketType socket, const T* buffer, size_t len) { PADDLE_ENFORCE_GT( byte_sent, 0, - platform::errors::InvalidArgument("TCP send error. Details: %s.", - socket_error().message())); + phi::errors::InvalidArgument("TCP send error. Details: %s.", + socket_error().message())); to_send -= byte_sent; ptr += byte_sent; } @@ -102,8 +103,8 @@ void receive_bytes(SocketType socket, T* buffer, size_t len) { PADDLE_ENFORCE_GT( byte_received, 0, - platform::errors::InvalidArgument("TCP receive error. Details: %s.", - socket_error().message())); + phi::errors::InvalidArgument("TCP receive error. Details: %s.", + socket_error().message())); to_recv -= byte_received; ptr += byte_received; @@ -140,4 +141,4 @@ T receive_value(SocketType socket) { } // namespace tcputils } // namespace distributed -} // namespace paddle +} // namespace phi diff --git a/paddle/fluid/distributed/store/test_tcp_store.cc b/paddle/phi/core/distributed/store/test_tcp_store.cc similarity index 88% rename from paddle/fluid/distributed/store/test_tcp_store.cc rename to paddle/phi/core/distributed/store/test_tcp_store.cc index 45bf56953598a9a4541b588c87e761f2a030e75c..e101f573db9a61f1800bc79614bf4fef65426aed 100644 --- a/paddle/fluid/distributed/store/test_tcp_store.cc +++ b/paddle/phi/core/distributed/store/test_tcp_store.cc @@ -13,14 +13,14 @@ // limitations under the License. #include "gtest/gtest.h" -#include "paddle/fluid/distributed/store/tcp_store.h" -#include "paddle/fluid/distributed/store/tcp_utils.h" +#include "paddle/phi/core/distributed/store/tcp_store.h" +#include "paddle/phi/core/distributed/store/tcp_utils.h" #ifdef _WIN32 #include #endif -namespace paddle { +namespace phi { namespace distributed { TEST(MasterDaemon, init) { @@ -48,6 +48,5 @@ TEST(TCPStore, init) { paddle::errors::Fatal("result of add is not right")); } */ - -}; // namespace distributed -}; // namespace paddle +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/utils/data_type.h b/paddle/phi/core/utils/data_type.h index edb841aeb1caaf2edb24430e167c49868b906581..16b73e0f2baa6738702971a89101df99ce68c99f 100644 --- a/paddle/phi/core/utils/data_type.h +++ b/paddle/phi/core/utils/data_type.h @@ -211,4 +211,33 @@ inline int TransToProtoVarType(const DataType& dtype) { } } +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +inline ncclDataType_t ToNCCLDataType(DataType type) { + if (type == DataType::FLOAT32) { + return ncclFloat; + } else if (type == DataType::FLOAT64) { + return ncclDouble; + } else if (type == DataType::INT32) { + return ncclInt; + } else if (type == DataType::INT64) { + return ncclInt64; + } else if (type == DataType::FLOAT16) { + return ncclFloat16; + } else if (type == DataType::UINT8) { + return ncclUint8; + } else if (type == DataType::INT8) { + return ncclInt8; + } else if (type == DataType::BOOL) { + return ncclUint8; +#if NCCL_VERSION_CODE >= 21000 + } else if (type == DataType::BFLOAT16) { + return ncclBfloat16; +#endif + } else { + PADDLE_THROW( + errors::Unimplemented("This datatype in nccl is not supported.")); + } +} +#endif + } // namespace phi diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 767038addb7e29a5209cfda1f8eec45d65747da5..90a56530154803f1e44526843ba650932d54b316 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import os import paddle @@ -325,3 +326,25 @@ def is_available(): """ return core.is_compiled_with_dist() + + +def _init_parallel_env(backend): + master_endpoint = os.getenv("PADDLE_MASTER", None) + if master_endpoint: + master_addr = master_endpoint.split(":")[0] + master_port = int(master_endpoint.split(":")[1]) + global_env = _get_global_env() + rank = global_env.rank + world_size = global_env.world_size + dev_id = global_env.device_id + is_master = rank == 0 + store = core.TCPStore( + master_addr, + master_port, + is_master, + world_size, + ) + if backend == "nccl": + core.CommContextManager.create_nccl_comm_context( + store, dev_id, 0, rank, world_size + ) diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 99a71146104ef6b2ba30e41031b2f2bb0c9c7219..7dad9831e744d7237cd6ce95416be3c9646f7dad 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -243,6 +243,7 @@ def init_parallel_env(): _set_expected_place(place) group = None + if backend in _valid_backend_list and in_dygraph_mode(): if _default_group_name in _get_group_map_by_name(): return _get_group_map_by_name()[_default_group_name] diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py index 40cc5c64b4e1517eb643efefa8b91ae900db2185..6a3fc9ba1e3be9dffff61e1c9e941164cab50e46 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_broadcast_api.py @@ -30,6 +30,14 @@ class TestCollectiveBroadcastAPI(TestDistBase): "collective_broadcast_api.py", "broadcast", "nccl" ) + def test_broadcast_nccl_with_comm_context(self): + self.check_with_place( + "collective_broadcast_api.py", + "broadcast", + "nccl", + need_envs={"USE_COMM_CONTEXT": "1"}, + ) + def test_broadcast_gloo(self): self.check_with_place( "collective_broadcast_api.py", "broadcast", "gloo", "0" diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index ecabdb92fcb2a2b83aea1d0d820f66079f8d54bf..ea84008775623da225b0d0d6bdc7c4ee8f0bc21e 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -108,7 +108,10 @@ class TestCollectiveAPIRunnerBase: rank = args["trainerid"] current_endpoint = args["currentendpoint"] nranks = 2 - paddle.distributed.init_parallel_env() + if args["use_comm_context"]: + paddle.distributed.collective._init_parallel_env(args["backend"]) + else: + paddle.distributed.init_parallel_env() if args['backend'] == 'nccl': device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace( @@ -150,6 +153,7 @@ def runtime_main(test_class, col_type): args["path_id"] = int(os.getenv("PATH_ID")) args["static_mode"] = int(os.getenv("STATIC_MODE")) args["dtype"] = os.getenv("DTYPE") + args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0"))) model.run_trainer(args) @@ -162,6 +166,7 @@ class TestDistBase(unittest.TestCase): self._find_free_port(), ) self._python_interp = sys.executable + self._master_endpoints = "127.0.0.1:%s" % (self._find_free_port()) self.temp_dir = tempfile.TemporaryDirectory() @@ -204,6 +209,7 @@ class TestDistBase(unittest.TestCase): "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_CURRENT_ENDPOINT": w0_ep, + "PADDLE_MASTER": self._master_endpoints, } env1 = { @@ -212,6 +218,7 @@ class TestDistBase(unittest.TestCase): "PADDLE_TRAINERS_NUM": "2", "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, "PADDLE_CURRENT_ENDPOINT": w1_ep, + "PADDLE_MASTER": self._master_endpoints, } elif core.is_compiled_with_xpu(): env0 = {