未验证 提交 2a47416c 编写于 作者: L LiYuRio 提交者: GitHub

add new map instance (#48145)

上级 403d58bb
......@@ -35,17 +35,6 @@ void ProcessGroup::Task::Synchronize() {}
void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {}
ProcessGroup::ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid)
: rank_(rank), size_(size), place_(place), gid_(gid) {
if (gid != IGNORE_ID) {
auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this);
}
}
ProcessGroup::ProcessGroup(int rank, int size, int gid)
: rank_(rank), size_(size), gid_(gid) {
if (gid != IGNORE_ID) {
......@@ -66,5 +55,10 @@ ProcessGroup::Task::Task(int rank,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
static ProcessGroupIdMap instance;
return instance;
}
} // namespace distributed
} // namespace paddle
......@@ -82,13 +82,8 @@ class ProcessGroup {
};
public:
explicit ProcessGroup(int rank, int size, int gid);
ProcessGroup(int rank, int size, int gid);
virtual ~ProcessGroup() = default;
// TODO(dev): This constructor will be removed later.
explicit ProcessGroup(int rank,
int size,
const platform::Place& place,
int gid);
int GetRank() const { return rank_; }
......@@ -290,12 +285,18 @@ class ProcessGroup {
}
protected:
const int rank_;
const int size_;
const platform::Place place_;
const int gid_;
int rank_;
int size_;
int gid_;
};
class ProcessGroupIdMap
: public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
public:
static ProcessGroupIdMap& GetInstance();
};
// TODO(dev): The following method will be removed soon.
class ProcessGroupMapFromGid {
public:
bool has(int gid) {
......
......@@ -531,5 +531,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroupBKCL> ProcessGroupBKCL::CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
auto process_group =
std::make_shared<ProcessGroupBKCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
} // namespace distributed
} // namespace paddle
......@@ -73,6 +73,9 @@ class ProcessGroupBKCL : public ProcessGroupStream {
int size,
int gid);
static std::shared_ptr<ProcessGroupBKCL> CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid);
std::string GetBackendName() const override {
return std::string(BKCL_BACKEND_NAME);
}
......
......@@ -433,5 +433,18 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
CommType::BROADCAST);
}
std::shared_ptr<ProcessGroupCustom>
ProcessGroupCustom::CreateProcessGroupCustom(
const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank,
int size,
int gid) {
auto process_group =
std::make_shared<ProcessGroupCustom>(store, device_type, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
} // namespace distributed
} // namespace paddle
......@@ -69,6 +69,13 @@ class ProcessGroupCustom : public ProcessGroup {
int size,
int gid);
static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);
std::string GetBackendName() const override { return "XCCL_" + device_type_; }
std::shared_ptr<ProcessGroup::Task> AllGather(
......
......@@ -617,5 +617,22 @@ ProcessGroupGloo::createDefaultDevice() {
return createDeviceForHostname("127.0.0.1");
}
std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
auto opts = GlooOptions::create();
char* ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) {
opts->device =
ProcessGroupGloo::createDeviceForInterface(std::string(ifname));
} else {
opts->device = ProcessGroupGloo::createDefaultDevice();
}
auto process_group =
std::make_shared<ProcessGroupGloo>(store, rank, size, gid, opts);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
} // namespace distributed
} // namespace paddle
......@@ -15,6 +15,7 @@
#pragma once
#include <future>
#include <memory>
#include <mutex>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
......@@ -98,12 +99,17 @@ class ProcessGroupGloo : public ProcessGroup {
std::shared_ptr<::gloo::transport::Device> device;
};
explicit ProcessGroupGloo(
ProcessGroupGloo(const std::shared_ptr<paddle::distributed::Store>& store,
int rank,
int world_size,
int gid,
std::shared_ptr<GlooOptions> options);
static std::shared_ptr<ProcessGroupGloo> CreateProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store,
int rank,
int world_size,
int gid,
std::shared_ptr<GlooOptions> options);
int gid);
~ProcessGroupGloo() = default;
......@@ -191,7 +197,7 @@ class ProcessGroupGloo : public ProcessGroup {
const std::string& ifname);
static std::shared_ptr<::gloo::transport::Device> createDefaultDevice();
protected:
private:
uint32_t _tag;
std::shared_ptr<gloo::rendezvous::Context> _context;
std::shared_ptr<::gloo::rendezvous::Store> _store;
......
......@@ -1130,5 +1130,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
CommType::SCATTER);
}
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
} // namespace distributed
} // namespace paddle
......@@ -76,6 +76,9 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
};
public:
static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid);
ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank,
int size,
......
......@@ -85,8 +85,6 @@ using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore;
using GlooOptions = paddle::distributed::ProcessGroupGloo::GlooOptions;
#endif
static std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; // NOLINT
static UNUSED void *use_ccl_comm_func =
phi::detail::GetCCLComm(phi::CPUPlace());
......@@ -1221,24 +1219,18 @@ void BindDistributed(py::module *m) {
py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
auto processGroupNCCL =
py::class_<distributed::ProcessGroupNCCL,
std::shared_ptr<distributed::ProcessGroupNCCL>>(
*m, "ProcessGroupNCCL", ProcessGroupStream)
.def(py::init<const std::shared_ptr<distributed::Store> &,
int,
int,
int>(),
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
processGroupNCCL.def_static(
"group_start", []() { distributed::ProcessGroupNCCL::GroupStart(); });
processGroupNCCL.def_static(
"group_end", []() { distributed::ProcessGroupNCCL::GroupEnd(); });
py::class_<distributed::ProcessGroupNCCL,
std::shared_ptr<distributed::ProcessGroupNCCL>>(
*m, "ProcessGroupNCCL", ProcessGroupStream)
.def_static("create",
distributed::ProcessGroupNCCL::CreateProcessGroupNCCL,
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>())
.def_static("group_start", distributed::ProcessGroupNCCL::GroupStart)
.def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd);
#endif
......@@ -1265,17 +1257,14 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroupCustom,
std::shared_ptr<distributed::ProcessGroupCustom>>(
*m, "ProcessGroupCustom", ProcessGroup)
.def(py::init<const std::shared_ptr<distributed::Store> &,
const std::string &,
int,
int,
int>(),
py::arg("store"),
py::arg("device_type"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
.def_static("create",
distributed::ProcessGroupCustom::CreateProcessGroupCustom,
py::arg("store"),
py::arg("device_type"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
#endif
......@@ -1284,15 +1273,13 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroupBKCL,
std::shared_ptr<distributed::ProcessGroupBKCL>>(
*m, "ProcessGroupBKCL", ProcessGroupStream)
.def(py::init<const std::shared_ptr<distributed::Store> &,
int,
int,
int>(),
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
.def_static("create",
distributed::ProcessGroupBKCL::CreateProcessGroupBKCL,
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
#endif
py::class_<distributed::ProcessGroup::Task,
......@@ -1310,32 +1297,13 @@ void BindDistributed(py::module *m) {
#if defined(PADDLE_WITH_GLOO)
py::class_<ProcessGroupGloo, std::shared_ptr<ProcessGroupGloo>>(
*m, "ProcessGroupGloo", ProcessGroup)
.def(py::init<const std::shared_ptr<paddle::distributed::Store> &,
int,
int,
int,
std::shared_ptr<GlooOptions> &>(),
py::call_guard<py::gil_scoped_release>())
.def(py::init([](const std::shared_ptr<paddle::distributed::Store> &store,
int rank,
int world_size,
int gid) {
auto opts = GlooOptions::create();
char *ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
if (ifname && strlen(ifname) > 1) {
opts->device = ProcessGroupGloo::createDeviceForInterface(
std::string(ifname));
} else {
opts->device = ProcessGroupGloo::createDefaultDevice();
}
return std::make_shared<ProcessGroupGloo>(
store, rank, world_size, gid, opts);
}),
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>())
.def_static("create",
distributed::ProcessGroupGloo::CreateProcessGroupGloo,
py::arg("store"),
py::arg("rank"),
py::arg("world_size"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>())
.def_static("create_default_device",
&ProcessGroupGloo::createDefaultDevice);
#endif
......
......@@ -252,7 +252,13 @@ struct GPUContext::Impl {
phi::DestroyDnnHandle(dnn_handle_);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if (nccl_comm_) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
// NOTE(liyurui): It is not recommend calling CUDA runtime API
// in destructor. Since we can not ensure the release order of
// static object, calling ncclCommDestroy in static object destructor
// is a undefined behavior, CUDA driver may be already unloaded
// from process.
// If you really need to release the resource of nccl_comm,
// try to get the nccl_comm out and use ncclCommDestroy outside.
}
#endif
phi::DestroyBlasHandle(blas_handle_);
......
......@@ -152,15 +152,15 @@ def _new_process_group_impl(
genv = _get_global_env()
assert backend in _valid_backend_list, "Unsupported backend: %s." % backend
if backend == "gloo":
pg = core.ProcessGroupGloo(store, rank, world_size, group_id)
pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl":
pg = core.ProcessGroupNCCL(store, rank, world_size, group_id)
pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id)
elif backend == "xccl":
pg = core.ProcessGroupCustom(
pg = core.ProcessGroupCustom.create(
store, genv.device_type, rank, world_size, group_id
)
elif backend == "bkcl":
pg = core.ProcessGroupBKCL(store, rank, world_size, group_id)
pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id)
return pg
......
......@@ -28,7 +28,7 @@ def init_process_group(strategy=None):
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6173, is_master, nranks)
pg_group = core.ProcessGroupCustom(
pg_group = core.ProcessGroupCustom.create(
store,
ParallelEnv().device_type,
rank,
......
......@@ -42,7 +42,7 @@ class TestProcessGroupFp32(unittest.TestCase):
store = paddle.fluid.core.TCPStore(
"127.0.0.1", 6272, is_master, nranks, 30
)
pg = paddle.fluid.core.ProcessGroupGloo(store, rank, nranks)
pg = paddle.fluid.core.ProcessGroupGloo.create(store, rank, nranks)
# test allreduce sum
# rank 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册