未验证 提交 2a47416c 编写于 作者: L LiYuRio 提交者: GitHub

add new map instance (#48145)

上级 403d58bb
...@@ -35,17 +35,6 @@ void ProcessGroup::Task::Synchronize() {} ...@@ -35,17 +35,6 @@ void ProcessGroup::Task::Synchronize() {}
void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {} void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {}
ProcessGroup::ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid)
: rank_(rank), size_(size), place_(place), gid_(gid) {
if (gid != IGNORE_ID) {
auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this);
}
}
ProcessGroup::ProcessGroup(int rank, int size, int gid) ProcessGroup::ProcessGroup(int rank, int size, int gid)
: rank_(rank), size_(size), gid_(gid) { : rank_(rank), size_(size), gid_(gid) {
if (gid != IGNORE_ID) { if (gid != IGNORE_ID) {
...@@ -66,5 +55,10 @@ ProcessGroup::Task::Task(int rank, ...@@ -66,5 +55,10 @@ ProcessGroup::Task::Task(int rank,
bool sync_op) bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {} : rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
static ProcessGroupIdMap instance;
return instance;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -82,13 +82,8 @@ class ProcessGroup { ...@@ -82,13 +82,8 @@ class ProcessGroup {
}; };
public: public:
explicit ProcessGroup(int rank, int size, int gid); ProcessGroup(int rank, int size, int gid);
virtual ~ProcessGroup() = default; virtual ~ProcessGroup() = default;
// TODO(dev): This constructor will be removed later.
explicit ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid);
int GetRank() const { return rank_; } int GetRank() const { return rank_; }
...@@ -290,12 +285,18 @@ class ProcessGroup { ...@@ -290,12 +285,18 @@ class ProcessGroup {
} }
protected: protected:
const int rank_; int rank_;
const int size_; int size_;
const platform::Place place_; int gid_;
const int gid_;
}; };
class ProcessGroupIdMap
: public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
public:
static ProcessGroupIdMap& GetInstance();
};
// TODO(dev): The following method will be removed soon.
class ProcessGroupMapFromGid { class ProcessGroupMapFromGid {
public: public:
bool has(int gid) { bool has(int gid) {
......
...@@ -531,5 +531,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather( ...@@ -531,5 +531,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
/*use_calc_stream*/ false); /*use_calc_stream*/ false);
} }
std::shared_ptr<ProcessGroupBKCL> ProcessGroupBKCL::CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
auto process_group =
std::make_shared<ProcessGroupBKCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -73,6 +73,9 @@ class ProcessGroupBKCL : public ProcessGroupStream { ...@@ -73,6 +73,9 @@ class ProcessGroupBKCL : public ProcessGroupStream {
int size, int size,
int gid); int gid);
static std::shared_ptr<ProcessGroupBKCL> CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid);
std::string GetBackendName() const override { std::string GetBackendName() const override {
return std::string(BKCL_BACKEND_NAME); return std::string(BKCL_BACKEND_NAME);
} }
......
...@@ -433,5 +433,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast( ...@@ -433,5 +433,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
CommType::BROADCAST); CommType::BROADCAST);
} }
std::shared_ptr<ProcessGroupCustom>
ProcessGroupCustom::CreateProcessGroupCustom(
const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank,
int size,
int gid) {
auto process_group =
std::make_shared<ProcessGroupCustom>(store, device_type, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -69,6 +69,13 @@ class ProcessGroupCustom : public ProcessGroup { ...@@ -69,6 +69,13 @@ class ProcessGroupCustom : public ProcessGroup {
int size, int size,
int gid); int gid);
static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);
std::string GetBackendName() const override { return "XCCL_" + device_type_; } std::string GetBackendName() const override { return "XCCL_" + device_type_; }
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
......
...@@ -617,5 +617,22 @@ ProcessGroupGloo::createDefaultDevice() { ...@@ -617,5 +617,22 @@ ProcessGroupGloo::createDefaultDevice() {
return createDeviceForHostname("127.0.0.1"); return createDeviceForHostname("127.0.0.1");
} }
std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
auto opts = GlooOptions::create();
char* ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) {
opts->device =
ProcessGroupGloo::createDeviceForInterface(std::string(ifname));
} else {
opts->device = ProcessGroupGloo::createDefaultDevice();
}
auto process_group =
std::make_shared<ProcessGroupGloo>(store, rank, size, gid, opts);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <future> #include <future>
#include <memory>
#include <mutex> #include <mutex>
#include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/distributed/collective/ProcessGroup.h"
...@@ -98,12 +99,17 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -98,12 +99,17 @@ class ProcessGroupGloo : public ProcessGroup {
std::shared_ptr<::gloo::transport::Device> device; std::shared_ptr<::gloo::transport::Device> device;
}; };
explicit ProcessGroupGloo( ProcessGroupGloo(const std::shared_ptr<paddle::distributed::Store>& store,
int rank,
int world_size,
int gid,
std::shared_ptr<GlooOptions> options);
static std::shared_ptr<ProcessGroupGloo> CreateProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store, const std::shared_ptr<paddle::distributed::Store>& store,
int rank, int rank,
int world_size, int world_size,
int gid, int gid);
std::shared_ptr<GlooOptions> options);
~ProcessGroupGloo() = default; ~ProcessGroupGloo() = default;
...@@ -191,7 +197,7 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -191,7 +197,7 @@ class ProcessGroupGloo : public ProcessGroup {
const std::string& ifname); const std::string& ifname);
static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); static std::shared_ptr<::gloo::transport::Device> createDefaultDevice();
protected: private:
uint32_t _tag; uint32_t _tag;
std::shared_ptr<gloo::rendezvous::Context> _context; std::shared_ptr<gloo::rendezvous::Context> _context;
std::shared_ptr<::gloo::rendezvous::Store> _store; std::shared_ptr<::gloo::rendezvous::Store> _store;
......
...@@ -1130,5 +1130,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1130,5 +1130,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
CommType::SCATTER); CommType::SCATTER);
} }
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -76,6 +76,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -76,6 +76,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
}; };
public: public:
static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid);
ProcessGroupNCCL(const std::shared_ptr<Store>& store, ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
......
...@@ -85,8 +85,6 @@ using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore; ...@@ -85,8 +85,6 @@ using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore;
using GlooOptions = paddle::distributed::ProcessGroupGloo::GlooOptions; using GlooOptions = paddle::distributed::ProcessGroupGloo::GlooOptions;
#endif #endif
static std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; // NOLINT
static UNUSED void *use_ccl_comm_func = static UNUSED void *use_ccl_comm_func =
phi::detail::GetCCLComm(phi::CPUPlace()); phi::detail::GetCCLComm(phi::CPUPlace());
...@@ -1221,24 +1219,18 @@ void BindDistributed(py::module *m) { ...@@ -1221,24 +1219,18 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
auto processGroupNCCL = py::class_<distributed::ProcessGroupNCCL,
py::class_<distributed::ProcessGroupNCCL, std::shared_ptr<distributed::ProcessGroupNCCL>>(
std::shared_ptr<distributed::ProcessGroupNCCL>>( *m, "ProcessGroupNCCL", ProcessGroupStream)
*m, "ProcessGroupNCCL", ProcessGroupStream) .def_static("create",
.def(py::init<const std::shared_ptr<distributed::Store> &, distributed::ProcessGroupNCCL::CreateProcessGroupNCCL,
int, py::arg("store"),
int, py::arg("rank"),
int>(), py::arg("world_size"),
py::arg("store"), py::arg("group_id") = 0,
py::arg("rank"), py::call_guard<py::gil_scoped_release>())
py::arg("world_size"), .def_static("group_start", distributed::ProcessGroupNCCL::GroupStart)
py::arg("group_id") = 0, .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd);
py::call_guard<py::gil_scoped_release>());
processGroupNCCL.def_static(
"group_start", []() { distributed::ProcessGroupNCCL::GroupStart(); });
processGroupNCCL.def_static(
"group_end", []() { distributed::ProcessGroupNCCL::GroupEnd(); });
#endif #endif
...@@ -1265,17 +1257,14 @@ void BindDistributed(py::module *m) { ...@@ -1265,17 +1257,14 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroupCustom, py::class_<distributed::ProcessGroupCustom,
std::shared_ptr<distributed::ProcessGroupCustom>>( std::shared_ptr<distributed::ProcessGroupCustom>>(
*m, "ProcessGroupCustom", ProcessGroup) *m, "ProcessGroupCustom", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &, .def_static("create",
const std::string &, distributed::ProcessGroupCustom::CreateProcessGroupCustom,
int, py::arg("store"),
int, py::arg("device_type"),
int>(), py::arg("rank"),
py::arg("store"), py::arg("world_size"),
py::arg("device_type"), py::arg("group_id") = 0,
py::arg("rank"), py::call_guard<py::gil_scoped_release>());
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
#endif #endif
...@@ -1284,15 +1273,13 @@ void BindDistributed(py::module *m) { ...@@ -1284,15 +1273,13 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroupBKCL, py::class_<distributed::ProcessGroupBKCL,
std::shared_ptr<distributed::ProcessGroupBKCL>>( std::shared_ptr<distributed::ProcessGroupBKCL>>(
*m, "ProcessGroupBKCL", ProcessGroupStream) *m, "ProcessGroupBKCL", ProcessGroupStream)
.def(py::init<const std::shared_ptr<distributed::Store> &, .def_static("create",
int, distributed::ProcessGroupBKCL::CreateProcessGroupBKCL,
int, py::arg("store"),
int>(), py::arg("rank"),
py::arg("store"), py::arg("world_size"),
py::arg("rank"), py::arg("group_id") = 0,
py::arg("world_size"), py::call_guard<py::gil_scoped_release>());
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
#endif #endif
py::class_<distributed::ProcessGroup::Task, py::class_<distributed::ProcessGroup::Task,
...@@ -1310,32 +1297,13 @@ void BindDistributed(py::module *m) { ...@@ -1310,32 +1297,13 @@ void BindDistributed(py::module *m) {
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>( py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
*m, "ProcessGroupGloo", ProcessGroup) *m, "ProcessGroupGloo", ProcessGroup)
.def(py::init<const std::shared_ptr<paddle::distributed::Store> &, .def_static("create",
int, distributed::ProcessGroupGloo::CreateProcessGroupGloo,
int, py::arg("store"),
int, py::arg("rank"),
std::shared_ptr<GlooOptions> &>(), py::arg("world_size"),
py::call_guard<py::gil_scoped_release>()) py::arg("group_id") = 0,
.def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store, py::call_guard<py::gil_scoped_release>())
int rank,
int world_size,
int gid) {
auto opts = GlooOptions::create();
char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) {
opts->device = ProcessGroupGloo::createDeviceForInterface(
std::string(ifname));
} else {
opts->device = ProcessGroupGloo::createDefaultDevice();
}
return std::make_shared<ProcessGroupGloo>(
store, rank, world_size, gid, opts);
}),
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>())
.def_static("create_default_device", .def_static("create_default_device",
&ProcessGroupGloo::createDefaultDevice); &ProcessGroupGloo::createDefaultDevice);
#endif #endif
......
...@@ -252,7 +252,13 @@ struct GPUContext::Impl { ...@@ -252,7 +252,13 @@ struct GPUContext::Impl {
phi::DestroyDnnHandle(dnn_handle_); phi::DestroyDnnHandle(dnn_handle_);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nccl_comm_) { if (nccl_comm_) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_)); // NOTE(liyurui): It is not recommend calling CUDA runtime API
// in destructor. Since we can not ensure the release order of
// static object, calling ncclCommDestroy in static object destructor
// is a undefined behavior, CUDA driver may be already unloaded
// from process.
// If you really need to release the resource of nccl_comm,
// try to get the nccl_comm out and use ncclCommDestroy outside.
} }
#endif #endif
phi::DestroyBlasHandle(blas_handle_); phi::DestroyBlasHandle(blas_handle_);
......
...@@ -152,15 +152,15 @@ def _new_process_group_impl( ...@@ -152,15 +152,15 @@ def _new_process_group_impl(
genv = _get_global_env() genv = _get_global_env()
assert backend in _valid_backend_list, "Unsupported backend: %s." % backend assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
if backend == "gloo": if backend == "gloo":
pg = core.ProcessGroupGloo(store, rank, world_size, group_id) pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl": elif backend == "nccl":
pg = core.ProcessGroupNCCL(store, rank, world_size, group_id) pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id)
elif backend == "xccl": elif backend == "xccl":
pg = core.ProcessGroupCustom( pg = core.ProcessGroupCustom.create(
store, genv.device_type, rank, world_size, group_id store, genv.device_type, rank, world_size, group_id
) )
elif backend == "bkcl": elif backend == "bkcl":
pg = core.ProcessGroupBKCL(store, rank, world_size, group_id) pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id)
return pg return pg
......
...@@ -28,7 +28,7 @@ def init_process_group(strategy=None): ...@@ -28,7 +28,7 @@ def init_process_group(strategy=None):
rank = ParallelEnv().local_rank rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks) store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks)
pg_group = core.ProcessGroupCustom( pg_group = core.ProcessGroupCustom.create(
store, store,
ParallelEnv().device_type, ParallelEnv().device_type,
rank, rank,
......
...@@ -42,7 +42,7 @@ class TestProcessGroupFp32(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestProcessGroupFp32(unittest.TestCase):
store = paddle.fluid.core.TCPStore( store = paddle.fluid.core.TCPStore(
"127.0.0.1", 6272, is_master, nranks, 30 "127.0.0.1", 6272, is_master, nranks, 30
) )
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks) pg = paddle.fluid.core.ProcessGroupGloo.create(store, rank, nranks)
# test allreduce sum # test allreduce sum
# rank 0 # rank 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册