未验证 提交 40431e66 编写于 作者: G Ghost Screaming 提交者: GitHub

[Auto Parallel] Upgrade fluid comm operators to be compatible with new comm library (#56088)

上级 77036fff
...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/alltoall_op.h" #include "paddle/fluid/operators/collective/alltoall_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
namespace paddle { namespace paddle {
...@@ -41,15 +46,44 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> { ...@@ -41,15 +46,44 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The ring_id (%d) for alltoall op must be non-negative.", ring_id)); "The ring_id (%d) for alltoall op must be non-negative.", ring_id));
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
int nranks = 0;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
nranks = comm_ctx->GetSize();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
stream = comm->stream();
nranks = comm->nranks();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream. // should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream(); stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
} }
framework::DDim x_dims = x->dims(); framework::DDim x_dims = x->dims();
...@@ -66,15 +100,29 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> { ...@@ -66,15 +100,29 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
auto recv_buf = out->mutable_data<T>(out_dims, place); auto recv_buf = out->mutable_data<T>(out_dims, place);
size_t offset = 0; size_t offset = 0;
send_numel /= nranks; send_numel /= nranks;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); if (comm_ctx) {
for (auto i = 0; i < nranks; ++i) { comm_ctx->GroupStart();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( for (auto i = 0; i < nranks; ++i) {
send_buf + offset, send_numel, dtype, i, comm->comm(), stream)); auto send_buf = distributed::GetPartialTensor(*x, offset, send_numel);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( comm_ctx->Send(send_buf, send_numel, i, stream);
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); auto recv_buf = distributed::GetPartialTensor(*out, offset, send_numel);
offset += send_numel; comm_ctx->Recv(&recv_buf, send_numel, i, stream);
offset += send_numel;
}
comm_ctx->GroupEnd();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < nranks; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
send_buf + offset, send_numel, dtype, i, comm->comm(), stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
recv_buf + offset, send_numel, dtype, i, comm->comm(), stream));
offset += send_numel;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
......
...@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/barrier_op.h" #include "paddle/fluid/operators/collective/barrier_op.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
namespace paddle { namespace paddle {
...@@ -38,13 +42,45 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> { ...@@ -38,13 +42,45 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
void* recvbuff = out->mutable_data<T>(place); void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); const auto& comm_context_manager =
// should ExecutionContext for calc stream. phi::distributed::CommContextManager::GetInstance();
auto stream = ctx.cuda_device_context().stream(); if (FLAGS_dynamic_static_unified_comm) {
ncclRedOp_t nccl_red_type = ncclSum; PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( true,
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); platform::errors::InvalidArgument(
platform::GpuStreamSync(stream); "You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
auto comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
auto stream = comm_ctx->GetStream();
ncclRedOp_t nccl_red_type = ncclSum;
comm_ctx->AllReduce(out, *in, nccl_red_type, stream);
platform::GpuStreamSync(stream);
VLOG(3) << "new NCCLCommContext has rid " << rid;
} else {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
// should ExecutionContext for calc stream.
auto stream = ctx.cuda_device_context().stream();
ncclRedOp_t nccl_red_type = ncclSum;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(sendbuff,
recvbuff,
numel,
dtype,
nccl_red_type,
comm->comm(),
stream));
platform::GpuStreamSync(stream);
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with NCCL.")); "PaddlePaddle should compile with NCCL."));
......
...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h" #include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
...@@ -50,32 +55,63 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -50,32 +55,63 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
int64_t send_numel = in->numel(); int64_t send_numel = in->numel();
const T* send_buff = in->data<T>(); const T* send_buff = in->data<T>();
T* recv_buff = out->mutable_data<T>(place); T* recv_buff = out->mutable_data<T>(place);
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream. // should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream(); stream = ctx.cuda_device_context().stream();
}
if (comm_ctx) {
comm_ctx->AllGather(out, *in, stream);
} else { } else {
stream = comm->stream(); PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
send_numel,
static_cast<ncclDataType_t>(dtype),
comm->comm(),
stream));
} }
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
send_numel,
static_cast<ncclDataType_t>(dtype),
comm->comm(),
stream));
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU.")); "PaddlePaddle should compile with GPU."));
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) defined(PADDLE_WITH_XPU_BKCL)
...@@ -31,6 +32,9 @@ limitations under the License. */ ...@@ -31,6 +32,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
...@@ -293,16 +297,41 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -293,16 +297,41 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
comm = platform::NCCLCommContext::Instance().Get(rid, place);
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
// should not use global ctx for calc stream. // should not use global ctx for calc stream.
// auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
// stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); // stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
stream = ctx.cuda_device_context().stream(); stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
} }
VLOG(10) << "all reduce buffer:" << sendbuff << ", numel:" << numel VLOG(10) << "all reduce buffer:" << sendbuff << ", numel:" << numel
<< ", redtype:" << static_cast<int>(red_type) << ", redtype:" << static_cast<int>(red_type)
...@@ -332,8 +361,17 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -332,8 +361,17 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
"Invalid reduce type: %d", red_type)); "Invalid reduce type: %d", red_type));
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( if (comm_ctx) {
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); comm_ctx->AllReduce(out, *in, nccl_red_type, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(sendbuff,
recvbuff,
numel,
dtype,
nccl_red_type,
comm->comm(),
stream));
}
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU.")); "PaddlePaddle should compile with GPU."));
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_concat_op.h" #include "paddle/fluid/operators/collective/c_concat_op.h"
#include <vector> #include <vector>
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
...@@ -23,6 +25,9 @@ limitations under the License. */ ...@@ -23,6 +25,9 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
namespace paddle { namespace paddle {
...@@ -68,6 +73,12 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> { ...@@ -68,6 +73,12 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
temp_out.mutable_data<T>(temp_out_dims, place); temp_out.mutable_data<T>(temp_out_dims, place);
auto map = distributed::ProcessGroupMapFromGid::getInstance(); auto map = distributed::ProcessGroupMapFromGid::getInstance();
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
if (map->has(rid)) { if (map->has(rid)) {
// Use ProcessGroup // Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid); distributed::ProcessGroup* pg = map->get(rid);
...@@ -78,27 +89,55 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> { ...@@ -78,27 +89,55 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
auto task = pg->AllGather(in_tensor, out_tensor); auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait(); task->Wait();
} else { } else {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); platform::NCCLComm* comm = nullptr;
PADDLE_ENFORCE_EQ( phi::distributed::NCCLCommContext* comm_ctx = nullptr;
nranks, const auto& comm_context_manager =
comm->nranks(), phi::distributed::CommContextManager::GetInstance();
platform::errors::InvalidArgument( if (FLAGS_dynamic_static_unified_comm) {
"nranks: %s should equal to %s", nranks, comm->nranks())); PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
int64_t send_numel = x->numel(); platform::errors::InvalidArgument(
const T* send_buff = x->data<T>(); "You choose to use new communication library by "
T* recv_buff = temp_out.data<T>(); "setting environment "
gpuStream_t stream = nullptr; "variable FLAGS_dynamic_static_unified_comm "
// should ExecutionContext for calc stream. "True. But ring_id(%d) is "
stream = ctx.cuda_device_context().stream(); "not found in comm_context_manager.",
std::to_string(rid)));
PADDLE_ENFORCE_GPU_SUCCESS( comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
platform::dynload::ncclAllGather(send_buff, comm_context_manager.Get(std::to_string(rid)));
recv_buff, PADDLE_ENFORCE_NE(
send_numel, comm_ctx,
static_cast<ncclDataType_t>(dtype), nullptr,
comm->comm(), platform::errors::Unavailable(
stream)); "NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}
if (comm_ctx) {
comm_ctx->AllGather(&temp_out, *x, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
send_numel,
static_cast<ncclDataType_t>(dtype),
comm->comm(),
stream));
}
} }
std::vector<phi::DenseTensor> inputs; std::vector<phi::DenseTensor> inputs;
......
...@@ -19,11 +19,13 @@ limitations under the License. */ ...@@ -19,11 +19,13 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) defined(PADDLE_WITH_XPU_BKCL)
...@@ -32,6 +34,9 @@ limitations under the License. */ ...@@ -32,6 +34,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
...@@ -220,14 +225,40 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -220,14 +225,40 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
int root = ctx.Attr<int>("root_id"); int root = ctx.Attr<int>("root_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream. // should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream(); stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
} }
ncclRedOp_t nccl_red_type = ncclSum; ncclRedOp_t nccl_red_type = ncclSum;
...@@ -256,14 +287,18 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -256,14 +287,18 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
"kRedMax, kRedMin, kRedProd.")); "kRedMax, kRedMin, kRedProd."));
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(sendbuff, if (comm_ctx) {
recvbuff, comm_ctx->Reduce(out, *in, nccl_red_type, root, stream);
numel, } else {
dtype, PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(sendbuff,
nccl_red_type, recvbuff,
root, numel,
comm->comm(), dtype,
stream)); nccl_red_type,
root,
comm->comm(),
stream));
}
#else #else
PADDLE_ENFORCE_EQ(true, PADDLE_ENFORCE_EQ(true,
false, false,
......
...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
namespace paddle { namespace paddle {
...@@ -32,10 +37,58 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -32,10 +37,58 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
int nranks = comm->nranks();
auto out_dims = in->dims(); auto out_dims = in->dims();
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
PADDLE_ENFORCE_EQ(out_dims[0] % comm_ctx->GetSize(),
0,
platform::errors::InvalidArgument(
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0],
comm_ctx->GetSize()));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has ring_id " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(out_dims[0] % comm->nranks(),
0,
platform::errors::InvalidArgument(
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0],
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has ring_id " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}
int nranks = comm_ctx ? comm_ctx->GetSize() : comm->nranks();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks, PADDLE_ENFORCE_EQ(out_dims[0] % nranks,
0, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -52,22 +105,18 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -52,22 +105,18 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
int dtype = int dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(in->dtype())); platform::ToNCCLDataType(framework::TransToProtoVarType(in->dtype()));
gpuStream_t stream = nullptr; if (comm_ctx) {
if (ctx.Attr<bool>("use_calc_stream")) { comm_ctx->ReduceScatter(out, *in, ncclSum, stream);
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
send_buff,
recv_buff,
recv_numel,
static_cast<ncclDataType_t>(dtype),
ncclSum,
comm->comm(),
stream));
} }
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclReduceScatter(send_buff,
recv_buff,
recv_numel,
static_cast<ncclDataType_t>(dtype),
ncclSum,
comm->comm(),
stream));
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU.")); "PaddlePaddle should compile with GPU."));
......
...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_scatter_op.h" #include "paddle/fluid/operators/collective/c_scatter_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
namespace paddle { namespace paddle {
...@@ -37,14 +42,9 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -37,14 +42,9 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
int root_id = ctx.Attr<int>("root"); int root_id = ctx.Attr<int>("root");
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr;
PADDLE_ENFORCE_EQ(nranks, platform::NCCLComm* comm = nullptr;
comm->nranks(), phi::distributed::NCCLCommContext* comm_ctx = nullptr;
platform::errors::InvalidArgument(
"The number of ranks (%d) you set of must "
"be equal to comm->nranks (%d).",
nranks,
comm->nranks()));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
root_id, root_id,
0, 0,
...@@ -58,38 +58,95 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -58,38 +58,95 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
"The ring_id (%d) for c_scatter_op must be non-negative.", "The ring_id (%d) for c_scatter_op must be non-negative.",
ring_id)); ring_id));
gpuStream_t stream = nullptr; const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
PADDLE_ENFORCE_EQ(nranks,
comm_ctx->GetSize(),
platform::errors::InvalidArgument(
"The number of ranks (%d) you set of must "
"be equal to comm_ctx->GetSize() (%d).",
nranks,
comm_ctx->GetSize()));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has ring_id " << ring_id;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
PADDLE_ENFORCE_EQ(nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"The number of ranks (%d) you set of must "
"be equal to comm->nranks (%d).",
nranks,
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has ring_id " << ring_id;
}
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream. // should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream(); stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
} }
framework::DDim x_dims = x->dims(); framework::DDim x_dims = x->dims();
framework::DDim out_dims(x_dims); framework::DDim out_dims(x_dims);
phi::DenseTensor temp; phi::DenseTensor temp;
auto out_ptr = temp.mutable_data<T>(out_dims, place); auto out_ptr = temp.mutable_data<T>(out_dims, place);
if (root_id == comm->rank()) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast( if (FLAGS_dynamic_static_unified_comm) {
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), if (root_id == comm_ctx->GetRank()) {
numel, comm_ctx->Broadcast(
dtype, const_cast<phi::DenseTensor*>(x), *x, root_id, stream);
root_id, framework::TensorCopy(
comm->comm(), *static_cast<const phi::DenseTensor*>(x),
stream)); place,
*platform::DeviceContextPool::Instance().Get(place),
framework::TensorCopy(*static_cast<const phi::DenseTensor*>(x), static_cast<phi::DenseTensor*>(&temp));
place, } else {
*platform::DeviceContextPool::Instance().Get(place), comm_ctx->Broadcast(&temp, temp, root_id, stream);
static_cast<phi::DenseTensor*>(&temp)); }
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast( if (root_id == comm->rank()) {
out_ptr, numel, dtype, root_id, comm->comm(), stream)); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())),
numel,
dtype,
root_id,
comm->comm(),
stream));
framework::TensorCopy(
*static_cast<const phi::DenseTensor*>(x),
place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<phi::DenseTensor*>(&temp));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
out_ptr, numel, dtype, root_id, comm->comm(), stream));
}
} }
out_dims[0] = out_dims[0] / nranks; out_dims[0] = out_dims[0] / nranks;
auto start_index = out_dims[0] * comm->rank(); auto start_index = FLAGS_dynamic_static_unified_comm
? out_dims[0] * comm_ctx->GetRank()
: out_dims[0] * comm->rank();
auto end_index = start_index + out_dims[0]; auto end_index = start_index + out_dims[0];
temp = temp.Slice(start_index, end_index); temp = temp.Slice(start_index, end_index);
temp.Resize(out_dims); temp.Resize(out_dims);
......
...@@ -20,6 +20,10 @@ limitations under the License. */ ...@@ -20,6 +20,10 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
...@@ -36,8 +40,29 @@ class CSyncCommStreamKernel : public framework::OpKernel<T> { ...@@ -36,8 +40,29 @@ class CSyncCommStreamKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
auto stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); gpuStream_t stream = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
auto comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
platform::GpuStreamSync(stream); platform::GpuStreamSync(stream);
......
...@@ -21,6 +21,10 @@ class Scope; ...@@ -21,6 +21,10 @@ class Scope;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
namespace paddle { namespace paddle {
...@@ -46,15 +50,40 @@ class CWaitCommOp : public framework::OperatorBase { ...@@ -46,15 +50,40 @@ class CWaitCommOp : public framework::OperatorBase {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = Attr<int>("ring_id"); int ring_id = Attr<int>("ring_id");
auto compute_stream = gpuStream_t compute_stream =
static_cast<phi::GPUContext*>( static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place)) platform::DeviceContextPool::Instance().Get(place))
->stream(); ->stream();
auto comm_stream = gpuStream_t comm_stream = nullptr;
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); gpuEvent_t event = nullptr;
auto event = const auto& comm_context_manager =
platform::NCCLCommContext::Instance().Get(ring_id, place)->comm_event(); phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
phi::distributed::NCCLCommContext* comm_ctx =
static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
comm_stream = comm_ctx->GetStream();
event = comm_ctx->GetComputeEvent();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
event = platform::NCCLCommContext::Instance()
.Get(ring_id, place)
->comm_event();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
// comm_stream-->event-->compute_stream // comm_stream-->event-->compute_stream
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
......
...@@ -21,6 +21,10 @@ class Scope; ...@@ -21,6 +21,10 @@ class Scope;
} // namespace paddle } // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
namespace paddle { namespace paddle {
...@@ -46,16 +50,40 @@ class CWaitComputeOp : public framework::OperatorBase { ...@@ -46,16 +50,40 @@ class CWaitComputeOp : public framework::OperatorBase {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = Attr<int>("ring_id"); int ring_id = Attr<int>("ring_id");
auto compute_stream = gpuStream_t compute_stream =
static_cast<phi::GPUContext*>( static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place)) platform::DeviceContextPool::Instance().Get(place))
->stream(); ->stream();
auto comm_stream = gpuStream_t comm_stream = nullptr;
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); gpuEvent_t event = nullptr;
auto event = platform::NCCLCommContext::Instance() const auto& comm_context_manager =
.Get(ring_id, place) phi::distributed::CommContextManager::GetInstance();
->compute_event(); if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
phi::distributed::NCCLCommContext* comm_ctx =
static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
comm_stream = comm_ctx->GetStream();
event = comm_ctx->GetComputeEvent();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
event = platform::NCCLCommContext::Instance()
.Get(ring_id, place)
->compute_event();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
// compute_stream-->event-->comm_stream // compute_stream-->event-->comm_stream
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
......
...@@ -17,6 +17,10 @@ limitations under the License. */ ...@@ -17,6 +17,10 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
...@@ -30,14 +34,16 @@ namespace operators { ...@@ -30,14 +34,16 @@ namespace operators {
framework::DDim recv_shape_info(const platform::Place &place, framework::DDim recv_shape_info(const platform::Place &place,
const gpuStream_t &stream, const gpuStream_t &stream,
platform::NCCLComm *comm, platform::NCCLComm *comm,
phi::distributed::NCCLCommContext *comm_ctx,
const int &peer, const int &peer,
distributed::ProcessGroup *group) { distributed::ProcessGroup *group) {
if (!group) { if (!group) {
PADDLE_ENFORCE_EQ((stream != nullptr && comm != nullptr), PADDLE_ENFORCE_EQ(
true, ((stream != nullptr && comm != nullptr) || comm_ctx != nullptr),
platform::errors::InvalidArgument( true,
"NCCLComm and Stream should be provided if use NCCL " platform::errors::InvalidArgument(
"to send the shape info.")); "NCCLComm and Stream should be provided if use NCCL "
"to send the shape info."));
} }
phi::DataType shape_dtype = phi::DataType::INT32; phi::DataType shape_dtype = phi::DataType::INT32;
...@@ -50,8 +56,13 @@ framework::DDim recv_shape_info(const platform::Place &place, ...@@ -50,8 +56,13 @@ framework::DDim recv_shape_info(const platform::Place &place,
gpu_shape_size_tensor.Resize({1}); gpu_shape_size_tensor.Resize({1});
gpu_shape_size_tensor.mutable_data(place, shape_dtype); gpu_shape_size_tensor.mutable_data(place, shape_dtype);
auto *gpu_data = gpu_shape_size_tensor.data<int>(); auto *gpu_data = gpu_shape_size_tensor.data<int>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
gpu_data, 1, nccl_dtype, peer, comm->comm(), stream)); if (comm_ctx) {
comm_ctx->Recv(&gpu_shape_size_tensor, 1, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
gpu_data, 1, nccl_dtype, peer, comm->comm(), stream));
}
} }
// copy the shape size tensor to cpu // copy the shape size tensor to cpu
...@@ -76,8 +87,12 @@ framework::DDim recv_shape_info(const platform::Place &place, ...@@ -76,8 +87,12 @@ framework::DDim recv_shape_info(const platform::Place &place,
gpu_shape_tensor.Resize({shape_size}); gpu_shape_tensor.Resize({shape_size});
gpu_shape_tensor.mutable_data(place, shape_dtype); gpu_shape_tensor.mutable_data(place, shape_dtype);
auto *gpu_shape_data = gpu_shape_tensor.data<int>(); auto *gpu_shape_data = gpu_shape_tensor.data<int>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( if (comm_ctx) {
gpu_shape_data, shape_size, nccl_dtype, peer, comm->comm(), stream)); comm_ctx->Recv(&gpu_shape_tensor, shape_size, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
gpu_shape_data, shape_size, nccl_dtype, peer, comm->comm(), stream));
}
} }
// copy the shape tensor to cpu // copy the shape tensor to cpu
...@@ -139,11 +154,13 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -139,11 +154,13 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
if (dynamic_shape) { if (dynamic_shape) {
VLOG(3) << "recv_v2 will use dynamic shape with send_v2 for switch"; VLOG(3) << "recv_v2 will use dynamic shape with send_v2 for switch";
framework::DDim new_dim = recv_shape_info(ctx.GetPlace(), framework::DDim new_dim =
/* gpuStream_t */ nullptr, recv_shape_info(ctx.GetPlace(),
/* NCCLComm* */ nullptr, /* gpuStream_t */ nullptr,
peer, /* NCCLComm* */ nullptr,
pg); /* NCCLCommContext* */ nullptr,
peer,
pg);
out->Resize(new_dim); out->Resize(new_dim);
out->mutable_data<T>(new_dim, place); out->mutable_data<T>(new_dim, place);
} else { } else {
...@@ -154,21 +171,48 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -154,21 +171,48 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
auto task = pg->Recv(out_tensor, peer); auto task = pg->Recv(out_tensor, peer);
return; return;
} }
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) { platform::NCCLComm *comm = nullptr;
// should ExecutionContext for calc stream. phi::distributed::NCCLCommContext *comm_ctx = nullptr;
stream = ctx.cuda_device_context().stream();
const auto &comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else { } else {
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_LT(peer,
comm->nranks(),
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
stream = comm->stream(); stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
} }
PADDLE_ENFORCE_LT(
peer,
comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}
int data_type = ctx.Attr<int>("dtype"); int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type = framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type); framework::proto::VarType::Type(data_type);
...@@ -188,10 +232,14 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -188,10 +232,14 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
auto out_dims = out->dims(); auto out_dims = out->dims();
out->mutable_data<T>(out_dims, place, 0); out->mutable_data<T>(out_dims, place, 0);
auto numel = out->numel(); auto numel = out->numel();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( if (comm_ctx) {
out->data<T>(), numel, dtype, peer, comm->comm(), stream)); comm_ctx->Recv(out, numel, peer, stream);
VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out_dims) } else {
<< " from " << peer; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv "
<< phi::product(out_dims) << " from " << peer;
}
} }
return; return;
} }
...@@ -206,6 +254,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -206,6 +254,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
framework::DDim new_dim = recv_shape_info(place, framework::DDim new_dim = recv_shape_info(place,
stream, stream,
comm, comm,
comm_ctx,
peer, peer,
/* ProcessGroup* */ nullptr); /* ProcessGroup* */ nullptr);
out->Resize(new_dim); out->Resize(new_dim);
...@@ -214,10 +263,22 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -214,10 +263,22 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
} else { } else {
out->mutable_data<T>(out_dims, place); out->mutable_data<T>(out_dims, place);
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( if (comm_ctx) {
out->data<T>(), numel, dtype, peer, comm->comm(), stream)); comm_ctx->Recv(out, numel, peer, stream);
VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out->dims()) } else {
<< " from " << peer; comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_LT(peer,
comm->nranks(),
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv "
<< phi::product(out->dims()) << " from " << peer;
}
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should be compiled with NCCL and " "PaddlePaddle should be compiled with NCCL and "
......
...@@ -17,6 +17,10 @@ limitations under the License. */ ...@@ -17,6 +17,10 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif #endif
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
...@@ -30,14 +34,16 @@ void send_shape_info(const phi::DenseTensor& x, ...@@ -30,14 +34,16 @@ void send_shape_info(const phi::DenseTensor& x,
const platform::Place& place, const platform::Place& place,
const gpuStream_t& stream, const gpuStream_t& stream,
platform::NCCLComm* comm, platform::NCCLComm* comm,
phi::distributed::NCCLCommContext* comm_ctx,
const int& peer, const int& peer,
distributed::ProcessGroup* group) { distributed::ProcessGroup* group) {
if (!group) { if (!group) {
PADDLE_ENFORCE_EQ((stream != nullptr && comm != nullptr), PADDLE_ENFORCE_EQ(
true, ((stream != nullptr && comm != nullptr) || comm_ctx != nullptr),
platform::errors::InvalidArgument( true,
"NCCLComm and Stream should be provided if use NCCL " platform::errors::InvalidArgument(
"to send the shape info.")); "NCCLComm and Stream should be provided if use NCCL "
"to send the shape info."));
} }
phi::DataType shape_dtype = phi::DataType::INT32; phi::DataType shape_dtype = phi::DataType::INT32;
ncclDataType_t nccl_dtype = ncclDataType_t nccl_dtype =
...@@ -63,13 +69,17 @@ void send_shape_info(const phi::DenseTensor& x, ...@@ -63,13 +69,17 @@ void send_shape_info(const phi::DenseTensor& x,
gpu_shape_size_tensor->mutable_data(place, shape_dtype); gpu_shape_size_tensor->mutable_data(place, shape_dtype);
framework::TensorCopySync( framework::TensorCopySync(
cpu_shape_size_tensor, place, gpu_shape_size_tensor); cpu_shape_size_tensor, place, gpu_shape_size_tensor);
PADDLE_ENFORCE_GPU_SUCCESS( if (comm_ctx) {
platform::dynload::ncclSend(gpu_shape_size_tensor->data<int>(), comm_ctx->Send(*gpu_shape_size_tensor, 1, peer, stream);
1, } else {
nccl_dtype, PADDLE_ENFORCE_GPU_SUCCESS(
peer, platform::dynload::ncclSend(gpu_shape_size_tensor->data<int>(),
comm->comm(), 1,
stream)); nccl_dtype,
peer,
comm->comm(),
stream));
}
} }
VLOG(3) << "send the shape size: " << shape_size << " to peer"; VLOG(3) << "send the shape size: " << shape_size << " to peer";
...@@ -92,13 +102,17 @@ void send_shape_info(const phi::DenseTensor& x, ...@@ -92,13 +102,17 @@ void send_shape_info(const phi::DenseTensor& x,
gpu_shape_tensor->Resize({shape_size}); gpu_shape_tensor->Resize({shape_size});
gpu_shape_tensor->mutable_data(place, shape_dtype); gpu_shape_tensor->mutable_data(place, shape_dtype);
framework::TensorCopySync(cpu_shape_tensor, place, gpu_shape_tensor); framework::TensorCopySync(cpu_shape_tensor, place, gpu_shape_tensor);
PADDLE_ENFORCE_GPU_SUCCESS( if (comm_ctx) {
platform::dynload::ncclSend(gpu_shape_tensor->data<int>(), comm_ctx->Send(*gpu_shape_tensor, shape_size, peer, stream);
shape_size, } else {
nccl_dtype, PADDLE_ENFORCE_GPU_SUCCESS(
peer, platform::dynload::ncclSend(gpu_shape_tensor->data<int>(),
comm->comm(), shape_size,
stream)); nccl_dtype,
peer,
comm->comm(),
stream));
}
} }
VLOG(3) << "send the shape: (" << dims << ") to peer"; VLOG(3) << "send the shape: (" << dims << ") to peer";
} }
...@@ -137,6 +151,7 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -137,6 +151,7 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
ctx.GetPlace(), ctx.GetPlace(),
/* gpuStream_t */ nullptr, /* gpuStream_t */ nullptr,
/* NCCLComm* */ nullptr, /* NCCLComm* */ nullptr,
/* NCCLCommContext * */ nullptr,
peer, peer,
pg); pg);
} }
...@@ -148,20 +163,47 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -148,20 +163,47 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
} }
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_LT(peer,
comm->nranks(),
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream. // should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream(); stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
} }
PADDLE_ENFORCE_LT(
peer,
comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
auto* x_var = ctx.InputVar("X"); auto* x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensorArray>()) { if (x_var->IsType<framework::LoDTensorArray>()) {
...@@ -177,8 +219,12 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -177,8 +219,12 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
int numel = x.numel(); int numel = x.numel();
ncclDataType_t dtype = ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x.dtype())); platform::ToNCCLDataType(framework::TransToProtoVarType(x.dtype()));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( if (comm_ctx) {
x.data<T>(), numel, dtype, peer, comm->comm(), stream)); comm_ctx->Send(x, numel, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
x.data<T>(), numel, dtype, peer, comm->comm(), stream));
}
VLOG(3) << "rank " << comm->rank() << " send " << phi::product(x.dims()) VLOG(3) << "rank " << comm->rank() << " send " << phi::product(x.dims())
<< " to " << peer; << " to " << peer;
} }
...@@ -193,16 +239,21 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -193,16 +239,21 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
place, place,
stream, stream,
comm, comm,
comm_ctx,
peer, peer,
/* ProcessGroup* */ nullptr); /* ProcessGroup* */ nullptr);
} }
ncclDataType_t dtype = if (comm_ctx) {
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype())); comm_ctx->Send(*x, numel, peer, stream);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( } else {
x->data<T>(), numel, dtype, peer, comm->comm(), stream)); ncclDataType_t dtype =
VLOG(3) << "rank " << comm->rank() << " send " << phi::product(x->dims()) platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
<< " to " << peer; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
x->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " send " << phi::product(x->dims())
<< " to " << peer;
}
#else #else
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should be compiled with NCCL " "PaddlePaddle should be compiled with NCCL "
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#endif #endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif #endif
...@@ -54,7 +55,8 @@ void CommContextManager::CreateNCCLCommContext( ...@@ -54,7 +55,8 @@ void CommContextManager::CreateNCCLCommContext(
const std::shared_ptr<Store>& store, const std::shared_ptr<Store>& store,
const std::string& unique_comm_key, const std::string& unique_comm_key,
int rank, int rank,
int size) { int size,
const std::string& hash_key) {
auto& comm_context_manager = CommContextManager::GetInstance(); auto& comm_context_manager = CommContextManager::GetInstance();
if (comm_context_manager.Has(unique_comm_key)) { if (comm_context_manager.Has(unique_comm_key)) {
return; return;
...@@ -64,7 +66,7 @@ void CommContextManager::CreateNCCLCommContext( ...@@ -64,7 +66,7 @@ void CommContextManager::CreateNCCLCommContext(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id));
} }
std::string unique_key = "NCCLCommContext/" + unique_comm_key; std::string unique_key = "NCCLCommContext/" + unique_comm_key + hash_key;
if (rank == 0) { if (rank == 0) {
std::vector<uint8_t> nccl_id_wrapper( std::vector<uint8_t> nccl_id_wrapper(
reinterpret_cast<uint8_t*>(&nccl_id), reinterpret_cast<uint8_t*>(&nccl_id),
...@@ -77,7 +79,6 @@ void CommContextManager::CreateNCCLCommContext( ...@@ -77,7 +79,6 @@ void CommContextManager::CreateNCCLCommContext(
auto nccl_comm_context = auto nccl_comm_context =
std::make_unique<NCCLCommContext>(rank, size, nccl_id); std::make_unique<NCCLCommContext>(rank, size, nccl_id);
if (CommContextManager::device_id != -1) { if (CommContextManager::device_id != -1) {
std::unique_ptr<phi::GPUContext> dev_ctx( std::unique_ptr<phi::GPUContext> dev_ctx(
new phi::GPUContext(phi::GPUPlace(CommContextManager::device_id))); new phi::GPUContext(phi::GPUPlace(CommContextManager::device_id)));
......
...@@ -52,7 +52,8 @@ class CommContextManager { ...@@ -52,7 +52,8 @@ class CommContextManager {
static void CreateNCCLCommContext(const std::shared_ptr<Store>& store, static void CreateNCCLCommContext(const std::shared_ptr<Store>& store,
const std::string& unique_comm_key, const std::string& unique_comm_key,
int rank, int rank,
int size); int size,
const std::string& hash_key = "");
#endif #endif
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
......
...@@ -1328,3 +1328,19 @@ PHI_DEFINE_EXPORTED_int64(host_trace_level, ...@@ -1328,3 +1328,19 @@ PHI_DEFINE_EXPORTED_int64(host_trace_level,
1, 1,
"RecordEvent will works " "RecordEvent will works "
"if host_trace_level >= level."); "if host_trace_level >= level.");
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
/**
* Communication library related FLAG
* Name: FLAGS_dynamic_static_unified_comm
* Since Version: 2.5
* Value Range: bool, default=false
* Example:
* Note: Whether to use new communication library in auto parallel and static
* mode. If true, it will use unified CommContextManager for communication.
*/
PHI_DEFINE_EXPORTED_bool(dynamic_static_unified_comm,
false,
"Whether to use new communication library in auto "
"parallel and static mode.");
#endif // FLAGS_dynamic_static_unified_comm
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
import hashlib
import os import os
from collections import OrderedDict from collections import OrderedDict
...@@ -162,12 +163,21 @@ class ProcessGroup: ...@@ -162,12 +163,21 @@ class ProcessGroup:
) )
if use_new_comm in ["1", "True", "true"]: if use_new_comm in ["1", "True", "true"]:
store = core.create_or_get_global_tcp_store() store = core.create_or_get_global_tcp_store()
endpoints_str = ""
for endpoint in strategy.trainer_endpoints:
endpoints_str += endpoint
endpoints_str += f"ring_id:{ring_id}"
endpoints_str_hash = hashlib.md5(
endpoints_str.encode(encoding='UTF-8')
).hexdigest()
core.CommContextManager.set_device_id(genv.device_id) core.CommContextManager.set_device_id(genv.device_id)
core.CommContextManager.create_nccl_comm_context( core.CommContextManager.create_nccl_comm_context(
store, store,
str(ring_id), str(ring_id),
strategy.local_rank, strategy.local_rank,
strategy.nranks, strategy.nranks,
endpoints_str_hash,
) )
else: else:
core.NCCLParallelContext(strategy, place).init_with_ring_id( core.NCCLParallelContext(strategy, place).init_with_ring_id(
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import hashlib
import paddle import paddle
...@@ -330,9 +331,16 @@ def _init_parallel_env(backend): ...@@ -330,9 +331,16 @@ def _init_parallel_env(backend):
store, "0", rank, world_size store, "0", rank, world_size
) )
elif backend == "nccl": elif backend == "nccl":
endpoints_str = ""
for endpoint in global_env.trainer_endpoints:
endpoints_str += endpoint
endpoints_str += "ring_id:{}".format("0")
endpoints_str_hash = hashlib.md5(
endpoints_str.encode(encoding='UTF-8')
).hexdigest()
core.CommContextManager.set_device_id(dev_id) core.CommContextManager.set_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context( core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size store, "0", rank, world_size, endpoints_str_hash
) )
elif backend == "xccl": elif backend == "xccl":
dev_type = global_env.device_type dev_type = global_env.device_type
......
...@@ -81,10 +81,10 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) ...@@ -81,10 +81,10 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
if((NOT WITH_ROCM) AND ((${CUDA_ARCH_NAME}) STREQUAL "Ampere")) if((NOT WITH_ROCM) AND ((${CUDA_ARCH_NAME}) STREQUAL "Ampere"))
set_tests_properties(test_collective_alltoall_api set_tests_properties(test_collective_alltoall_api
PROPERTIES TIMEOUT "160" LABELS "RUN_TYPE=DIST") PROPERTIES TIMEOUT "280" LABELS "RUN_TYPE=DIST")
else() else()
set_tests_properties(test_collective_alltoall_api set_tests_properties(test_collective_alltoall_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") PROPERTIES TIMEOUT "240" LABELS "RUN_TYPE=DIST")
endif() endif()
endif() endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX)) if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...@@ -286,7 +286,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) ...@@ -286,7 +286,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules( py_test_modules(
test_collective_sendrecv MODULES test_collective_sendrecv ENVS test_collective_sendrecv MODULES test_collective_sendrecv ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_sendrecv PROPERTIES TIMEOUT "300" LABELS set_tests_properties(test_collective_sendrecv PROPERTIES TIMEOUT "500" LABELS
"RUN_TYPE=DIST") "RUN_TYPE=DIST")
endif() endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX)) if((WITH_GPU OR WITH_ROCM) AND (LINUX))
...@@ -294,7 +294,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) ...@@ -294,7 +294,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_sendrecv_api set_tests_properties(test_collective_sendrecv_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") PROPERTIES TIMEOUT "500" LABELS "RUN_TYPE=DIST")
endif() endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX)) if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules( py_test_modules(
......
...@@ -109,7 +109,7 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase): ...@@ -109,7 +109,7 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
rank = args["trainerid"] rank = args["trainerid"]
current_endpoint = args["currentendpoint"] current_endpoint = args["currentendpoint"]
nranks = 2 nranks = 2
if args["use_comm_context"]: if args["use_comm_context"] or args["dynamic_static_unified_comm"]:
paddle.distributed.collective._init_parallel_env(args["backend"]) paddle.distributed.collective._init_parallel_env(args["backend"])
else: else:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
......
...@@ -93,6 +93,20 @@ class TestCollectiveAllreduceAPI(TestCollectiveAPIRunnerBase): ...@@ -93,6 +93,20 @@ class TestCollectiveAllreduceAPI(TestCollectiveAPIRunnerBase):
all_reduce_new(tindata, reduce_type) all_reduce_new(tindata, reduce_type)
return [tindata] return [tindata]
def get_model_new_comm(
self,
main_prog,
startup_program,
rank,
dtype='float32',
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
paddle.distributed.all_reduce(tindata)
return [tindata]
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestCollectiveAllreduceAPI, "allreduce") runtime_main(TestCollectiveAllreduceAPI, "allreduce")
...@@ -105,6 +105,9 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase): ...@@ -105,6 +105,9 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase):
return toutdata return toutdata
def get_model_new_comm(self, main_prog, startup_program, dtype="float32"):
return self.get_model(main_prog, startup_program)
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestCollectiveAllreduce, "allreduce", 0) runtime_main(TestCollectiveAllreduce, "allreduce", 0)
...@@ -126,6 +126,19 @@ class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase): ...@@ -126,6 +126,19 @@ class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
alltoall_new(tindata, tout_data) alltoall_new(tindata, tout_data)
return tout_data return tout_data
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype='float32'
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
tindata = paddle.split(tindata, 2, axis=0)
tout_data = []
paddle.distributed.alltoall(tindata, tout_data)
return tout_data
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestCollectiveAllToAllAPI, "alltoall") runtime_main(TestCollectiveAllToAllAPI, "alltoall")
...@@ -33,6 +33,13 @@ class TestCollectiveBarrierAPI(TestCollectiveAPIRunnerBase): ...@@ -33,6 +33,13 @@ class TestCollectiveBarrierAPI(TestCollectiveAPIRunnerBase):
paddle.distributed.barrier() paddle.distributed.barrier()
return [] return []
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype="float32"
):
with base.program_guard(main_prog, startup_program):
paddle.distributed.barrier()
return []
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestCollectiveBarrierAPI, "barrier") runtime_main(TestCollectiveBarrierAPI, "barrier")
...@@ -60,6 +60,39 @@ def concat_new(tensor, group=None): ...@@ -60,6 +60,39 @@ def concat_new(tensor, group=None):
return out return out
def concat_new_comm(tensor, group=None, rank=0):
op_type = 'c_concat'
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
],
op_type,
)
helper = framework.LayerHelper(op_type, **locals())
ring_id = 0 if group is None else group.id
nranks = 2
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'ring_id': ring_id,
'nranks': nranks,
'rank': rank,
},
)
return out
class TestCollectiveConcatAPI(TestCollectiveAPIRunnerBase): class TestCollectiveConcatAPI(TestCollectiveAPIRunnerBase):
def __init__(self): def __init__(self):
self.global_ring_id = 0 self.global_ring_id = 0
...@@ -78,6 +111,18 @@ class TestCollectiveConcatAPI(TestCollectiveAPIRunnerBase): ...@@ -78,6 +111,18 @@ class TestCollectiveConcatAPI(TestCollectiveAPIRunnerBase):
toutdata = concat_new(tindata) toutdata = concat_new(tindata)
return [toutdata] return [toutdata]
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype="float32"
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
toutdata = concat_new_comm(tindata, rank=rank)
return [toutdata]
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestCollectiveConcatAPI, "concat") runtime_main(TestCollectiveConcatAPI, "concat")
...@@ -92,6 +92,17 @@ class TestCollectiveReduceAPI(TestCollectiveAPIRunnerBase): ...@@ -92,6 +92,17 @@ class TestCollectiveReduceAPI(TestCollectiveAPIRunnerBase):
reduce_new(tindata, dst=0, reduce_type=reduce_type) reduce_new(tindata, dst=0, reduce_type=reduce_type)
return [tindata] return [tindata]
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype='float32'
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
paddle.distributed.reduce(tindata, dst=0)
return [tindata]
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestCollectiveReduceAPI, "reduce") runtime_main(TestCollectiveReduceAPI, "reduce")
...@@ -50,6 +50,20 @@ class TestCollectiveReduceScatterAPI(TestCollectiveAPIRunnerBase): ...@@ -50,6 +50,20 @@ class TestCollectiveReduceScatterAPI(TestCollectiveAPIRunnerBase):
paddle.distributed.reduce_scatter(toutdata, tindata) paddle.distributed.reduce_scatter(toutdata, tindata)
return [toutdata] return [toutdata]
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype="float32"
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
toutdata = paddle.static.data(
name="toutdata", shape=[5, 1000], dtype=dtype
)
paddle.distributed.reduce_scatter(toutdata, tindata)
return [toutdata]
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestCollectiveReduceScatterAPI, "reduce_scatter") runtime_main(TestCollectiveReduceScatterAPI, "reduce_scatter")
...@@ -44,6 +44,24 @@ class TestCollectiveScatterAPI(TestCollectiveAPIRunnerBase): ...@@ -44,6 +44,24 @@ class TestCollectiveScatterAPI(TestCollectiveAPIRunnerBase):
paddle.distributed.scatter(toutdata, tensor_list, src=1) paddle.distributed.scatter(toutdata, tensor_list, src=1)
return [toutdata] return [toutdata]
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype="float32"
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata",
shape=[10, 1000],
dtype=dtype,
)
toutdata = paddle.tensor.fill_constant(
shape=[5, 1000], dtype=dtype, value=1.0
)
tensor_list = None
if rank == 1:
tensor_list = paddle.split(tindata, 2, axis=0)
paddle.distributed.scatter(toutdata, tensor_list, src=1)
return [toutdata]
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestCollectiveScatterAPI, "scatter") runtime_main(TestCollectiveScatterAPI, "scatter")
...@@ -66,6 +66,26 @@ class TestCollectiveAllgatherAPI(TestDistBase): ...@@ -66,6 +66,26 @@ class TestCollectiveAllgatherAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"}, need_envs={"USE_COMM_CONTEXT": "1"},
) )
def test_allgather_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
"int8",
"uint8",
"bool",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_allgather_api.py",
"allgather",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_allgather_gloo(self): def test_allgather_gloo(self):
dtypes_to_test = [ dtypes_to_test = [
"float16", "float16",
......
...@@ -59,6 +59,30 @@ class TestCollectiveAllreduceAPI(TestDistBase): ...@@ -59,6 +59,30 @@ class TestCollectiveAllreduceAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"}, need_envs={"USE_COMM_CONTEXT": "1"},
) )
def test_allreduce_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
red_types_to_test = [
dist.ReduceOp.SUM,
]
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
for red_type in red_types_to_test:
self.check_with_place(
"collective_allreduce_api.py",
"allreduce",
"nccl",
dtype=dtype,
reduce_type=red_type,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_allreduce_bkcl(self): def test_allreduce_bkcl(self):
if paddle.base.core.is_compiled_with_xpu(): if paddle.base.core.is_compiled_with_xpu():
self.check_with_place( self.check_with_place(
......
...@@ -43,6 +43,23 @@ class TestCollectiveAllToAllAPI(TestDistBase): ...@@ -43,6 +43,23 @@ class TestCollectiveAllToAllAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"}, need_envs={"USE_COMM_CONTEXT": "1"},
) )
def test_alltoall_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_alltoall_api.py",
"alltoall",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_alltoall_nccl_dygraph(self): def test_alltoall_nccl_dygraph(self):
dtypes_to_test = [ dtypes_to_test = [
"float16", "float16",
......
...@@ -28,6 +28,14 @@ class TestCollectiveBarrierAPI(TestDistBase): ...@@ -28,6 +28,14 @@ class TestCollectiveBarrierAPI(TestDistBase):
def test_barrier_nccl(self): def test_barrier_nccl(self):
self.check_with_place("collective_barrier_api.py", "barrier", "nccl") self.check_with_place("collective_barrier_api.py", "barrier", "nccl")
def test_barrier_nccl_with_new_comm(self):
self.check_with_place(
"collective_barrier_api.py",
"barrier",
"nccl",
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_barrier_gloo(self): def test_barrier_gloo(self):
self.check_with_place( self.check_with_place(
"collective_barrier_api.py", "barrier", "gloo", "5" "collective_barrier_api.py", "barrier", "gloo", "5"
......
...@@ -47,6 +47,23 @@ class TestCollectiveConcatAPI(TestDistBase): ...@@ -47,6 +47,23 @@ class TestCollectiveConcatAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"}, need_envs={"USE_COMM_CONTEXT": "1"},
) )
def test_concat_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_concat_api.py",
"dist_concat",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -58,6 +58,29 @@ class TestCollectiveReduceAPI(TestDistBase): ...@@ -58,6 +58,29 @@ class TestCollectiveReduceAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"}, need_envs={"USE_COMM_CONTEXT": "1"},
) )
def test_reduce_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
red_types_to_test = [
dist.ReduceOp.SUM,
]
for dtype in dtypes_to_test:
if paddle.base.core.is_compiled_with_cuda():
for red_type in red_types_to_test:
self.check_with_place(
"collective_reduce_api.py",
"reduce",
"nccl",
dtype=dtype,
reduce_type=red_type,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_reduce_bkcl(self): def test_reduce_bkcl(self):
if paddle.base.core.is_compiled_with_xpu(): if paddle.base.core.is_compiled_with_xpu():
self.check_with_place("collective_reduce_api.py", "reduce", "bkcl") self.check_with_place("collective_reduce_api.py", "reduce", "bkcl")
......
...@@ -43,6 +43,25 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase): ...@@ -43,6 +43,25 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"}, need_envs={"USE_COMM_CONTEXT": "1"},
) )
def test_reduce_scatter_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
"collective_reduce_scatter_api.py",
"reduce_scatter",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_reduce_scatter_nccl_dygraph(self): def test_reduce_scatter_nccl_dygraph(self):
dtypes_to_test = [ dtypes_to_test = [
"float16", "float16",
......
...@@ -33,6 +33,23 @@ class TestCollectiveScatterAPI(TestDistBase): ...@@ -33,6 +33,23 @@ class TestCollectiveScatterAPI(TestDistBase):
def test_scatter_nccl(self): def test_scatter_nccl(self):
self.check_with_place("collective_scatter_api.py", "scatter", "nccl") self.check_with_place("collective_scatter_api.py", "scatter", "nccl")
def test_scatter_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_scatter_api.py",
"scatter",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_scatter_nccl_dygraph(self): def test_scatter_nccl_dygraph(self):
dtypes_to_test = [ dtypes_to_test = [
"float16", "float16",
......
...@@ -125,7 +125,7 @@ class TestCollectiveAPIRunnerBase: ...@@ -125,7 +125,7 @@ class TestCollectiveAPIRunnerBase:
rank = args["trainerid"] rank = args["trainerid"]
current_endpoint = args["currentendpoint"] current_endpoint = args["currentendpoint"]
nranks = 2 nranks = 2
if args["use_comm_context"]: if args["use_comm_context"] or args["dynamic_static_unified_comm"]:
paddle.distributed.collective._init_parallel_env(args["backend"]) paddle.distributed.collective._init_parallel_env(args["backend"])
else: else:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
...@@ -152,7 +152,13 @@ class TestCollectiveAPIRunnerBase: ...@@ -152,7 +152,13 @@ class TestCollectiveAPIRunnerBase:
reduce_type=args['reduce_type'], reduce_type=args['reduce_type'],
) )
if args["use_comm_context"] if args["use_comm_context"]
else self.get_model(train_prog, startup_prog, rank) else (
self.get_model_new_comm(
train_prog, startup_prog, rank, dtype=args['dtype']
)
if args["dynamic_static_unified_comm"]
else self.get_model(train_prog, startup_prog, rank)
)
) )
exe = base.Executor(place) exe = base.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
...@@ -182,6 +188,9 @@ def runtime_main(test_class, col_type): ...@@ -182,6 +188,9 @@ def runtime_main(test_class, col_type):
args["dtype"] = os.getenv("DTYPE") args["dtype"] = os.getenv("DTYPE")
args["reduce_type"] = os.getenv("REDUCE_TYPE") args["reduce_type"] = os.getenv("REDUCE_TYPE")
args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0"))) args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0")))
args["dynamic_static_unified_comm"] = bool(
int(os.getenv("FLAGS_dynamic_static_unified_comm", "0"))
)
model.run_trainer(args) model.run_trainer(args)
......
...@@ -27,6 +27,7 @@ import numpy as np ...@@ -27,6 +27,7 @@ import numpy as np
import paddle.base.unique_name as nameGen import paddle.base.unique_name as nameGen
from paddle import base from paddle import base
from paddle.base import core from paddle.base import core
from paddle.distributed.collective import _init_parallel_env
class TestCollectiveRunnerBase: class TestCollectiveRunnerBase:
...@@ -110,9 +111,13 @@ class TestCollectiveRunnerBase: ...@@ -110,9 +111,13 @@ class TestCollectiveRunnerBase:
rank = args["trainerid"] rank = args["trainerid"]
current_endpoint = args["currentendpoint"] current_endpoint = args["currentendpoint"]
nranks = 2 nranks = 2
self.initCommunicator( if args["dynamic_static_unified_comm"]:
startup_prog, rank, nranks, True, current_endpoint, endpoints _init_parallel_env("nccl")
) else:
self.initCommunicator(
startup_prog, rank, nranks, True, current_endpoint, endpoints
)
self.rank = rank self.rank = rank
result = self.get_model(train_prog, startup_prog) result = self.get_model(train_prog, startup_prog)
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
...@@ -140,6 +145,10 @@ def runtime_main(test_class, col_type, sub_type): ...@@ -140,6 +145,10 @@ def runtime_main(test_class, col_type, sub_type):
args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS') args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS')
args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT") args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT")
args["col_type"] = col_type args["col_type"] = col_type
args["dtype"] = os.getenv("DTYPE")
args["dynamic_static_unified_comm"] = bool(
int(os.getenv("FLAGS_dynamic_static_unified_comm", "0"))
)
model.run_trainer(args) model.run_trainer(args)
...@@ -257,6 +266,8 @@ class TestDistBase(unittest.TestCase): ...@@ -257,6 +266,8 @@ class TestDistBase(unittest.TestCase):
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""), "LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "3", "GLOG_v": "3",
"NCCL_P2P_DISABLE": "1", "NCCL_P2P_DISABLE": "1",
"Flags_dynamic_static_unified_comm": "False",
"DTYPE": "float32",
} }
required_envs.update(need_envs) required_envs.update(need_envs)
if check_error_log: if check_error_log:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册