未验证 提交 04e24e58 编写于 作者: L LiYuRio 提交者: GitHub

Create comm_context and modified static init (#49536)

* comm_context and static init

* refactor: move to phi/core/distributed

* refactor: avoid mutable_data usage

* fix: windows sock

* fix: device without nccl
Co-authored-by: 元无心's avatarWen Sun <syl1887415157@126.com>
上级 67fc8e93
add_subdirectory(auto_parallel)
add_subdirectory(collective)
add_subdirectory(store)
if(WITH_PYTHON)
py_proto_compile(ps_py_proto SRCS the_one_ps.proto)
add_custom_target(
......
......@@ -12,7 +12,7 @@ if(WITH_DISTRIBUTE)
cc_library(
process_group_gloo
SRCS process_group_gloo.cc
DEPS phi_api eager_api gloo_wrapper)
DEPS phi_api eager_api gloo_wrapper tcp_store)
endif()
if(WITH_NCCL OR WITH_RCCL)
......@@ -20,6 +20,7 @@ if(WITH_NCCL OR WITH_RCCL)
process_group_nccl
SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc
DEPS process_group
tcp_store
place
enforce
collective_helper
......@@ -32,7 +33,12 @@ if(WITH_XPU_BKCL)
cc_library(
process_group_bkcl
SRCS process_group_bkcl.cc bkcl_tools.cc common.cc
DEPS process_group place enforce collective_helper device_context
DEPS process_group
tcp_store
place
enforce
collective_helper
device_context
dense_tensor)
endif()
......@@ -47,6 +53,11 @@ if(WITH_CUSTOM_DEVICE)
cc_library(
process_group_custom
SRCS process_group_custom.cc custom_ccl_tools.cc common.cc
DEPS process_group phi_backends place enforce collective_helper
DEPS process_group
tcp_store
phi_backends
place
enforce
collective_helper
device_context)
endif()
......@@ -72,10 +72,11 @@ bool ProcessGroupBKCL::BKCLTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait
void ProcessGroupBKCL::BKCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr<Store>& store,
int rank,
int size,
int gid)
ProcessGroupBKCL::ProcessGroupBKCL(
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid)
: ProcessGroupWithStream(rank, size, gid), store_(store) {}
void ProcessGroupBKCL::GroupStart() {
......@@ -606,7 +607,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::AllGather(
}
std::shared_ptr<ProcessGroupBKCL> ProcessGroupBKCL::CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid) {
auto process_group =
std::make_shared<ProcessGroupBKCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
......
......@@ -21,11 +21,11 @@
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#if defined(PADDLE_WITH_XPU)
#include "paddle/fluid/distributed/collective/bkcl_tools.h"
......@@ -67,13 +67,16 @@ class ProcessGroupBKCL : public ProcessGroupWithStream {
};
public:
ProcessGroupBKCL(const std::shared_ptr<Store>& store,
ProcessGroupBKCL(const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid);
static std::shared_ptr<ProcessGroupBKCL> CreateProcessGroupBKCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid);
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid);
std::string GetBackendName() const override {
return std::string(BKCL_BACKEND_NAME);
......@@ -176,7 +179,7 @@ class ProcessGroupBKCL : public ProcessGroupWithStream {
void SyncCalcStream(const Place& place);
private:
std::shared_ptr<Store> store_;
std::shared_ptr<phi::distributed::Store> store_;
std::mutex mutex_;
std::shared_ptr<XPUEventManager> calc_event_; // event on calc stream
std::unordered_map<std::string, phi::XPUContext*> place_to_calc_ctx_;
......
......@@ -98,11 +98,12 @@ bool ProcessGroupCustom::CustomTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait
void ProcessGroupCustom::CustomTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr<Store>& store,
const std::string& device_type,
int rank,
int size,
int gid)
ProcessGroupCustom::ProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
int gid)
: ProcessGroupWithoutStream(rank, size, gid),
store_(store),
device_type_(device_type) {}
......@@ -438,7 +439,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Broadcast(
std::shared_ptr<ProcessGroupCustom>
ProcessGroupCustom::CreateProcessGroupCustom(
const std::shared_ptr<Store>& store,
const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
......
......@@ -24,17 +24,18 @@
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/distributed/store/store.h"
namespace paddle {
namespace distributed {
using Place = paddle::platform::Place;
using CustomDeviceContext = paddle::platform::CustomDeviceContext;
class ProcessGroupCustom : public ProcessGroupWithoutStream {
public:
class CustomTask : public ProcessGroup::Task,
......@@ -64,14 +65,14 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream {
const std::string device_type_;
};
ProcessGroupCustom(const std::shared_ptr<Store>& store,
ProcessGroupCustom(const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
int gid);
static std::shared_ptr<ProcessGroupCustom> CreateProcessGroupCustom(
const std::shared_ptr<Store>& store,
const std::shared_ptr<phi::distributed::Store>& store,
const std::string& device_type,
int rank,
int size,
......@@ -127,7 +128,7 @@ class ProcessGroupCustom : public ProcessGroupWithoutStream {
CommType opType,
const std::vector<phi::DenseTensor>& inputs);
std::shared_ptr<Store> store_;
std::shared_ptr<phi::distributed::Store> store_;
std::shared_ptr<CustomCCLCommManager> custom_comm_;
std::mutex mutex_;
std::unordered_map<std::string,
......
......@@ -177,7 +177,7 @@ ProcessGroupGloo::GlooTask::GlooTask(
: ProcessGroup::Task(rank, inputs, comm_type) {}
ProcessGroupGloo::ProcessGroupGloo(
const std::shared_ptr<distributed::Store>& store,
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int world_size,
int gid,
......@@ -601,10 +601,11 @@ ProcessGroupGloo::createDefaultDevice() {
0,
platform::errors::Fatal("Get hostname error for createDefaultDevice."));
::addrinfo* result;
result = tcputils::get_addr_info(hostname.data(), "", 0, AF_UNSPEC);
result = phi::distributed::tcputils::get_addr_info(
hostname.data(), "", 0, AF_UNSPEC);
::addrinfo* cur;
for (cur = result; cur != nullptr; cur = cur->ai_next) {
SocketType socket =
phi::distributed::SocketType socket =
::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);
if (socket == -1) {
continue;
......@@ -628,7 +629,10 @@ ProcessGroupGloo::createDefaultDevice() {
}
std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid) {
std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME";
auto opts = GlooOptions::create();
char* ifname = getenv(GLOO_SOCKET_IFNAME_ENV.c_str());
......
......@@ -20,8 +20,8 @@
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
#ifdef PADDLE_WITH_GLOO
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
......@@ -52,7 +52,7 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
class GlooStore : public ::gloo::rendezvous::Store {
public:
explicit GlooStore(const std::shared_ptr<paddle::distributed::Store>& store)
explicit GlooStore(const std::shared_ptr<phi::distributed::Store>& store)
: _store(store) {}
~GlooStore() = default;
......@@ -86,7 +86,7 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
}
protected:
std::shared_ptr<paddle::distributed::Store> _store;
std::shared_ptr<phi::distributed::Store> _store;
};
class GlooOptions {
......@@ -99,14 +99,14 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
std::shared_ptr<::gloo::transport::Device> device;
};
ProcessGroupGloo(const std::shared_ptr<paddle::distributed::Store>& store,
ProcessGroupGloo(const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int world_size,
int gid,
std::shared_ptr<GlooOptions> options);
static std::shared_ptr<ProcessGroupGloo> CreateProcessGroupGloo(
const std::shared_ptr<paddle::distributed::Store>& store,
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int world_size,
int gid);
......
......@@ -86,10 +86,11 @@ bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
// Same as Wait
void ProcessGroupNCCL::NCCLTask::Synchronize() { Wait(kWaitTimeout); }
ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank,
int size,
int gid)
ProcessGroupNCCL::ProcessGroupNCCL(
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid)
: ProcessGroupWithStream(rank, size, gid), store_(store) {}
void ProcessGroupNCCL::GroupStart() {
......@@ -1151,7 +1152,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
}
std::shared_ptr<ProcessGroupNCCL> ProcessGroupNCCL::CreateProcessGroupNCCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid) {
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid) {
auto process_group =
std::make_shared<ProcessGroupNCCL>(store, rank, size, gid);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
......
......@@ -22,10 +22,10 @@
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/nccl_tools.h"
......@@ -33,7 +33,7 @@
#ifdef PADDLE_WITH_RCCL
#include "paddle/phi/backends/dynload/rccl.h"
#elif PADDLE_WITH_NCCL
#else
#include "paddle/phi/backends/dynload/nccl.h"
#endif
......@@ -76,9 +76,12 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
public:
static std::shared_ptr<ProcessGroupNCCL> CreateProcessGroupNCCL(
const std::shared_ptr<Store>& store, int rank, int size, int gid);
const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid);
ProcessGroupNCCL(const std::shared_ptr<Store>& store,
ProcessGroupNCCL(const std::shared_ptr<phi::distributed::Store>& store,
int rank,
int size,
int gid);
......@@ -243,7 +246,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
const std::vector<Place>& places);
private:
std::shared_ptr<Store> store_;
std::shared_ptr<phi::distributed::Store> store_;
std::unordered_map<std::string, platform::DeviceEvent>
place_to_calc_event_; // event on calc stream
......
......@@ -34,7 +34,8 @@ register_operators(
${COLLECTIVE_DEPS})
if(WITH_NCCL OR WITH_RCCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper
comm_context_manager nccl_comm_context)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
endif()
......
......@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/phi/api/include/tensor.h"
namespace paddle {
......@@ -31,66 +32,52 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto x = ctx.Input<phi::DenseTensor>("X");
auto out = ctx.Output<phi::DenseTensor>("Out");
int numel = x->numel();
ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*x);
out_tensor.push_back(*out);
auto task = pg->Broadcast(in_tensor, out_tensor);
task->Wait();
return;
}
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
const auto& place = ctx.GetPlace();
ctx.device_context().Alloc<T>(out);
int root = ctx.Attr<int>("root");
if (root == comm->rank()) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())),
numel,
dtype,
root,
comm->comm(),
stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
framework::TensorCopy(
*static_cast<const phi::DenseTensor*>(x),
place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<phi::DenseTensor*>(out));
}
gpuStream_t stream = ctx.cuda_device_context().stream();
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(rid)) {
auto* comm_context = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(rid));
comm_context->Broadcast(out, *x, root, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclBcast(out->mutable_data<T>(place),
numel,
dtype,
root,
comm->comm(),
stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received "
<< phi::product(out->dims());
// NOTE(liyurui): This will be removed after moving this operator to phi.
int numel = x->numel();
ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (root == comm->rank()) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())),
numel,
dtype,
root,
comm->comm(),
stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
framework::TensorCopy(
*static_cast<const phi::DenseTensor*>(x),
place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<phi::DenseTensor*>(out));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
out->data<T>(), numel, dtype, root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received "
<< phi::product(out->dims());
}
}
out->Resize(x->dims());
out->set_lod(x->lod());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
......
......@@ -37,6 +37,7 @@ set(PYBIND_DEPS
global_utils
phi_utils
tcp_store
comm_context_manager
new_profiler
auto_parallel
jit_layer
......
......@@ -21,49 +21,64 @@ limitations under the License. */
#include <pybind11/stl.h>
#include <chrono>
#include <memory>
#include <string>
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
using TCPStore = paddle::distributed::TCPStore;
void BindCommContextManager(py::module *m) {
auto CommContextManager =
py::class_<phi::distributed::CommContextManager,
std::shared_ptr<phi::distributed::CommContextManager>>(
*m, "CommContextManager")
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
.def_static(
"create_nccl_comm_context",
&phi::distributed::CommContextManager::CreateNCCLCommContext,
py::call_guard<py::gil_scoped_release>())
#endif
.def("set_store", &phi::distributed::CommContextManager::SetStore);
}
using TCPStore = phi::distributed::TCPStore;
void BindTCPStore(py::module *m) {
auto Store =
py::class_<distributed::Store, std::shared_ptr<distributed::Store>>(
*m, "Store")
.def(py::init<>())
.def(
"set",
[](distributed::Store &self,
const std::string &key,
const std::string &value) {
std::vector<uint8_t> data(value.begin(), value.end());
self.set(key, data);
},
py::arg("key"),
py::arg("value"),
py::call_guard<py::gil_scoped_release>())
.def(
"get",
[](distributed::Store &self,
const std::string &key) -> py::bytes {
auto data = self.get(key);
return py::bytes(reinterpret_cast<char *>(data.data()),
data.size());
},
py::arg("key"),
py::call_guard<py::gil_scoped_release>())
.def("add",
&distributed::Store::add,
py::call_guard<py::gil_scoped_release>())
.def("wait",
&distributed::Store::wait,
py::call_guard<py::gil_scoped_release>());
auto Store = py::class_<phi::distributed::Store,
std::shared_ptr<phi::distributed::Store>>(*m, "Store")
.def(py::init<>())
.def(
"set",
[](phi::distributed::Store &self,
const std::string &key,
const std::string &value) {
std::vector<uint8_t> data(value.begin(), value.end());
self.set(key, data);
},
py::arg("key"),
py::arg("value"),
py::call_guard<py::gil_scoped_release>())
.def(
"get",
[](phi::distributed::Store &self,
const std::string &key) -> py::bytes {
auto data = self.get(key);
return py::bytes(reinterpret_cast<char *>(data.data()),
data.size());
},
py::arg("key"),
py::call_guard<py::gil_scoped_release>())
.def("add",
&phi::distributed::Store::add,
py::call_guard<py::gil_scoped_release>())
.def("wait",
&phi::distributed::Store::wait,
py::call_guard<py::gil_scoped_release>());
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store)
.def(py::init([](std::string hostname,
......
......@@ -26,6 +26,7 @@ namespace paddle {
namespace pybind {
void BindTCPStore(pybind11::module* m);
void BindCommContextManager(pybind11::module* m);
} // namespace pybind
} // namespace paddle
......@@ -46,7 +46,6 @@ limitations under the License. */
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/process_group_gloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
......
......@@ -1871,6 +1871,7 @@ All parameter, weight, gradient are variables in Paddle.
BindGlobalValueGetterSetter(&m);
BindFleetExecutor(&m);
BindTCPStore(&m);
BindCommContextManager(&m);
BindAutoParallel(&m);
BindJitProperty(&m);
......
# compatible utils used for fluid op system
add_subdirectory(compat)
add_subdirectory(distributed)
if(WITH_GPU)
proto_library(external_error_proto SRCS external_error.proto)
......
add_subdirectory(store)
set(COMM_CONTEXT_MANAGER_DEPS tcp_store)
if(WITH_NCCL OR WITH_RCCL)
cc_library(
nccl_comm_context
SRCS nccl_comm_context.cc
DEPS dense_tensor)
list(APPEND COMM_CONTEXT_MANAGER_DEPS nccl_comm_context)
endif()
cc_library(
comm_context_manager
SRCS comm_context_manager.cc
DEPS ${COMM_CONTEXT_MANAGER_DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/macros.h"
namespace phi {
namespace distributed {
class CommContext {
public:
CommContext(int rank, int size) : rank_(rank), size_(size) {}
virtual ~CommContext() = default;
protected:
int rank_;
int size_;
private:
DISABLE_COPY_AND_ASSIGN(CommContext);
};
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include <memory>
#include <string>
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/enforce.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
namespace phi {
namespace distributed {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
void CommContextManager::CreateNCCLCommContext(
const std::shared_ptr<Store>& store,
int dev_id,
int ring_id,
int rank,
int size) {
phi::backends::gpu::SetDeviceId(dev_id);
ncclUniqueId nccl_id;
if (rank == 0) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id));
}
std::string unique_key = "NCCLCommContext/" + std::to_string(ring_id);
if (rank == 0) {
std::vector<uint8_t> nccl_id_wrapper(
reinterpret_cast<uint8_t*>(&nccl_id),
reinterpret_cast<uint8_t*>(&nccl_id) + NCCL_UNIQUE_ID_BYTES);
store->set(unique_key, nccl_id_wrapper);
} else {
const auto& nccl_id_wrapper = store->get(unique_key);
std::memcpy(&nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
}
auto nccl_comm_context =
std::make_unique<NCCLCommContext>(rank, size, nccl_id);
auto& comm_context_manager = CommContextManager::GetInstance();
comm_context_manager.SetStore(store);
comm_context_manager.Emplace(ring_id, std::move(nccl_comm_context));
}
#endif
CommContext* CommContextManager::Emplace(
int ring_id, std::unique_ptr<CommContext> comm_context) {
PADDLE_ENFORCE_EQ(
id_to_comm_context_.find(ring_id),
id_to_comm_context_.end(),
errors::AlreadyExists("Ring id %d already exists in the map.", ring_id));
id_to_comm_context_.emplace(ring_id, std::move(comm_context));
return id_to_comm_context_.at(ring_id).get();
}
CommContext* CommContextManager::Get(int ring_id) const {
PADDLE_ENFORCE_NE(
id_to_comm_context_.find(ring_id),
id_to_comm_context_.end(),
errors::NotFound("Can not find ring id %d in map.", ring_id));
return id_to_comm_context_.at(ring_id).get();
}
bool CommContextManager::Has(int ring_id) const {
return id_to_comm_context_.find(ring_id) != id_to_comm_context_.end();
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <memory>
#include <unordered_map>
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
namespace phi {
namespace distributed {
class Store;
class CommContextManager {
public:
CommContextManager() = default;
~CommContextManager() = default;
static CommContextManager& GetInstance() {
static CommContextManager instance;
return instance;
}
void SetStore(const std::shared_ptr<Store>& store) { store_ = store; }
CommContext* Emplace(int ring_id, std::unique_ptr<CommContext> comm_context);
CommContext* Get(int ring_id) const;
bool Has(int ring_id) const;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
static void CreateNCCLCommContext(const std::shared_ptr<Store>& store,
int dev_id,
int ring_id,
int rank,
int size);
#endif
private:
DISABLE_COPY_AND_ASSIGN(CommContextManager);
std::unordered_map<int, std::unique_ptr<CommContext>> id_to_comm_context_;
std::shared_ptr<Store> store_;
};
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace distributed {
NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id)
: CommContext(rank, size) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_));
}
void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
gpuStream_t stream) {
phi::dynload::ncclBroadcast(in_tensor.data(),
out_tensor->data(),
in_tensor.numel(),
ToNCCLDataType(in_tensor.type()),
root,
nccl_comm_,
stream);
}
} // namespace distributed
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/distributed/comm_context.h"
#include "paddle/phi/core/macros.h"
#if defined(PADDLE_WITH_RCCL)
#include "paddle/phi/backends/dynload/rccl.h"
#else
#include "paddle/phi/backends/dynload/nccl.h"
#endif
namespace phi {
class DenseTensor;
namespace distributed {
class NCCLCommContext final : public CommContext {
public:
NCCLCommContext(int rank, int size, ncclUniqueId nccl_id);
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
gpuStream_t stream);
private:
DISABLE_COPY_AND_ASSIGN(NCCLCommContext);
ncclComm_t nccl_comm_;
};
} // namespace distributed
} // namespace phi
cc_library(
tcp_store
SRCS tcp_store.cc tcp_utils.cc socket.cpp
SRCS tcp_store.cc tcp_utils.cc socket.cpp store.cc
DEPS enforce glog)
if(NOT WIN32)
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/store/socket.h"
#include "paddle/phi/core/distributed/store/socket.h"
#ifndef _WIN32
#include <arpa/inet.h>
......@@ -23,7 +23,7 @@
#include <errno.h>
#include <stdio.h>
namespace paddle {
namespace phi {
namespace distributed {
#ifdef _WIN32
......@@ -75,5 +75,5 @@ std::string GetSockName(int fd) {
return std::string(out);
}
}; // namespace distributed
}; // namespace paddle
} // namespace distributed
} // namespace phi
......@@ -16,11 +16,11 @@
#include <string>
namespace paddle {
namespace phi {
namespace distributed {
int GetSockName(int fd, char* out, int out_len);
std::string GetSockName(int fd);
}; // namespace distributed
}; // namespace paddle
} // namespace distributed
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace distributed {
int64_t Store::add(const std::string& key, int64_t value) {
PADDLE_THROW(
errors::InvalidArgument("Implement the add method in the subclass."));
}
std::vector<uint8_t> Store::get(const std::string& key) {
PADDLE_THROW(
errors::InvalidArgument("Implement the get method in the subclass."));
}
void Store::wait(const std::string& key) {
PADDLE_THROW(
errors::InvalidArgument("Implement the wait method in the subclass."));
}
void Store::set(const std::string& key, const std::vector<uint8_t>& value) {
PADDLE_THROW(
errors::InvalidArgument("Implement the set method in the subclass."));
}
} // namespace distributed
} // namespace phi
......@@ -18,9 +18,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/distributed/store/tcp_utils.h"
namespace paddle {
namespace phi {
namespace distributed {
class Store {
......@@ -29,22 +27,10 @@ class Store {
explicit Store(const int timeout) : _timeout(timeout) {}
virtual ~Store() = default;
virtual int64_t add(const std::string& key, int64_t value) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual std::vector<uint8_t> get(const std::string& key) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual void wait(const std::string& key) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual void set(const std::string& key, const std::vector<uint8_t>& value) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Implement the add method in the subclass."));
}
virtual int64_t add(const std::string& key, int64_t value);
virtual std::vector<uint8_t> get(const std::string& key);
virtual void wait(const std::string& key);
virtual void set(const std::string& key, const std::vector<uint8_t>& value);
virtual int timeout() { return _timeout; }
......@@ -53,4 +39,4 @@ class Store {
};
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -12,17 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
#include <chrono>
#include <iostream>
#include <thread>
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/store/tcp_utils.h"
#include "paddle/phi/core/flags.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace detail {
......@@ -90,7 +89,7 @@ void MasterDaemon::_do_get(SocketType socket) {
PADDLE_ENFORCE_NE(
iter,
_store.end(),
platform::errors::InvalidArgument("Key %s not found in TCPStore.", key));
phi::errors::InvalidArgument("Key %s not found in TCPStore.", key));
std::vector<uint8_t> value = iter->second;
tcputils::send_vector<uint8_t>(socket, value);
}
......@@ -100,7 +99,7 @@ void MasterDaemon::InitControlFd() {
PADDLE_ENFORCE_NE(
pipe(_control_fd.data()),
-1,
platform::errors::Fatal("failed to cread control pipe errno:%d", errno));
phi::errors::Fatal("failed to cread control pipe errno:%d", errno));
}
void MasterDaemon::CloseControlFd() {
for (int fd : _control_fd) {
......@@ -112,10 +111,10 @@ void MasterDaemon::CloseControlFd() {
void MasterDaemon::StopByControlFd() {
VLOG(4) << ("begin to run StopByControlFd");
if (_control_fd[1] != -1) {
PADDLE_ENFORCE_NE(::write(_control_fd[1], "\0", 1),
-1,
platform::errors::Fatal(
"failed to write control pipe errno:%d", errno));
PADDLE_ENFORCE_NE(
::write(_control_fd[1], "\0", 1),
-1,
phi::errors::Fatal("failed to write control pipe errno:%d", errno));
// close the write end of the pipe
::close(_control_fd[1]);
_control_fd[1] = -1;
......@@ -125,7 +124,7 @@ void MasterDaemon::StopByControlFd() {
void MasterDaemon::InitControlFd() {
ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL);
PADDLE_ENFORCE(ghStopEvent_,
platform::errors::Fatal("failed to cread control pipe"));
phi::errors::Fatal("failed to cread control pipe"));
}
void MasterDaemon::CloseControlFd() { CloseHandle(ghStopEvent_); }
void MasterDaemon::StopByControlFd() { SetEvent(ghStopEvent_); }
......@@ -231,8 +230,8 @@ void MasterDaemon::run() {
// The control pipe receive shutdown event, and begin to close it.
if (fds[1].revents != 0) {
if (fds[1].revents & ~(POLLIN | POLLHUP)) {
PADDLE_THROW(paddle::platform::errors::Fatal("Undefined event type:%d",
fds[1].revents));
PADDLE_THROW(
phi::errors::Fatal("Undefined event type:%d", fds[1].revents));
}
VLOG(0)
<< "receive shutdown event and so quit from MasterDaemon run loop";
......@@ -312,9 +311,7 @@ TCPStore::TCPStore(std::string host,
: Store(timeout), _is_master(is_master), _num_workers(num_workers) {
_timeout = timeout;
PADDLE_ENFORCE_GT(
timeout,
0,
platform::errors::InvalidArgument("timeout must >= %d", timeout));
timeout, 0, phi::errors::InvalidArgument("timeout must >= %d", timeout));
VLOG(3) << "input timeout" << timeout << ", member timeout:" << _timeout;
if (_is_master) {
......@@ -355,7 +352,7 @@ void TCPStore::waitWorkers() {
PADDLE_ENFORCE_EQ(
completed,
_num_workers,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"TCPStore timeouted and not all workers got ready."));
}
} while (true);
......@@ -398,4 +395,4 @@ void TCPStore::wait(const std::string& key) {
TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; }
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -30,11 +30,11 @@
#include <thread>
#include <unordered_map>
#include "paddle/fluid/distributed/store/socket.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include "paddle/phi/core/distributed/store/socket.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/store/tcp_utils.h"
namespace paddle {
namespace phi {
namespace distributed {
enum class ReplyType { WAITING, STOP_WAIT };
......@@ -143,4 +143,4 @@ class TCPStore : public Store {
};
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -12,15 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include "paddle/phi/core/distributed/store/tcp_utils.h"
#include <cerrno>
#include <cstring>
#include <thread>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace phi {
namespace distributed {
namespace tcputils {
......@@ -60,7 +58,7 @@ void close_socket(SocketType socket) {
: "");
PADDLE_ENFORCE_EQ(n,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"%s network %s:%s cannot be obtained. Details: %s.",
proto,
host,
......@@ -73,7 +71,7 @@ void close_socket(SocketType socket) {
void free_addr_info(::addrinfo* hint) {
PADDLE_ENFORCE_NOT_NULL(
hint,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The parameter for free_addr_info cannot be null."));
::freeaddrinfo(hint);
}
......@@ -91,14 +89,14 @@ SocketType tcp_connect(const std::string host,
do {
for (::addrinfo* cur = res; cur != nullptr; cur = cur->ai_next) {
sockfd = ::socket(cur->ai_family, cur->ai_socktype, cur->ai_protocol);
PADDLE_ENFORCE_GT(sockfd,
0,
platform::errors::InvalidArgument(
"Create socket to connect %s:%s failed. "
"Details: %s. ",
host,
port,
socket_error().message()));
PADDLE_ENFORCE_GT(
sockfd,
0,
phi::errors::InvalidArgument("Create socket to connect %s:%s failed. "
"Details: %s. ",
host,
port,
socket_error().message()));
if (::connect(sockfd, cur->ai_addr, cur->ai_addrlen) == 0) {
retry = false;
......@@ -125,7 +123,7 @@ SocketType tcp_connect(const std::string host,
PADDLE_ENFORCE_GT(sockfd,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Network %s:%s cannot be connected.", host, port));
VLOG(0) << "Successfully connected to " << host << ":" << port;
......@@ -173,7 +171,7 @@ SocketType tcp_listen(const std::string host,
PADDLE_ENFORCE_GT(sockfd,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Bind network on %s:%s failedd.", node, port));
::listen(sockfd, LISTENQ);
......@@ -190,7 +188,7 @@ SocketType tcp_accept(SocketType socket) {
PADDLE_ENFORCE_GT(
new_socket,
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The server failed to accept a new connection. Details: %s.",
socket_error().message()));
#ifndef _WIN32
......@@ -225,4 +223,4 @@ std::string receive_string(SocketType socket) {
} // namespace tcputils
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -26,14 +26,15 @@
#include <sys/socket.h>
#include <unistd.h>
#endif
#include <chrono>
#include <iostream>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/enforce.h"
// Utility functions for TCP socket.
namespace paddle {
namespace phi {
namespace distributed {
#ifdef _WIN32
......@@ -82,8 +83,8 @@ void send_bytes(SocketType socket, const T* buffer, size_t len) {
PADDLE_ENFORCE_GT(
byte_sent,
0,
platform::errors::InvalidArgument("TCP send error. Details: %s.",
socket_error().message()));
phi::errors::InvalidArgument("TCP send error. Details: %s.",
socket_error().message()));
to_send -= byte_sent;
ptr += byte_sent;
}
......@@ -102,8 +103,8 @@ void receive_bytes(SocketType socket, T* buffer, size_t len) {
PADDLE_ENFORCE_GT(
byte_received,
0,
platform::errors::InvalidArgument("TCP receive error. Details: %s.",
socket_error().message()));
phi::errors::InvalidArgument("TCP receive error. Details: %s.",
socket_error().message()));
to_recv -= byte_received;
ptr += byte_received;
......@@ -140,4 +141,4 @@ T receive_value(SocketType socket) {
} // namespace tcputils
} // namespace distributed
} // namespace paddle
} // namespace phi
......@@ -13,14 +13,14 @@
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
#include "paddle/phi/core/distributed/store/tcp_utils.h"
#ifdef _WIN32
#include <windows.h>
#endif
namespace paddle {
namespace phi {
namespace distributed {
TEST(MasterDaemon, init) {
......@@ -48,6 +48,5 @@ TEST(TCPStore, init) {
paddle::errors::Fatal("result of add is not right"));
}
*/
}; // namespace distributed
}; // namespace paddle
} // namespace distributed
} // namespace phi
......@@ -211,4 +211,33 @@ inline int TransToProtoVarType(const DataType& dtype) {
}
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
inline ncclDataType_t ToNCCLDataType(DataType type) {
if (type == DataType::FLOAT32) {
return ncclFloat;
} else if (type == DataType::FLOAT64) {
return ncclDouble;
} else if (type == DataType::INT32) {
return ncclInt;
} else if (type == DataType::INT64) {
return ncclInt64;
} else if (type == DataType::FLOAT16) {
return ncclFloat16;
} else if (type == DataType::UINT8) {
return ncclUint8;
} else if (type == DataType::INT8) {
return ncclInt8;
} else if (type == DataType::BOOL) {
return ncclUint8;
#if NCCL_VERSION_CODE >= 21000
} else if (type == DataType::BFLOAT16) {
return ncclBfloat16;
#endif
} else {
PADDLE_THROW(
errors::Unimplemented("This datatype in nccl is not supported."));
}
}
#endif
} // namespace phi
......@@ -13,6 +13,7 @@
# limitations under the License.
import datetime
import os
import paddle
......@@ -325,3 +326,25 @@ def is_available():
"""
return core.is_compiled_with_dist()
def _init_parallel_env(backend):
master_endpoint = os.getenv("PADDLE_MASTER", None)
if master_endpoint:
master_addr = master_endpoint.split(":")[0]
master_port = int(master_endpoint.split(":")[1])
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id
is_master = rank == 0
store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
)
if backend == "nccl":
core.CommContextManager.create_nccl_comm_context(
store, dev_id, 0, rank, world_size
)
......@@ -243,6 +243,7 @@ def init_parallel_env():
_set_expected_place(place)
group = None
if backend in _valid_backend_list and in_dygraph_mode():
if _default_group_name in _get_group_map_by_name():
return _get_group_map_by_name()[_default_group_name]
......
......@@ -30,6 +30,14 @@ class TestCollectiveBroadcastAPI(TestDistBase):
"collective_broadcast_api.py", "broadcast", "nccl"
)
def test_broadcast_nccl_with_comm_context(self):
self.check_with_place(
"collective_broadcast_api.py",
"broadcast",
"nccl",
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_broadcast_gloo(self):
self.check_with_place(
"collective_broadcast_api.py", "broadcast", "gloo", "0"
......
......@@ -108,7 +108,10 @@ class TestCollectiveAPIRunnerBase:
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
paddle.distributed.init_parallel_env()
if args["use_comm_context"]:
paddle.distributed.collective._init_parallel_env(args["backend"])
else:
paddle.distributed.init_parallel_env()
if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(
......@@ -150,6 +153,7 @@ def runtime_main(test_class, col_type):
args["path_id"] = int(os.getenv("PATH_ID"))
args["static_mode"] = int(os.getenv("STATIC_MODE"))
args["dtype"] = os.getenv("DTYPE")
args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0")))
model.run_trainer(args)
......@@ -162,6 +166,7 @@ class TestDistBase(unittest.TestCase):
self._find_free_port(),
)
self._python_interp = sys.executable
self._master_endpoints = "127.0.0.1:%s" % (self._find_free_port())
self.temp_dir = tempfile.TemporaryDirectory()
......@@ -204,6 +209,7 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w0_ep,
"PADDLE_MASTER": self._master_endpoints,
}
env1 = {
......@@ -212,6 +218,7 @@ class TestDistBase(unittest.TestCase):
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep,
"PADDLE_MASTER": self._master_endpoints,
}
elif core.is_compiled_with_xpu():
env0 = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册