未验证 提交 5a6cd05f 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

update dygraph collective process group (#54863)

* update dygraph collective

fix ut

* remove debug log
上级 bbcaaffd
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/distributed/store/store.h" #include "paddle/phi/core/distributed/store/store.h"
namespace paddle { namespace paddle {
...@@ -68,6 +69,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -68,6 +69,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
public: public:
static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL( static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store, const std::shared_ptr<phi::distributed::Store>& store,
int device_id,
int rank, int rank,
int size, int size,
int gid); int gid);
...@@ -219,7 +221,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -219,7 +221,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
void SyncCalcStream(const Place& place); void SyncCalcStream(const Place& place);
std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv( std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
std::function<void(ncclComm_t, gpuStream_t)> fn, std::function<void(gpuStream_t)> fn,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
CommType comm_type, CommType comm_type,
bool sync_op, bool sync_op,
...@@ -249,6 +251,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ...@@ -249,6 +251,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
void CreateNCCLManagerCache(const std::string& places_key, void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places); const std::vector<Place>& places);
phi::distributed::NCCLCommContext* GetCommContext();
private: private:
std::shared_ptr<phi::distributed::Store> store_; std::shared_ptr<phi::distributed::Store> store_;
......
...@@ -1238,6 +1238,7 @@ void BindDistributed(py::module *m) { ...@@ -1238,6 +1238,7 @@ void BindDistributed(py::module *m) {
.def_static("create", .def_static("create",
distributed::ProcessGroupNCCL::CreateProcessGroupNCCL, distributed::ProcessGroupNCCL::CreateProcessGroupNCCL,
py::arg("store"), py::arg("store"),
py::arg("device_id"),
py::arg("rank"), py::arg("rank"),
py::arg("world_size"), py::arg("world_size"),
py::arg("group_id") = 0, py::arg("group_id") = 0,
......
...@@ -151,7 +151,10 @@ def _new_process_group_impl( ...@@ -151,7 +151,10 @@ def _new_process_group_impl(
if backend == "gloo": if backend == "gloo":
pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id) pg = core.ProcessGroupGloo.create(store, rank, world_size, group_id)
elif backend == "nccl": elif backend == "nccl":
pg = core.ProcessGroupNCCL.create(store, rank, world_size, group_id) pg = core.ProcessGroupNCCL.create(
store, genv.device_id, rank, world_size, group_id
)
elif backend == "xccl": elif backend == "xccl":
pg = core.ProcessGroupCustom.create( pg = core.ProcessGroupCustom.create(
store, genv.device_type, rank, world_size, group_id store, genv.device_type, rank, world_size, group_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册