未验证 提交 40431e66 编写于 作者: G Ghost Screaming 提交者: GitHub

[Auto Parallel] Upgrade fluid comm operators to be compatible with new comm library (#56088)

上级 77036fff
......@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/alltoall_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
namespace paddle {
......@@ -41,15 +46,44 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"The ring_id (%d) for alltoall op must be non-negative.", ring_id));
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
int nranks = comm->nranks();
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
int nranks = 0;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
nranks = comm_ctx->GetSize();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
stream = comm->stream();
nranks = comm->nranks();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
framework::DDim x_dims = x->dims();
......@@ -66,6 +100,18 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
auto recv_buf = out->mutable_data<T>(out_dims, place);
size_t offset = 0;
send_numel /= nranks;
if (comm_ctx) {
comm_ctx->GroupStart();
for (auto i = 0; i < nranks; ++i) {
auto send_buf = distributed::GetPartialTensor(*x, offset, send_numel);
comm_ctx->Send(send_buf, send_numel, i, stream);
auto recv_buf = distributed::GetPartialTensor(*out, offset, send_numel);
comm_ctx->Recv(&recv_buf, send_numel, i, stream);
offset += send_numel;
}
comm_ctx->GroupEnd();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < nranks; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
......@@ -75,6 +121,8 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
offset += send_numel;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
......
......@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/barrier_op.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
namespace paddle {
......@@ -38,13 +42,45 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id");
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
auto comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
auto stream = comm_ctx->GetStream();
ncclRedOp_t nccl_red_type = ncclSum;
comm_ctx->AllReduce(out, *in, nccl_red_type, stream);
platform::GpuStreamSync(stream);
VLOG(3) << "new NCCLCommContext has rid " << rid;
} else {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
// should ExecutionContext for calc stream.
auto stream = ctx.cuda_device_context().stream();
ncclRedOp_t nccl_red_type = ncclSum;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(sendbuff,
recvbuff,
numel,
dtype,
nccl_red_type,
comm->comm(),
stream));
platform::GpuStreamSync(stream);
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should compile with NCCL."));
......
......@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/framework/convert_utils.h"
......@@ -50,25 +55,54 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
return;
}
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->mutable_data<T>(place);
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
if (comm_ctx) {
comm_ctx->AllGather(out, *in, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
......@@ -76,6 +110,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
static_cast<ncclDataType_t>(dtype),
comm->comm(),
stream));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
......@@ -31,6 +32,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
......@@ -293,16 +297,41 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
return;
}
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
comm = platform::NCCLCommContext::Instance().Get(rid, place);
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should not use global ctx for calc stream.
// auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
// stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
VLOG(10) << "all reduce buffer:" << sendbuff << ", numel:" << numel
<< ", redtype:" << static_cast<int>(red_type)
......@@ -332,8 +361,17 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
"Invalid reduce type: %d", red_type));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
if (comm_ctx) {
comm_ctx->AllReduce(out, *in, nccl_red_type, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(sendbuff,
recvbuff,
numel,
dtype,
nccl_red_type,
comm->comm(),
stream));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_concat_op.h"
#include <vector>
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/api/include/tensor.h"
......@@ -23,6 +25,9 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
namespace paddle {
......@@ -68,6 +73,12 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
temp_out.mutable_data<T>(temp_out_dims, place);
auto map = distributed::ProcessGroupMapFromGid::getInstance();
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
......@@ -78,20 +89,47 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
auto task = pg->AllGather(in_tensor, out_tensor);
task->Wait();
} else {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm "
"True. But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(
nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"nranks: %s should equal to %s", nranks, comm->nranks()));
int64_t send_numel = x->numel();
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}
if (comm_ctx) {
comm_ctx->AllGather(&temp_out, *x, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
recv_buff,
......@@ -100,6 +138,7 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
comm->comm(),
stream));
}
}
std::vector<phi::DenseTensor> inputs;
int axis = x->dims().size() - 1;
......
......@@ -19,11 +19,13 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
......@@ -32,6 +34,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
......@@ -220,14 +225,40 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id");
int root = ctx.Attr<int>("root_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
ncclRedOp_t nccl_red_type = ncclSum;
......@@ -256,6 +287,9 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
"kRedMax, kRedMin, kRedProd."));
}
if (comm_ctx) {
comm_ctx->Reduce(out, *in, nccl_red_type, root, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce(sendbuff,
recvbuff,
numel,
......@@ -264,6 +298,7 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
root,
comm->comm(),
stream));
}
#else
PADDLE_ENFORCE_EQ(true,
false,
......
......@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
namespace paddle {
......@@ -32,10 +37,58 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
int nranks = comm->nranks();
auto out_dims = in->dims();
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
PADDLE_ENFORCE_EQ(out_dims[0] % comm_ctx->GetSize(),
0,
platform::errors::InvalidArgument(
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0],
comm_ctx->GetSize()));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has ring_id " << rid;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_EQ(out_dims[0] % comm->nranks(),
0,
platform::errors::InvalidArgument(
"The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0],
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has ring_id " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}
int nranks = comm_ctx ? comm_ctx->GetSize() : comm->nranks();
PADDLE_ENFORCE_EQ(out_dims[0] % nranks,
0,
platform::errors::InvalidArgument(
......@@ -52,22 +105,18 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
int dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(in->dtype()));
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
if (comm_ctx) {
comm_ctx->ReduceScatter(out, *in, ncclSum, stream);
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclReduceScatter(send_buff,
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
send_buff,
recv_buff,
recv_numel,
static_cast<ncclDataType_t>(dtype),
ncclSum,
comm->comm(),
stream));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
......
......@@ -13,10 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_scatter_op.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
namespace paddle {
......@@ -37,14 +42,9 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
int root_id = ctx.Attr<int>("root");
int ring_id = ctx.Attr<int>("ring_id");
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
PADDLE_ENFORCE_EQ(nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"The number of ranks (%d) you set of must "
"be equal to comm->nranks (%d).",
nranks,
comm->nranks()));
gpuStream_t stream = nullptr;
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
PADDLE_ENFORCE_GE(
root_id,
0,
......@@ -58,18 +58,71 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
"The ring_id (%d) for c_scatter_op must be non-negative.",
ring_id));
gpuStream_t stream = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
PADDLE_ENFORCE_EQ(nranks,
comm_ctx->GetSize(),
platform::errors::InvalidArgument(
"The number of ranks (%d) you set of must "
"be equal to comm_ctx->GetSize() (%d).",
nranks,
comm_ctx->GetSize()));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has ring_id " << ring_id;
} else { // old comm_context
comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
PADDLE_ENFORCE_EQ(nranks,
comm->nranks(),
platform::errors::InvalidArgument(
"The number of ranks (%d) you set of must "
"be equal to comm->nranks (%d).",
nranks,
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has ring_id " << ring_id;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
framework::DDim x_dims = x->dims();
framework::DDim out_dims(x_dims);
phi::DenseTensor temp;
auto out_ptr = temp.mutable_data<T>(out_dims, place);
if (FLAGS_dynamic_static_unified_comm) {
if (root_id == comm_ctx->GetRank()) {
comm_ctx->Broadcast(
const_cast<phi::DenseTensor*>(x), *x, root_id, stream);
framework::TensorCopy(
*static_cast<const phi::DenseTensor*>(x),
place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<phi::DenseTensor*>(&temp));
} else {
comm_ctx->Broadcast(&temp, temp, root_id, stream);
}
} else {
if (root_id == comm->rank()) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())),
......@@ -79,7 +132,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
comm->comm(),
stream));
framework::TensorCopy(*static_cast<const phi::DenseTensor*>(x),
framework::TensorCopy(
*static_cast<const phi::DenseTensor*>(x),
place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<phi::DenseTensor*>(&temp));
......@@ -87,9 +141,12 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
out_ptr, numel, dtype, root_id, comm->comm(), stream));
}
}
out_dims[0] = out_dims[0] / nranks;
auto start_index = out_dims[0] * comm->rank();
auto start_index = FLAGS_dynamic_static_unified_comm
? out_dims[0] * comm_ctx->GetRank()
: out_dims[0] * comm->rank();
auto end_index = start_index + out_dims[0];
temp = temp.Slice(start_index, end_index);
temp.Resize(out_dims);
......
......@@ -20,6 +20,10 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#if defined(PADDLE_WITH_XPU_BKCL)
......@@ -36,8 +40,29 @@ class CSyncCommStreamKernel : public framework::OpKernel<T> {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto place = ctx.GetPlace();
int ring_id = ctx.Attr<int>("ring_id");
auto stream =
gpuStream_t stream = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
auto comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
platform::GpuStreamSync(stream);
......
......@@ -21,6 +21,10 @@ class Scope;
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
namespace paddle {
......@@ -46,15 +50,40 @@ class CWaitCommOp : public framework::OperatorBase {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = Attr<int>("ring_id");
auto compute_stream =
gpuStream_t compute_stream =
static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto comm_stream =
gpuStream_t comm_stream = nullptr;
gpuEvent_t event = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
phi::distributed::NCCLCommContext* comm_ctx =
static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
comm_stream = comm_ctx->GetStream();
event = comm_ctx->GetComputeEvent();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
auto event =
platform::NCCLCommContext::Instance().Get(ring_id, place)->comm_event();
event = platform::NCCLCommContext::Instance()
.Get(ring_id, place)
->comm_event();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
// comm_stream-->event-->compute_stream
#ifdef PADDLE_WITH_HIP
......
......@@ -21,6 +21,10 @@ class Scope;
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
namespace paddle {
......@@ -46,16 +50,40 @@ class CWaitComputeOp : public framework::OperatorBase {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = Attr<int>("ring_id");
auto compute_stream =
gpuStream_t compute_stream =
static_cast<phi::GPUContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto comm_stream =
gpuStream_t comm_stream = nullptr;
gpuEvent_t event = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
phi::distributed::NCCLCommContext* comm_ctx =
static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
comm_stream = comm_ctx->GetStream();
event = comm_ctx->GetComputeEvent();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
auto event = platform::NCCLCommContext::Instance()
event = platform::NCCLCommContext::Instance()
.Get(ring_id, place)
->compute_event();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}
// compute_stream-->event-->comm_stream
#ifdef PADDLE_WITH_HIP
......
......@@ -17,6 +17,10 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#include "paddle/fluid/distributed/collective/process_group.h"
......@@ -30,10 +34,12 @@ namespace operators {
framework::DDim recv_shape_info(const platform::Place &place,
const gpuStream_t &stream,
platform::NCCLComm *comm,
phi::distributed::NCCLCommContext *comm_ctx,
const int &peer,
distributed::ProcessGroup *group) {
if (!group) {
PADDLE_ENFORCE_EQ((stream != nullptr && comm != nullptr),
PADDLE_ENFORCE_EQ(
((stream != nullptr && comm != nullptr) || comm_ctx != nullptr),
true,
platform::errors::InvalidArgument(
"NCCLComm and Stream should be provided if use NCCL "
......@@ -50,9 +56,14 @@ framework::DDim recv_shape_info(const platform::Place &place,
gpu_shape_size_tensor.Resize({1});
gpu_shape_size_tensor.mutable_data(place, shape_dtype);
auto *gpu_data = gpu_shape_size_tensor.data<int>();
if (comm_ctx) {
comm_ctx->Recv(&gpu_shape_size_tensor, 1, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
gpu_data, 1, nccl_dtype, peer, comm->comm(), stream));
}
}
// copy the shape size tensor to cpu
phi::DenseTensor *cpu_shape_size_tensor = new phi::DenseTensor(shape_dtype);
......@@ -76,9 +87,13 @@ framework::DDim recv_shape_info(const platform::Place &place,
gpu_shape_tensor.Resize({shape_size});
gpu_shape_tensor.mutable_data(place, shape_dtype);
auto *gpu_shape_data = gpu_shape_tensor.data<int>();
if (comm_ctx) {
comm_ctx->Recv(&gpu_shape_tensor, shape_size, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
gpu_shape_data, shape_size, nccl_dtype, peer, comm->comm(), stream));
}
}
// copy the shape tensor to cpu
phi::DenseTensor *cpu_shape_tensor = new phi::DenseTensor(shape_dtype);
......@@ -139,9 +154,11 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
if (dynamic_shape) {
VLOG(3) << "recv_v2 will use dynamic shape with send_v2 for switch";
framework::DDim new_dim = recv_shape_info(ctx.GetPlace(),
framework::DDim new_dim =
recv_shape_info(ctx.GetPlace(),
/* gpuStream_t */ nullptr,
/* NCCLComm* */ nullptr,
/* NCCLCommContext* */ nullptr,
peer,
pg);
out->Resize(new_dim);
......@@ -154,21 +171,48 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
auto task = pg->Recv(out_tensor, peer);
return;
}
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
platform::NCCLComm *comm = nullptr;
phi::distributed::NCCLCommContext *comm_ctx = nullptr;
const auto &comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_LT(
peer,
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_LT(peer,
comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}
int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);
......@@ -188,10 +232,14 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
auto out_dims = out->dims();
out->mutable_data<T>(out_dims, place, 0);
auto numel = out->numel();
if (comm_ctx) {
comm_ctx->Recv(out, numel, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out_dims)
<< " from " << peer;
VLOG(3) << "rank " << comm->rank() << " recv "
<< phi::product(out_dims) << " from " << peer;
}
}
return;
}
......@@ -206,6 +254,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
framework::DDim new_dim = recv_shape_info(place,
stream,
comm,
comm_ctx,
peer,
/* ProcessGroup* */ nullptr);
out->Resize(new_dim);
......@@ -214,10 +263,22 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
} else {
out->mutable_data<T>(out_dims, place);
}
if (comm_ctx) {
comm_ctx->Recv(out, numel, peer, stream);
} else {
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_LT(peer,
comm->nranks(),
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv " << phi::product(out->dims())
<< " from " << peer;
VLOG(3) << "rank " << comm->rank() << " recv "
<< phi::product(out->dims()) << " from " << peer;
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should be compiled with NCCL and "
......
......@@ -17,6 +17,10 @@ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(dynamic_static_unified_comm);
#endif
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/phi/api/include/tensor.h"
......@@ -30,10 +34,12 @@ void send_shape_info(const phi::DenseTensor& x,
const platform::Place& place,
const gpuStream_t& stream,
platform::NCCLComm* comm,
phi::distributed::NCCLCommContext* comm_ctx,
const int& peer,
distributed::ProcessGroup* group) {
if (!group) {
PADDLE_ENFORCE_EQ((stream != nullptr && comm != nullptr),
PADDLE_ENFORCE_EQ(
((stream != nullptr && comm != nullptr) || comm_ctx != nullptr),
true,
platform::errors::InvalidArgument(
"NCCLComm and Stream should be provided if use NCCL "
......@@ -63,6 +69,9 @@ void send_shape_info(const phi::DenseTensor& x,
gpu_shape_size_tensor->mutable_data(place, shape_dtype);
framework::TensorCopySync(
cpu_shape_size_tensor, place, gpu_shape_size_tensor);
if (comm_ctx) {
comm_ctx->Send(*gpu_shape_size_tensor, 1, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclSend(gpu_shape_size_tensor->data<int>(),
1,
......@@ -71,6 +80,7 @@ void send_shape_info(const phi::DenseTensor& x,
comm->comm(),
stream));
}
}
VLOG(3) << "send the shape size: " << shape_size << " to peer";
// step2: send the shape
......@@ -92,6 +102,9 @@ void send_shape_info(const phi::DenseTensor& x,
gpu_shape_tensor->Resize({shape_size});
gpu_shape_tensor->mutable_data(place, shape_dtype);
framework::TensorCopySync(cpu_shape_tensor, place, gpu_shape_tensor);
if (comm_ctx) {
comm_ctx->Send(*gpu_shape_tensor, shape_size, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclSend(gpu_shape_tensor->data<int>(),
shape_size,
......@@ -100,6 +113,7 @@ void send_shape_info(const phi::DenseTensor& x,
comm->comm(),
stream));
}
}
VLOG(3) << "send the shape: (" << dims << ") to peer";
}
#endif
......@@ -137,6 +151,7 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
ctx.GetPlace(),
/* gpuStream_t */ nullptr,
/* NCCLComm* */ nullptr,
/* NCCLCommContext * */ nullptr,
peer,
pg);
}
......@@ -148,20 +163,47 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
}
gpuStream_t stream = nullptr;
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
platform::NCCLComm* comm = nullptr;
phi::distributed::NCCLCommContext* comm_ctx = nullptr;
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
platform::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
platform::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
stream = comm->stream();
}
PADDLE_ENFORCE_LT(
peer,
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_LT(peer,
comm->nranks(),
platform::errors::InvalidArgument("The value of peer (%d) you set must "
platform::errors::InvalidArgument(
"The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
}
auto* x_var = ctx.InputVar("X");
if (x_var->IsType<framework::LoDTensorArray>()) {
......@@ -177,8 +219,12 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
int numel = x.numel();
ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x.dtype()));
if (comm_ctx) {
comm_ctx->Send(x, numel, peer, stream);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
x.data<T>(), numel, dtype, peer, comm->comm(), stream));
}
VLOG(3) << "rank " << comm->rank() << " send " << phi::product(x.dims())
<< " to " << peer;
}
......@@ -193,16 +239,21 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
place,
stream,
comm,
comm_ctx,
peer,
/* ProcessGroup* */ nullptr);
}
if (comm_ctx) {
comm_ctx->Send(*x, numel, peer, stream);
} else {
ncclDataType_t dtype =
platform::ToNCCLDataType(framework::TransToProtoVarType(x->dtype()));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
x->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " send " << phi::product(x->dims())
<< " to " << peer;
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"PaddlePaddle should be compiled with NCCL "
......
......@@ -30,6 +30,7 @@
#endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
......@@ -54,7 +55,8 @@ void CommContextManager::CreateNCCLCommContext(
const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
int rank,
int size) {
int size,
const std::string& hash_key) {
auto& comm_context_manager = CommContextManager::GetInstance();
if (comm_context_manager.Has(unique_comm_key)) {
return;
......@@ -64,7 +66,7 @@ void CommContextManager::CreateNCCLCommContext(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id));
}
std::string unique_key = "NCCLCommContext/" + unique_comm_key;
std::string unique_key = "NCCLCommContext/" + unique_comm_key + hash_key;
if (rank == 0) {
std::vector<uint8_t> nccl_id_wrapper(
reinterpret_cast<uint8_t*>(&nccl_id),
......@@ -77,7 +79,6 @@ void CommContextManager::CreateNCCLCommContext(
auto nccl_comm_context =
std::make_unique<NCCLCommContext>(rank, size, nccl_id);
if (CommContextManager::device_id != -1) {
std::unique_ptr<phi::GPUContext> dev_ctx(
new phi::GPUContext(phi::GPUPlace(CommContextManager::device_id)));
......
......@@ -52,7 +52,8 @@ class CommContextManager {
static void CreateNCCLCommContext(const std::shared_ptr<Store>& store,
const std::string& unique_comm_key,
int rank,
int size);
int size,
const std::string& hash_key = "");
#endif
#if defined(PADDLE_WITH_GLOO)
......
......@@ -1328,3 +1328,19 @@ PHI_DEFINE_EXPORTED_int64(host_trace_level,
1,
"RecordEvent will works "
"if host_trace_level >= level.");
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
/**
* Communication library related FLAG
* Name: FLAGS_dynamic_static_unified_comm
* Since Version: 2.5
* Value Range: bool, default=false
* Example:
* Note: Whether to use new communication library in auto parallel and static
* mode. If true, it will use unified CommContextManager for communication.
*/
PHI_DEFINE_EXPORTED_bool(dynamic_static_unified_comm,
false,
"Whether to use new communication library in auto "
"parallel and static mode.");
#endif // FLAGS_dynamic_static_unified_comm
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
import hashlib
import os
from collections import OrderedDict
......@@ -162,12 +163,21 @@ class ProcessGroup:
)
if use_new_comm in ["1", "True", "true"]:
store = core.create_or_get_global_tcp_store()
endpoints_str = ""
for endpoint in strategy.trainer_endpoints:
endpoints_str += endpoint
endpoints_str += f"ring_id:{ring_id}"
endpoints_str_hash = hashlib.md5(
endpoints_str.encode(encoding='UTF-8')
).hexdigest()
core.CommContextManager.set_device_id(genv.device_id)
core.CommContextManager.create_nccl_comm_context(
store,
str(ring_id),
strategy.local_rank,
strategy.nranks,
endpoints_str_hash,
)
else:
core.NCCLParallelContext(strategy, place).init_with_ring_id(
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import datetime
import hashlib
import paddle
......@@ -330,9 +331,16 @@ def _init_parallel_env(backend):
store, "0", rank, world_size
)
elif backend == "nccl":
endpoints_str = ""
for endpoint in global_env.trainer_endpoints:
endpoints_str += endpoint
endpoints_str += "ring_id:{}".format("0")
endpoints_str_hash = hashlib.md5(
endpoints_str.encode(encoding='UTF-8')
).hexdigest()
core.CommContextManager.set_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
store, "0", rank, world_size, endpoints_str_hash
)
elif backend == "xccl":
dev_type = global_env.device_type
......
......@@ -81,10 +81,10 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
if((NOT WITH_ROCM) AND ((${CUDA_ARCH_NAME}) STREQUAL "Ampere"))
set_tests_properties(test_collective_alltoall_api
PROPERTIES TIMEOUT "160" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "280" LABELS "RUN_TYPE=DIST")
else()
set_tests_properties(test_collective_alltoall_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "240" LABELS "RUN_TYPE=DIST")
endif()
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
......@@ -286,7 +286,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_sendrecv MODULES test_collective_sendrecv ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_sendrecv PROPERTIES TIMEOUT "300" LABELS
set_tests_properties(test_collective_sendrecv PROPERTIES TIMEOUT "500" LABELS
"RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
......@@ -294,7 +294,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_sendrecv_api
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "500" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......
......@@ -109,7 +109,7 @@ class TestCollectiveAllgatherAPI(test_base.TestCollectiveAPIRunnerBase):
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
if args["use_comm_context"]:
if args["use_comm_context"] or args["dynamic_static_unified_comm"]:
paddle.distributed.collective._init_parallel_env(args["backend"])
else:
paddle.distributed.init_parallel_env()
......
......@@ -93,6 +93,20 @@ class TestCollectiveAllreduceAPI(TestCollectiveAPIRunnerBase):
all_reduce_new(tindata, reduce_type)
return [tindata]
def get_model_new_comm(
self,
main_prog,
startup_program,
rank,
dtype='float32',
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
paddle.distributed.all_reduce(tindata)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveAllreduceAPI, "allreduce")
......@@ -105,6 +105,9 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase):
return toutdata
def get_model_new_comm(self, main_prog, startup_program, dtype="float32"):
return self.get_model(main_prog, startup_program)
if __name__ == "__main__":
runtime_main(TestCollectiveAllreduce, "allreduce", 0)
......@@ -126,6 +126,19 @@ class TestCollectiveAllToAllAPI(TestCollectiveAPIRunnerBase):
alltoall_new(tindata, tout_data)
return tout_data
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype='float32'
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
tindata = paddle.split(tindata, 2, axis=0)
tout_data = []
paddle.distributed.alltoall(tindata, tout_data)
return tout_data
if __name__ == "__main__":
runtime_main(TestCollectiveAllToAllAPI, "alltoall")
......@@ -33,6 +33,13 @@ class TestCollectiveBarrierAPI(TestCollectiveAPIRunnerBase):
paddle.distributed.barrier()
return []
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype="float32"
):
with base.program_guard(main_prog, startup_program):
paddle.distributed.barrier()
return []
if __name__ == "__main__":
runtime_main(TestCollectiveBarrierAPI, "barrier")
......@@ -60,6 +60,39 @@ def concat_new(tensor, group=None):
return out
def concat_new_comm(tensor, group=None, rank=0):
op_type = 'c_concat'
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
],
op_type,
)
helper = framework.LayerHelper(op_type, **locals())
ring_id = 0 if group is None else group.id
nranks = 2
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [out]},
attrs={
'ring_id': ring_id,
'nranks': nranks,
'rank': rank,
},
)
return out
class TestCollectiveConcatAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
......@@ -78,6 +111,18 @@ class TestCollectiveConcatAPI(TestCollectiveAPIRunnerBase):
toutdata = concat_new(tindata)
return [toutdata]
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype="float32"
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
toutdata = concat_new_comm(tindata, rank=rank)
return [toutdata]
if __name__ == "__main__":
runtime_main(TestCollectiveConcatAPI, "concat")
......@@ -92,6 +92,17 @@ class TestCollectiveReduceAPI(TestCollectiveAPIRunnerBase):
reduce_new(tindata, dst=0, reduce_type=reduce_type)
return [tindata]
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype='float32'
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[-1, 10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
paddle.distributed.reduce(tindata, dst=0)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveReduceAPI, "reduce")
......@@ -50,6 +50,20 @@ class TestCollectiveReduceScatterAPI(TestCollectiveAPIRunnerBase):
paddle.distributed.reduce_scatter(toutdata, tindata)
return [toutdata]
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype="float32"
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
toutdata = paddle.static.data(
name="toutdata", shape=[5, 1000], dtype=dtype
)
paddle.distributed.reduce_scatter(toutdata, tindata)
return [toutdata]
if __name__ == "__main__":
runtime_main(TestCollectiveReduceScatterAPI, "reduce_scatter")
......@@ -44,6 +44,24 @@ class TestCollectiveScatterAPI(TestCollectiveAPIRunnerBase):
paddle.distributed.scatter(toutdata, tensor_list, src=1)
return [toutdata]
def get_model_new_comm(
self, main_prog, startup_program, rank, dtype="float32"
):
with base.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata",
shape=[10, 1000],
dtype=dtype,
)
toutdata = paddle.tensor.fill_constant(
shape=[5, 1000], dtype=dtype, value=1.0
)
tensor_list = None
if rank == 1:
tensor_list = paddle.split(tindata, 2, axis=0)
paddle.distributed.scatter(toutdata, tensor_list, src=1)
return [toutdata]
if __name__ == "__main__":
runtime_main(TestCollectiveScatterAPI, "scatter")
......@@ -66,6 +66,26 @@ class TestCollectiveAllgatherAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_allgather_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
"int8",
"uint8",
"bool",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_allgather_api.py",
"allgather",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_allgather_gloo(self):
dtypes_to_test = [
"float16",
......
......@@ -59,6 +59,30 @@ class TestCollectiveAllreduceAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_allreduce_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
red_types_to_test = [
dist.ReduceOp.SUM,
]
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
for red_type in red_types_to_test:
self.check_with_place(
"collective_allreduce_api.py",
"allreduce",
"nccl",
dtype=dtype,
reduce_type=red_type,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_allreduce_bkcl(self):
if paddle.base.core.is_compiled_with_xpu():
self.check_with_place(
......
......@@ -43,6 +43,23 @@ class TestCollectiveAllToAllAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_alltoall_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_alltoall_api.py",
"alltoall",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_alltoall_nccl_dygraph(self):
dtypes_to_test = [
"float16",
......
......@@ -28,6 +28,14 @@ class TestCollectiveBarrierAPI(TestDistBase):
def test_barrier_nccl(self):
self.check_with_place("collective_barrier_api.py", "barrier", "nccl")
def test_barrier_nccl_with_new_comm(self):
self.check_with_place(
"collective_barrier_api.py",
"barrier",
"nccl",
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_barrier_gloo(self):
self.check_with_place(
"collective_barrier_api.py", "barrier", "gloo", "5"
......
......@@ -47,6 +47,23 @@ class TestCollectiveConcatAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_concat_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_concat_api.py",
"dist_concat",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
if __name__ == '__main__':
unittest.main()
......@@ -58,6 +58,29 @@ class TestCollectiveReduceAPI(TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_reduce_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
red_types_to_test = [
dist.ReduceOp.SUM,
]
for dtype in dtypes_to_test:
if paddle.base.core.is_compiled_with_cuda():
for red_type in red_types_to_test:
self.check_with_place(
"collective_reduce_api.py",
"reduce",
"nccl",
dtype=dtype,
reduce_type=red_type,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_reduce_bkcl(self):
if paddle.base.core.is_compiled_with_xpu():
self.check_with_place("collective_reduce_api.py", "reduce", "bkcl")
......
......@@ -43,6 +43,25 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_reduce_scatter_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
"collective_reduce_scatter_api.py",
"reduce_scatter",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_reduce_scatter_nccl_dygraph(self):
dtypes_to_test = [
"float16",
......
......@@ -33,6 +33,23 @@ class TestCollectiveScatterAPI(TestDistBase):
def test_scatter_nccl(self):
self.check_with_place("collective_scatter_api.py", "scatter", "nccl")
def test_scatter_nccl_with_new_comm(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
]
for dtype in dtypes_to_test:
self.check_with_place(
"collective_scatter_api.py",
"scatter",
"nccl",
dtype=dtype,
need_envs={"FLAGS_dynamic_static_unified_comm": "1"},
)
def test_scatter_nccl_dygraph(self):
dtypes_to_test = [
"float16",
......
......@@ -125,7 +125,7 @@ class TestCollectiveAPIRunnerBase:
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
if args["use_comm_context"]:
if args["use_comm_context"] or args["dynamic_static_unified_comm"]:
paddle.distributed.collective._init_parallel_env(args["backend"])
else:
paddle.distributed.init_parallel_env()
......@@ -152,8 +152,14 @@ class TestCollectiveAPIRunnerBase:
reduce_type=args['reduce_type'],
)
if args["use_comm_context"]
else (
self.get_model_new_comm(
train_prog, startup_prog, rank, dtype=args['dtype']
)
if args["dynamic_static_unified_comm"]
else self.get_model(train_prog, startup_prog, rank)
)
)
exe = base.Executor(place)
exe.run(startup_prog)
fetch_list = []
......@@ -182,6 +188,9 @@ def runtime_main(test_class, col_type):
args["dtype"] = os.getenv("DTYPE")
args["reduce_type"] = os.getenv("REDUCE_TYPE")
args["use_comm_context"] = bool(int(os.getenv("USE_COMM_CONTEXT", "0")))
args["dynamic_static_unified_comm"] = bool(
int(os.getenv("FLAGS_dynamic_static_unified_comm", "0"))
)
model.run_trainer(args)
......
......@@ -27,6 +27,7 @@ import numpy as np
import paddle.base.unique_name as nameGen
from paddle import base
from paddle.base import core
from paddle.distributed.collective import _init_parallel_env
class TestCollectiveRunnerBase:
......@@ -110,9 +111,13 @@ class TestCollectiveRunnerBase:
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
if args["dynamic_static_unified_comm"]:
_init_parallel_env("nccl")
else:
self.initCommunicator(
startup_prog, rank, nranks, True, current_endpoint, endpoints
)
self.rank = rank
result = self.get_model(train_prog, startup_prog)
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
......@@ -140,6 +145,10 @@ def runtime_main(test_class, col_type, sub_type):
args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS')
args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT")
args["col_type"] = col_type
args["dtype"] = os.getenv("DTYPE")
args["dynamic_static_unified_comm"] = bool(
int(os.getenv("FLAGS_dynamic_static_unified_comm", "0"))
)
model.run_trainer(args)
......@@ -257,6 +266,8 @@ class TestDistBase(unittest.TestCase):
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "3",
"NCCL_P2P_DISABLE": "1",
"Flags_dynamic_static_unified_comm": "False",
"DTYPE": "float32",
}
required_envs.update(need_envs)
if check_error_log:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册