未验证 提交 2b12a117 编写于 作者: TaoTao Li's avatar TaoTao Li 提交者: GitHub

add distributed p_send/p_recv/reduce_scatter operator (#51858)

fix merge conflicts
上级 d6a38532
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class PSendOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
};
class PSendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of send op.");
AddAttr<int>("ring_id", "(int), attribute 0 for send op.").SetDefault(0);
AddAttr<int>("peer", "(int), attribute 1 for send op.").SetDefault(0);
AddAttr<bool>("dynamic_shape", "(bool), attribute 2 for send op.")
.SetDefault(false);
AddComment(R"DOC(
TODO: Documentation of send op.
)DOC");
}
};
class PSendArrayOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
};
class PSendArrayOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor[]), input 0 of p_send_array op.").AsDuplicable();
AddAttr<int>("ring_id", "(int), attribute 0 for p_send_array op.")
.SetDefault(0);
AddAttr<int>("peer", "(int), attribute 1 for p_send_array op.")
.SetDefault(0);
AddComment(R"DOC(
TODO: Documentation of p_send_array op.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(p_send,
PSendInferShapeFunctor,
PD_INFER_META(phi::PSendInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(p_send_array,
PSendArrayInferShapeFunctor,
PD_INFER_META(phi::PSendArrayInferMeta));
REGISTER_OPERATOR(
p_send,
ops::PSendOp,
ops::PSendOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
PSendInferShapeFunctor);
REGISTER_OPERATOR(
p_send_array,
ops::PSendArrayOp,
ops::PSendArrayOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
PSendArrayInferShapeFunctor);
......@@ -228,6 +228,27 @@
func : one_hot_raw
data_type : x
- op : p_recv
args : (int ring_id = 0, int peer = 0, DataType dtype = DataType::FLOAT32, bool dynamic_shape = false)
output : Tensor(out)
infer_meta :
func : PRecvInferMeta
param : [peer, dtype]
kernel :
func : p_recv
param : [peer, dtype, dynamic_shape]
data_type : dtype
- op : p_recv_array
args : (int ring_id = 0, int peer = 0, DataType dtype = DataType::FLOAT32, int[] out_shape = {})
output : Tensor(out)
infer_meta :
func : PRecvArrayInferMeta
param : [peer, dtype, out_shape]
kernel :
func : p_recv_array
param : [peer, dtype, out_shape]
- op : randint
args : (int low, int high, IntArray shape = {}, DataType dtype = DataType::INT64, int seed = 0)
output : Tensor(out)
......@@ -249,6 +270,16 @@
func : reduce
param: [x, root_id, reduce_type]
- op : reduce_scatter
args : (Tensor x, int ring_id = 0, int nranks = 1)
output : Tensor(out)
infer_meta :
func : ReduceScatterInferMeta
param: [x, nranks]
kernel :
func : reduce_scatter
param: [x, nranks]
- op : share_buffer
args : (Tensor[] x, bool[] share_dims_and_dtype={})
output : Tensor[](out){x.size()}, Tensor[](xout){x.size()}
......
......@@ -31,6 +31,8 @@ NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id)
phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_));
}
ncclComm_t NCCLCommContext::GetNcclComm() { return nccl_comm_; }
void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
......@@ -64,6 +66,55 @@ void NCCLCommContext::AllGather(phi::DenseTensor* out_tensor,
nccl_comm_,
stream));
}
void NCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
gpuStream_t stream) {
int64_t out_size = in_tensor.numel() / GetSize();
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclReduceScatter(in_tensor.data(),
out_tensor->data(),
out_size,
ToNCCLDataType(in_tensor.type()),
ncclSum,
nccl_comm_,
stream));
}
void NCCLCommContext::Send(const phi::DenseTensor& in_tensor,
const int& peer,
gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
NCCLDynamicCheck::CheckShape(in_tensor, rank_, rank_, nccl_comm_);
}
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclSend(in_tensor.data(),
in_tensor.numel(),
ToNCCLDataType(in_tensor.type()),
peer,
nccl_comm_,
stream));
VLOG(3) << "rank " << GetRank() << " send " << phi::product(in_tensor.dims())
<< " to " << peer;
}
void NCCLCommContext::Recv(phi::DenseTensor* out_tensor,
const int& peer,
gpuStream_t stream) {
if (FLAGS_enable_nccl_dynamic_check) {
NCCLDynamicCheck::CheckShape(*out_tensor, rank_, rank_, nccl_comm_);
}
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::ncclRecv(out_tensor->data(),
out_tensor->numel(),
ToNCCLDataType(out_tensor->type()),
peer,
nccl_comm_,
stream));
VLOG(3) << "rank " << GetRank() << " recv "
<< phi::product(out_tensor->dims()) << " from " << peer;
}
void NCCLCommContext::AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......
......@@ -31,10 +31,21 @@ class NCCLCommContext final : public CommContext {
public:
NCCLCommContext(int rank, int size, ncclUniqueId nccl_id);
ncclComm_t GetNcclComm();
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root,
gpuStream_t stream);
void Send(const phi::DenseTensor& in_tensor,
const int& peer,
gpuStream_t stream);
void Recv(phi::DenseTensor* out_tensor, const int& peer, gpuStream_t stream);
void ReduceScatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
gpuStream_t stream);
void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
......
......@@ -122,6 +122,47 @@ void RandintInferMeta(
out->set_dtype(dtype);
}
void PRecvInferMeta(int peer, DataType dtype, MetaTensor* out) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for p_recv op must be non-negative.", peer));
// auto data_type = phi::TransToPhiDataType(dtype);
out->set_dtype(dtype);
}
void PRecvArrayInferMeta(int peer,
DataType dtype,
const std::vector<int>& out_shape,
MetaTensor* out) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for p_recv op must be non-negative.", peer));
PADDLE_ENFORCE_GE(out_shape.size(),
1,
errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(
out_shape[i],
1,
errors::InvalidArgument("The shape attribute for recv must be set "
"explicitly, but the %dth element is %d which "
"is less than 1. Or dynamic_shape should be "
"set to True for both send_v2 and recv_v2.",
i,
out_shape[i]));
}
out->set_dtype(dtype);
}
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
......
......@@ -60,6 +60,13 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out);
void RandintInferMeta(
int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out);
void PRecvInferMeta(int peer, DataType dtype, MetaTensor* out);
void PRecvArrayInferMeta(int peer,
DataType dtype,
const std::vector<int>& out_shape,
MetaTensor* out);
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
......
......@@ -2821,6 +2821,23 @@ void Pool2DInferMeta(const MetaTensor& x,
}
}
void PSendInferMeta(const MetaTensor& x, int peer) {
LOG(INFO) << "SendBaseInferMeta begin";
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for p_send op must be non-negative.", peer));
}
void PSendArrayInferMeta(const MetaTensor& x, int peer) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for p_send op must be non-negative.", peer));
}
void PoolInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
......@@ -3131,6 +3148,20 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}
void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
auto dim = x.dims();
if (dim[0] > 0 || dim[0] < -1) {
PADDLE_ENFORCE_EQ(
dim[0] % nranks,
0,
errors::InvalidArgument(
"dim[0] (%d) is not divisible by nranks(%d)", dim[0], nranks));
dim[0] /= nranks;
}
out->set_dims(dim);
out->set_dtype(x.dtype());
}
void RepeatInterleaveInferMeta(const MetaTensor& x,
int repeats,
int dim,
......
......@@ -401,6 +401,10 @@ void Pool2DInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void PSendInferMeta(const MetaTensor& x, int peer);
void PSendArrayInferMeta(const MetaTensor& x, int peer);
void QrInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* q,
......@@ -432,6 +436,8 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out);
void RepeatInterleaveInferMeta(const MetaTensor& x,
int repeats,
int dim,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/p_recv_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
namespace phi {
template <typename T, typename Context>
void PRecvKernel(const Context& dev_ctx,
int peer,
DataType dtype,
bool dynamic_shape,
DenseTensor* out) {
PADDLE_THROW(errors::Unavailable("Do not support recv for cpu kernel now."));
}
template <typename T, typename Context>
void PRecvArrayKernel(const Context& dev_ctx,
int peer,
DataType dtype,
const std::vector<int>& out_shape,
TensorArray* out_array) {
PADDLE_THROW(
errors::Unavailable("Do not support recv array for cpu kernel now."));
}
} // namespace phi
PD_REGISTER_KERNEL(p_recv,
CPU,
ALL_LAYOUT,
phi::PRecvKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(p_recv_array,
CPU,
ALL_LAYOUT,
phi::PRecvArrayKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/p_send_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_GLOO)
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#endif
namespace phi {
template <typename T, typename Context>
void PSendKernel(const Context& dev_ctx,
const DenseTensor& x,
int peer,
bool dynamic_shape) {
PADDLE_THROW(errors::Unavailable("Do not support send for cpu kernel now."));
}
template <typename T, typename Context>
void PSendArrayKernel(const Context& dev_ctx,
const TensorArray& x,
int peer,
bool dynamic_shape,
DenseTensor* out) {
PADDLE_THROW(
errors::Unavailable("Do not support send array for cpu kernel now."));
}
} // namespace phi
PD_REGISTER_KERNEL(p_send,
CPU,
ALL_LAYOUT,
phi::PSendKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(p_send_array,
CPU,
ALL_LAYOUT,
phi::PSendArrayKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_scatter_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ReduceScatterKernel(const Context& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
PADDLE_THROW(
errors::Unimplemented("Unimplemented cpu kernel for CReduceScatterOp."));
}
} // namespace phi
PD_REGISTER_KERNEL(reduce_scatter,
CPU,
ALL_LAYOUT,
phi::ReduceScatterKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/p_recv_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_RCCL) && NCCL_VERSION_CODE >= 2703
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
namespace phi {
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
NCCL_VERSION_CODE >= 2703
template <typename Context>
DDim recv_shape_info(const Context& dev_ctx,
phi::DenseTensor* out,
distributed::NCCLCommContext* comm_ctx,
int peer) {
gpuStream_t stream = dev_ctx.stream();
PADDLE_ENFORCE_EQ((stream != nullptr && comm_ctx != nullptr),
true,
errors::InvalidArgument(
"NCCLComm and Stream should be provided if use NCCL "
"to send the shape info."));
paddle::DataType shape_dtype = paddle::DataType::INT32;
ncclDataType_t nccl_dtype = ncclInt;
// phi::DenseTensor gpu_shape_size_tensor(shape_dtype);
phi::DenseTensor* gpu_shape_size_tensor = new phi::DenseTensor(shape_dtype);
gpu_shape_size_tensor->Resize({1});
dev_ctx.Alloc(gpu_shape_size_tensor, shape_dtype);
comm_ctx->Recv(gpu_shape_size_tensor, peer, stream);
// copy the shape size tensor to cpu
phi::DenseTensor* cpu_shape_size_tensor = new phi::DenseTensor(shape_dtype);
cpu_shape_size_tensor->Resize({1});
dev_ctx.HostAlloc(cpu_shape_size_tensor, shape_dtype);
memory_utils::Copy(phi::CPUPlace(),
cpu_shape_size_tensor->data(),
dev_ctx.GetPlace(),
gpu_shape_size_tensor->data(),
gpu_shape_size_tensor->numel() * sizeof(int),
stream);
auto* cpu_data = cpu_shape_size_tensor->data<int>();
int shape_size = cpu_data[0];
VLOG(3) << "recv the shape size: " << shape_size << " from peer: " << peer;
// step2: send the shape
// phi::DenseTensor gpu_shape_tensor(shape_dtype);
phi::DenseTensor* gpu_shape_tensor = new phi::DenseTensor(shape_dtype);
gpu_shape_tensor->Resize({shape_size});
dev_ctx.Alloc(gpu_shape_tensor, shape_dtype);
comm_ctx->Recv(gpu_shape_tensor, peer, stream);
// copy the shape tensor to cpu
phi::DenseTensor* cpu_shape_tensor = new phi::DenseTensor(shape_dtype);
cpu_shape_tensor->Resize({shape_size});
dev_ctx.HostAlloc(cpu_shape_tensor, shape_dtype);
memory_utils::Copy(phi::CPUPlace(),
cpu_shape_tensor->data(),
dev_ctx.GetPlace(),
gpu_shape_tensor->data(),
gpu_shape_tensor->numel() * sizeof(int),
stream);
auto* cpu_shape_data = cpu_shape_tensor->data<int>();
std::vector<int> all_shape;
for (int i = 0; i < shape_size; ++i) {
all_shape.emplace_back(cpu_shape_data[i]);
}
DDim new_dim;
new_dim = new_dim.reshape(all_shape);
VLOG(3) << "recv the shape: (" << new_dim << ") from peer";
return new_dim;
}
template <typename Context>
distributed::NCCLCommContext* GetCommContext(const Context& dev_ctx, int peer) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument("The peer (%d) for send op must be non-negative.",
peer));
auto comm_ctx =
static_cast<distributed::NCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
PADDLE_ENFORCE_LT(
peer,
comm_ctx->GetSize(),
errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm_ctx->GetSize()));
return comm_ctx;
}
#endif
template <typename T, typename Context>
void PRecvKernel(const Context& dev_ctx,
int peer,
DataType dtype,
bool dynamic_shape,
DenseTensor* out) {
#if defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_RCCL) && NCCL_VERSION_CODE >= 2703
auto comm_ctx = GetCommContext(dev_ctx, peer);
gpuStream_t stream = dev_ctx.stream();
// auto data_type = phi::TransToPhiDataType(dtype);
if (dynamic_shape) {
DDim new_dim = recv_shape_info<Context>(dev_ctx, out, comm_ctx, peer);
out->Resize(new_dim);
}
dev_ctx.Alloc(out, dtype);
comm_ctx->Recv(out, peer, stream);
#else
PADDLE_THROW(
errors::PreconditionNotMet("PaddlePaddle should compile with GPU."
"and NCCL version >= 2.7.3 is needed."));
#endif
}
template <typename T, typename Context>
void PRecvArrayKernel(const Context& dev_ctx,
int peer,
DataType dtype,
const std::vector<int>& out_shape,
TensorArray* out_array) {
#if defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_RCCL) && NCCL_VERSION_CODE >= 2703
auto comm_ctx = GetCommContext(dev_ctx, peer);
gpuStream_t stream = dev_ctx.stream();
for (size_t idx = 0; idx < out_shape.size(); ++idx) {
VLOG(3) << "LodTensorArray: idx(" << idx << ")";
auto out = out_array->at(idx);
auto out_dims = out.dims();
dev_ctx.Alloc(&out, dtype);
comm_ctx->Recv(&out, peer, stream);
VLOG(3) << "rank " << comm_ctx->GetRank() << " recv "
<< phi::product(out_dims) << " from " << peer;
}
#else
PADDLE_THROW(
errors::PreconditionNotMet("PaddlePaddle should compile with GPU."
"and NCCL version >= 2.7.3 is needed."));
#endif
}
} // namespace phi
#if NCCL_VERSION_CODE >= 21000
PD_REGISTER_KERNEL(p_recv,
GPU,
ALL_LAYOUT,
phi::PRecvKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(p_recv_array,
GPU,
ALL_LAYOUT,
phi::PRecvArrayKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(p_recv,
GPU,
ALL_LAYOUT,
phi::PRecvKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(p_recv_array,
GPU,
ALL_LAYOUT,
phi::PRecvArrayKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/p_send_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#if defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_RCCL) && NCCL_VERSION_CODE >= 2703
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
namespace phi {
#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \
NCCL_VERSION_CODE >= 2703
template <typename Context>
void send_shape_info(const Context& dev_ctx,
const DenseTensor& x,
distributed::NCCLCommContext* comm_ctx,
int peer,
gpuStream_t stream) {
PADDLE_ENFORCE_EQ((stream != nullptr && comm_ctx != nullptr),
true,
errors::InvalidArgument(
"NCCLComm and Stream should be provided if use NCCL "
"to send the shape info."));
paddle::DataType shape_dtype = paddle::DataType::INT32;
ncclDataType_t nccl_dtype = ncclInt;
auto dims = x.dims();
int shape_size = dims.size();
// step1: send the shape size
phi::DenseTensor cpu_shape_size_tensor(shape_dtype);
cpu_shape_size_tensor.Resize({1});
dev_ctx.HostAlloc(&cpu_shape_size_tensor, shape_dtype);
auto* cpu_data = cpu_shape_size_tensor.data<int>();
cpu_data[0] = shape_size;
// copy the shape size tensor to gpu and send
phi::DenseTensor* gpu_shape_size_tensor = new phi::DenseTensor(shape_dtype);
gpu_shape_size_tensor->Resize({1});
dev_ctx.Alloc(gpu_shape_size_tensor, shape_dtype);
const auto& cpu_place = phi::CPUPlace();
memory_utils::Copy(dev_ctx.GetPlace(),
gpu_shape_size_tensor->data(),
cpu_place,
cpu_shape_size_tensor.data(),
cpu_shape_size_tensor.numel() * sizeof(int),
stream);
comm_ctx->Send(*gpu_shape_size_tensor, peer, stream);
VLOG(3) << "send the shape size: " << shape_size << " to peer";
// step2: send the shape
phi::DenseTensor cpu_shape_tensor(shape_dtype);
cpu_shape_tensor.Resize({shape_size});
dev_ctx.HostAlloc(&cpu_shape_tensor, shape_dtype);
auto* cpu_shape_data = cpu_shape_tensor.data<int>();
for (int i = 0; i < shape_size; ++i) {
cpu_shape_data[i] = dims[i];
}
// copy the shape tensor to gpu and send
phi::DenseTensor* gpu_shape_tensor = new phi::DenseTensor(shape_dtype);
gpu_shape_tensor->Resize({shape_size});
dev_ctx.Alloc(gpu_shape_tensor, shape_dtype);
memory_utils::Copy(dev_ctx.GetPlace(),
gpu_shape_tensor->data(),
cpu_place,
cpu_shape_tensor.data(),
cpu_shape_tensor.numel() * sizeof(int),
stream);
comm_ctx->Send(*gpu_shape_tensor, peer, stream);
VLOG(3) << "send the shape: (" << dims << ") to peer";
}
template <typename Context>
distributed::NCCLCommContext* GetCommContext(const Context& dev_ctx, int peer) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument("The peer (%d) for send op must be non-negative.",
peer));
auto comm_ctx =
static_cast<distributed::NCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_ctx,
nullptr,
errors::Unavailable("NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
PADDLE_ENFORCE_LT(
peer,
comm_ctx->GetSize(),
errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm_ctx->GetSize()));
return comm_ctx;
}
#endif
template <typename T, typename Context>
void PSendKernel(const Context& dev_ctx,
const DenseTensor& x,
int peer,
bool dynamic_shape) {
#if defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_RCCL) && NCCL_VERSION_CODE >= 2703
auto comm_ctx = GetCommContext(dev_ctx, peer);
gpuStream_t stream = dev_ctx.stream();
if (dynamic_shape) {
send_shape_info<Context>(dev_ctx, x, comm_ctx, peer, stream);
}
comm_ctx->Send(x, peer, stream);
#else
PADDLE_THROW(
errors::PreconditionNotMet("PaddlePaddle should compile with GPU."
"and NCCL version >= 2.7.3 is needed."));
#endif
}
template <typename T, typename Context>
void PSendArrayKernel(const Context& dev_ctx,
const TensorArray& x_array,
int peer) {
#if defined(PADDLE_WITH_NCCL) || \
defined(PADDLE_WITH_RCCL) && NCCL_VERSION_CODE >= 2703
auto comm_ctx = GetCommContext(dev_ctx, peer);
gpuStream_t stream = dev_ctx.stream();
for (size_t idx = 0; idx < x_array.size(); idx++) {
VLOG(3) << "LodTensorArray: idx(" << idx << ")";
auto x = x_array.at(idx);
int numel = x.numel();
ncclDataType_t dtype = ToNCCLDataType(x.type());
comm_ctx->Send(x, peer, stream);
VLOG(3) << "rank " << comm_ctx->GetRank() << " send "
<< phi::product(x.dims()) << " to " << peer;
}
#else
PADDLE_THROW(
errors::PreconditionNotMet("PaddlePaddle should compile with GPU."
"and NCCL version >= 2.7.3 is needed."));
#endif
}
} // namespace phi
#if NCCL_VERSION_CODE >= 21000
PD_REGISTER_KERNEL(p_send,
GPU,
ALL_LAYOUT,
phi::PSendKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(p_send_array,
GPU,
ALL_LAYOUT,
phi::PSendArrayKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(p_send,
GPU,
ALL_LAYOUT,
phi::PSendKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(p_send_array,
GPU,
ALL_LAYOUT,
phi::PSendArrayKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/reduce_scatter_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#endif
namespace phi {
template <typename T, typename Context>
void ReduceScatterKernel(const Context& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
gpuStream_t stream = dev_ctx.stream();
auto comm_context =
static_cast<distributed::NCCLCommContext*>(dev_ctx.GetCommContext());
PADDLE_ENFORCE_NE(
comm_context,
nullptr,
errors::Unavailable("NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
auto out_dims = x.dims();
PADDLE_ENFORCE_EQ(
out_dims[0] % nranks,
0,
errors::InvalidArgument("The input tensor X's "
"dim[0] (%d) should be divisible by nranks(%d)",
out_dims[0],
nranks));
out_dims[0] = out_dims[0] / nranks;
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
comm_context->ReduceScatter(out, x, stream);
#else
PADDLE_THROW(
errors::PreconditionNotMet("PaddlePaddle should compile with GPU."));
#endif
}
} // namespace phi
#if NCCL_VERSION_CODE >= 21000
PD_REGISTER_KERNEL(reduce_scatter,
GPU,
ALL_LAYOUT,
phi::ReduceScatterKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(reduce_scatter,
GPU,
ALL_LAYOUT,
phi::ReduceScatterKernel,
float,
double,
int,
bool,
int8_t,
uint8_t,
int64_t,
phi::dtype::float16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_array.h"
namespace phi {
template <typename T, typename Context>
void PRecvKernel(const Context& dev_ctx,
int peer,
DataType dtype,
bool dynamic_shape,
DenseTensor* out);
template <typename T, typename Context>
void PRecvArrayKernel(const Context& dev_ctx,
int peer,
DataType dtype,
const std::vector<int>& out_shape,
TensorArray* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_array.h"
namespace phi {
template <typename T, typename Context>
void PSendKernel(const Context& dev_ctx,
const DenseTensor& x,
int peer,
bool dynamic_shape);
template <typename T, typename Context>
void SendArrayV3Kernel(const Context& dev_ctx, const TensorArray& x, int peer);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ReduceScatterKernel(const Context& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PSendOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("p_send", {"x"}, {"peer", "dynamic_shape"}, {});
}
KernelSignature PSendArrayOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("p_send_array", {"x"}, {"peer"}, {});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(p_send, phi::PSendOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(p_send_array, phi::PSendArrayOpArgumentMapping);
......@@ -13,12 +13,14 @@
# limitations under the License.
import paddle
import paddle.distributed as dist
from paddle import framework
from paddle.distributed.communication.group import (
_get_global_group,
_warn_cur_rank_not_in_group,
)
from paddle.distributed.communication.reduce import ReduceOp, _get_reduce_op
from paddle.fluid import data_feeder
def _reduce_scatter_tensor_in_dygraph(
......@@ -65,6 +67,40 @@ def _reduce_scatter_in_dygraph(
return task
def _reduce_scatter_in_static_mode(tensor, tensor_or_tensor_list, group):
op_type = 'reduce_scatter'
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'int8',
'uint8',
'bool',
],
op_type,
)
helper = framework.LayerHelper(op_type, **locals())
ring_id = 0 if group is None else group.id
nranks = dist.get_world_size()
helper.append_op(
type=op_type,
inputs={'x': [tensor_or_tensor_list]},
outputs={'out': [tensor]},
attrs={
'ring_id': ring_id,
'nranks': nranks,
},
)
return None
def reduce_scatter(
tensor,
tensor_or_tensor_list,
......@@ -141,10 +177,13 @@ def reduce_scatter(
sync_op,
use_calc_stream,
)
raise RuntimeError(
"paddle.distributed.stream.reduce_scatter is only supported in dygraph mode now."
)
else:
assert (
group is None
), "Group can not be used in static graph mode for now."
return _reduce_scatter_in_static_mode(
tensor, tensor_or_tensor_list, group
)
def _reduce_scatter_base(
......
......@@ -215,7 +215,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_reduce_scatter_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_reduce_scatter_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "150" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import paddle
from paddle import fluid
paddle.enable_static()
class TestCollectiveReduceScatterAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
pass
def get_model_new(
self,
main_prog,
startup_program,
rank,
dtype='float32',
reduce_type=None,
):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata", shape=[10, 1000], dtype=dtype
)
tindata.desc.set_need_check_feed(False)
# toutdata = layers.fill_constant(shape=[5, 1000], dtype=dtype, value=1.0)
toutdata = paddle.static.data(
name="toutdata", shape=[5, 1000], dtype=dtype
)
paddle.distributed.reduce_scatter(toutdata, tindata)
return [toutdata]
if __name__ == "__main__":
runtime_main(TestCollectiveReduceScatterAPI, "reduce_scatter")
......@@ -15,11 +15,77 @@
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
import paddle
from paddle import fluid
from paddle import fluid, framework
from paddle.fluid import data_feeder
paddle.enable_static()
def send_new(tensor, dst, group=None, sync_op=True):
op_type = 'p_send'
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'int8',
'uint8',
'bool',
],
op_type,
)
ring_id = 0 if group is None else group.id
helper = framework.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'x': [tensor]},
attrs={
'ring_id': ring_id,
'peer': dst,
'dynamic_shape': True,
},
)
return None
def recv_new(tensor, src, group=None, sync_op=True, dtype='float32'):
op_type = 'p_recv'
data_feeder.check_variable_and_dtype(
tensor,
'tensor',
[
'float16',
'float32',
'float64',
'int32',
'int64',
'int8',
'uint8',
'bool',
],
op_type,
)
ring_id = 0 if group is None else group.id
helper = framework.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
outputs={'out': [tensor]},
attrs={
'ring_id': ring_id,
'peer': src,
'dynamic_shape': True,
'out_shape': tensor.shape,
'dtype': fluid.framework.convert_np_dtype_to_dtype_(dtype),
},
)
return None
class TestCollectiveSendRecvAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
......@@ -37,6 +103,26 @@ class TestCollectiveSendRecvAPI(TestCollectiveAPIRunnerBase):
paddle.distributed.recv(tindata, src=0)
return [tindata]
def get_model_new(
self,
main_prog,
startup_program,
rank,
dtype='float32',
reduce_type=None,
):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.static.data(
name="tindata",
shape=[10, 1000],
dtype=dtype,
)
if rank == 0:
send_new(tindata, dst=1)
else:
recv_new(tindata, src=0, dtype=dtype)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveSendRecvAPI, "sendrecv")
......@@ -21,6 +21,28 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
def _setup_config(self):
pass
def test_reduce_scatter_nccl_with_comm_context(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
"int8",
"uint8",
"bool",
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
"collective_reduce_scatter_api.py",
"reduce_scatter",
"nccl",
dtype=dtype,
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_reduce_scatter_nccl_dygraph(self):
dtypes_to_test = [
"float16",
......
......@@ -30,6 +30,29 @@ class TestCollectiveSendRecvAPI(TestDistBase):
# self.check_with_place("collective_sendrecv_api.py", "sendrecv",
# "nccl")
def test_sendrecv_nccl_with_comm_context(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
"int8",
"uint8",
"bool",
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
if paddle.fluid.core.is_compiled_with_cuda():
self.check_with_place(
"collective_sendrecv_api.py",
"sendrecv",
"nccl",
dtype=dtype,
need_envs={"USE_COMM_CONTEXT": "1"},
)
def test_sendrecv_nccl_dygraph(self):
dtypes_to_test = [
"float16",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册