From 47f87ad3eac25698419f8012ffb5f15304f71cb2 Mon Sep 17 00:00:00 2001 From: yu wentao Date: Fri, 24 Mar 2023 21:13:28 +0800 Subject: [PATCH] add phi operator allreduce/reduce (#51857) * add all_reduce, reduce kernel and api * fix all_reduce reduce ut fix reduce op maker conflict fix merge conflicts * fix conflicts, rename ReduceOp->ReduceBaseOp in reduce_ops rename allreduce op, to remove * fix code format fix comments * modify test_collective_reduce_api ut timeout * fix PR-CI-Build fix comments: format phi operator --- cmake/external/gflags.cmake | 2 +- .../operators/collective/allreduce_op.cc | 8 +- .../operators/reduce_ops/reduce_amax_op.cc | 4 +- .../operators/reduce_ops/reduce_amin_op.cc | 4 +- .../operators/reduce_ops/reduce_any_op.cc | 2 +- .../operators/reduce_ops/reduce_max_op.cc | 4 +- .../operators/reduce_ops/reduce_mean_op.cc | 4 +- .../operators/reduce_ops/reduce_min_op.cc | 4 +- .../fluid/operators/reduce_ops/reduce_op.cu.h | 4 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 32 +++--- .../operators/reduce_ops/reduce_prod_op.cc | 4 +- .../operators/reduce_ops/reduce_sum_op.cc | 4 +- paddle/phi/api/yaml/ops.yaml | 20 ---- paddle/phi/api/yaml/static_ops.yaml | 40 ++++++++ .../phi/core/distributed/gloo_comm_context.cc | 29 +++++- .../phi/core/distributed/gloo_comm_context.h | 7 ++ paddle/phi/core/distributed/gloo_utils.h | 32 ++++++ .../phi/core/distributed/nccl_comm_context.cc | 30 ++++++ .../phi/core/distributed/nccl_comm_context.h | 11 +++ paddle/phi/core/distributed/reduce_helper.h | 21 ++++ paddle/phi/infermeta/unary.cc | 20 +++- paddle/phi/infermeta/unary.h | 8 +- paddle/phi/kernels/all_reduce_kernel.h | 28 ++++++ paddle/phi/kernels/cpu/all_reduce_kernel.cc | 63 ++++++++++++ paddle/phi/kernels/cpu/reduce_kernel.cc | 63 ++++++++++++ paddle/phi/kernels/gpu/all_reduce_kernel.cu | 97 ++++++++++++++++++ paddle/phi/kernels/gpu/reduce_kernel.cu | 98 +++++++++++++++++++ paddle/phi/kernels/reduce_kernel.h | 29 ++++++ .../tests/unittests/collective/CMakeLists.txt | 4 +- .../collective/collective_allgather_api.py | 4 +- .../collective/collective_allreduce_api.py | 60 +++++++++++- .../collective/collective_broadcast_api.py | 11 +-- .../collective/collective_reduce_api.py | 55 +++++++++++ .../test_collective_allreduce_api.py | 54 ++++++++++ .../collective/test_collective_reduce_api.py | 55 +++++++++++ .../unittests/test_collective_api_base.py | 3 +- 36 files changed, 837 insertions(+), 81 deletions(-) create mode 100644 paddle/phi/core/distributed/reduce_helper.h create mode 100644 paddle/phi/kernels/all_reduce_kernel.h create mode 100644 paddle/phi/kernels/cpu/all_reduce_kernel.cc create mode 100644 paddle/phi/kernels/cpu/reduce_kernel.cc create mode 100644 paddle/phi/kernels/gpu/all_reduce_kernel.cu create mode 100644 paddle/phi/kernels/gpu/reduce_kernel.cu create mode 100644 paddle/phi/kernels/reduce_kernel.h diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index 7d5bb6caed5..606eb7ce1be 100755 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -14,8 +14,8 @@ include(ExternalProject) -set(GFLAGS_PREFIX_DIR ${THIRD_PARTY_PATH}/gflags) set(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags) +set(GFLAGS_PREFIX_DIR ${THIRD_PARTY_PATH}/gflags) set(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE) diff --git a/paddle/fluid/operators/collective/allreduce_op.cc b/paddle/fluid/operators/collective/allreduce_op.cc index ca9cd1ca529..91ca5105471 100644 --- a/paddle/fluid/operators/collective/allreduce_op.cc +++ b/paddle/fluid/operators/collective/allreduce_op.cc @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -class AllReduceOp : public framework::OperatorWithKernel { +class AllReduceDelOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -34,7 +34,7 @@ class AllReduceOp : public framework::OperatorWithKernel { } }; -class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker { +class AllReduceDelOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddInput("X", "(Tensor), tensor to be allreduced."); @@ -70,8 +70,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_WITHOUT_GRADIENT(allreduce, - ops::AllReduceOp, - ops::AllReduceOpMaker); + ops::AllReduceDelOp, + ops::AllReduceDelOpMaker); PD_REGISTER_STRUCT_KERNEL(allreduce, CPU, diff --git a/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc b/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc index 61f238f19d1..16650043fd3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_amax_op.cc @@ -18,7 +18,7 @@ namespace ops = paddle::operators; -class ReduceAMaxOpMaker : public ops::ReduceOpMaker { +class ReduceAMaxOpMaker : public ops::ReduceBaseOpMaker { protected: virtual std::string GetName() const { return "reduce_amax"; } virtual std::string GetOpType() const { return "Reduce reduce_amax"; } @@ -30,7 +30,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(reduce_amax, REGISTER_OPERATOR( reduce_amax, - ops::ReduceOp, + ops::ReduceBaseOp, ReduceAMaxOpMaker, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc b/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc index aac8414ac19..2450d6b0779 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_amin_op.cc @@ -18,7 +18,7 @@ namespace ops = paddle::operators; -class ReduceAMinOpMaker : public ops::ReduceOpMaker { +class ReduceAMinOpMaker : public ops::ReduceBaseOpMaker { protected: virtual std::string GetName() const { return "reduce_amin"; } virtual std::string GetOpType() const { return "Reduce reduce_amin"; } @@ -30,7 +30,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(reduce_amin, REGISTER_OPERATOR( reduce_amin, - ops::ReduceOp, + ops::ReduceBaseOp, ReduceAMinOpMaker, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc index 6634ccaaa01..ac3ec355876 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc @@ -32,7 +32,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(reduce_any, ReduceAnyInferShapeFunctor, PD_INFER_META(phi::ReduceInferMetaBase)); -class ReduceAnyOpMaker : public ops::ReduceOpMaker { +class ReduceAnyOpMaker : public ops::ReduceBaseOpMaker { protected: virtual std::string GetName() const { return "reduce_any"; } virtual std::string GetOpType() const { return "Reduce reduce_any"; } diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc index 0d5320d5634..1bb84a05469 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc @@ -22,7 +22,7 @@ namespace ops = paddle::operators; -class ReduceMaxOpMaker : public ops::ReduceOpMaker { +class ReduceMaxOpMaker : public ops::ReduceBaseOpMaker { protected: virtual std::string GetName() const { return "reduce_max"; } virtual std::string GetOpType() const { return "Reduce reduce_max"; } @@ -61,7 +61,7 @@ DECLARE_INFER_SHAPE_FUNCTOR( REGISTER_OPERATOR( reduce_max, - ops::ReduceOp, + ops::ReduceBaseOp, ReduceMaxOpMaker, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc index 9e15b347a0b..0048ec1e724 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc @@ -92,7 +92,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReduceMeanGradNoNeedBufferVarInferer, "X"); } // namespace operators } // namespace paddle -class __reduce_meanMaker__ : public ops::ReduceOpMaker { +class __reduce_meanMaker__ : public ops::ReduceBaseOpMaker { protected: virtual std::string GetName() const { return "reduce_mean"; } virtual std::string GetOpType() const { return "Reduce reduce_mean"; } @@ -104,7 +104,7 @@ DECLARE_INFER_SHAPE_FUNCTOR( PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase)); REGISTER_OPERATOR(reduce_mean, - ops::ReduceOp, + ops::ReduceBaseOp, __reduce_meanMaker__, ops::ReduceMeanOpGradMaker, ops::ReduceMeanOpGradMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_op.cc b/paddle/fluid/operators/reduce_ops/reduce_min_op.cc index e9fafce1332..d755d460722 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_min_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_min_op.cc @@ -19,7 +19,7 @@ namespace ops = paddle::operators; -class ReduceMinOpMaker : public ops::ReduceOpMaker { +class ReduceMinOpMaker : public ops::ReduceBaseOpMaker { protected: virtual std::string GetName() const { return "reduce_min"; } virtual std::string GetOpType() const { return "Reduce reduce_min"; } @@ -32,7 +32,7 @@ DECLARE_INFER_SHAPE_FUNCTOR( REGISTER_OPERATOR( reduce_min, - ops::ReduceOp, + ops::ReduceBaseOp, ReduceMinOpMaker, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index a62bac88ca3..21646d08db3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -29,7 +29,7 @@ namespace operators { template - class ReduceOp, + class ReduceBaseOp, typename TransformOp> void TensorReduceImpl(const phi::GPUContext& dev_ctx, const phi::DenseTensor& x, @@ -40,7 +40,7 @@ void TensorReduceImpl(const phi::GPUContext& dev_ctx, bool is_mean = false) { y->mutable_data(x.place()); - phi::funcs::ReduceKernel( + phi::funcs::ReduceKernel( static_cast(dev_ctx), x, y, diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 3282954bb79..f9b79010263 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -88,7 +88,7 @@ static inline std::vector GetReduceDim(const std::vector& dims, PADDLE_ENFORCE_LT(e, dim_size, paddle::platform::errors::InvalidArgument( - "ReduceOp: invalid axis, when x_dims is %d, " + "ReduceBaseOp: invalid axis, when x_dims is %d, " "axis[i] should less than x_dims, but got %d.", dim_size, e)); @@ -499,20 +499,20 @@ class ReduceGradKernel : public framework::OpKernel { } }; -class ReduceOp : public framework::OperatorWithKernel { +class ReduceBaseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceOp"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ReduceOp"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceBaseOp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ReduceBaseOp"); auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); auto dims = ctx->Attrs().Get>("dim"); PADDLE_ENFORCE_GT(dims.size(), 0, platform::errors::InvalidArgument( - "The input dim dimensions of ReduceOp " + "The input dim dimensions of ReduceBaseOp " "should be greater than 0. But received the dim " "dimesions of Reduce = %d.", dims.size())); @@ -636,9 +636,9 @@ class ReduceOp : public framework::OperatorWithKernel { } }; -class ReduceOpUseInputPlace : public ReduceOp { +class ReduceOpUseInputPlace : public ReduceBaseOp { public: - using ReduceOp::ReduceOp; + using ReduceBaseOp::ReduceBaseOp; protected: phi::KernelKey GetExpectedKernelType( @@ -655,11 +655,11 @@ class ReduceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceOp"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ReduceBaseOp"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", "Out@GRAD", - "ReduceOp"); + "ReduceBaseOp"); auto x_dims = ctx->GetInputDim("X"); auto x_rank = x_dims.size(); // TODO(dev): We should delete Infershape and migrate it into @@ -710,7 +710,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { } }; -class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { +class ReduceBaseOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() final { AddInput("X", @@ -763,7 +763,7 @@ If reduce_all is true, just reduce along all dimensions and output a scalar. #if defined(__HIPCC__) || defined(__NVCC__) || defined(__xpu__) template - class ReduceOp, + class ReduceBaseOp, template class TransformOp> class ReduceCudaKernel : public framework::OpKernel { @@ -790,7 +790,7 @@ class ReduceCudaKernel : public framework::OpKernel { std::vector dims_int64{dims.begin(), dims.end()}; - phi::Reduce( + phi::Reduce( dev_ctx, *input, reduce_all, dims_int64, false, pt_out_dtype, output); } }; @@ -869,14 +869,14 @@ struct DivideFunctor { namespace ops = paddle::operators; #define REGISTER_REDUCE_OP(op_name) \ - class __##op_name##Maker__ : public ops::ReduceOpMaker { \ + class __##op_name##Maker__ : public ops::ReduceBaseOpMaker { \ protected: \ virtual std::string GetName() const { return #op_name; } \ virtual std::string GetOpType() const { return "Reduce " #op_name; } \ }; \ REGISTER_OPERATOR( \ op_name, \ - ops::ReduceOp, \ + ops::ReduceBaseOp, \ __##op_name##Maker__, \ paddle::framework::DefaultGradOpMaker, \ paddle::framework::DefaultGradOpMaker, \ paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc index 0b78e807464..1ba1a1aa628 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc @@ -29,7 +29,7 @@ class OpBase; namespace ops = paddle::operators; -class ReduceProdOpMaker : public ops::ReduceOpMaker { +class ReduceProdOpMaker : public ops::ReduceBaseOpMaker { protected: virtual std::string GetName() const { return "reduce_prod"; } virtual std::string GetOpType() const { return "Reduce reduce_prod"; } @@ -42,7 +42,7 @@ DECLARE_INFER_SHAPE_FUNCTOR( REGISTER_OPERATOR( reduce_prod, - ops::ReduceOp, + ops::ReduceBaseOp, ReduceProdOpMaker, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker, diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 7098432c80e..645fa3bde6b 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -128,7 +128,7 @@ class ReduceSumVarTypeInference : public paddle::framework::VarTypeInference { } // namespace operators } // namespace paddle -class ReduceSumOpMaker : public ops::ReduceOpMaker { +class ReduceSumOpMaker : public ops::ReduceBaseOpMaker { protected: virtual std::string GetName() const { return "reduce_sum"; } virtual std::string GetOpType() const { return "Reduce reduce_sum"; } @@ -139,7 +139,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(reduce_sum, PD_INFER_META(phi::SumRawInferMeta)); REGISTER_OPERATOR(reduce_sum, - ops::ReduceOp, + ops::ReduceBaseOp, ReduceSumOpMaker, ops::ReduceSumVarTypeInference, ops::ReduceSumOpGradMaker, diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 30b8eed0fe7..03e2e3aef0d 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -33,16 +33,6 @@ data_type : x backward : addmm_grad -- op : all_gather - args : (Tensor X, int ring_id = 0, int nranks=0) - output : Tensor(Out) - infer_meta : - func : AllGatherInferMeta - param: [X, nranks] - kernel : - func : all_gather - param: [X, nranks] - - op : allclose args : (Tensor x, Tensor y, Scalar rtol="1e-5", Scalar atol="1e-8", bool equal_nan=false) output : Tensor(out) @@ -213,16 +203,6 @@ func : bmm backward : bmm_grad -- op : broadcast - args : (Tensor X, int ring_id = 0, int root = 0) - output : Tensor(Out) - infer_meta : - func : BroadcastBaseInferMeta - param: [X] - kernel : - func : broadcast - param: [X, root] - - op : broadcast_tensors args: (Tensor[] input) output: Tensor[]{input.size()} diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index a315d9c8086..f04a2961fbf 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -6,6 +6,36 @@ kernel : func : all +- op : all_gather + args : (Tensor x, int ring_id = 0, int nranks=0) + output : Tensor(out) + infer_meta : + func : AllGatherInferMeta + param: [x, nranks] + kernel : + func : all_gather + param: [x, nranks] + +- op : all_reduce + args : (Tensor x, int ring_id = 0, int reduce_type = 0) + output : Tensor(out) + infer_meta : + func : AllReduceInferMeta + param: [x] + kernel : + func : all_reduce + param: [x, reduce_type] + +- op : broadcast + args : (Tensor x, int ring_id = 0, int root = 0) + output : Tensor(out) + infer_meta : + func : DistBroadcastInferMeta + param: [x] + kernel : + func : broadcast + param: [x, root] + - op : embedding_with_eltwise_add_xpu args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx) output: Tensor @@ -136,6 +166,16 @@ backend : x force_backend : force_cpu +- op : reduce + args : (Tensor x, int ring_id = 0, int root_id = 0, int reduce_type = 0) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param: [x] + kernel : + func : reduce + param: [x, root_id, reduce_type] + - op : share_buffer args : (Tensor[] x, bool[] share_dims_and_dtype={}) output : Tensor[](out){x.size()}, Tensor[](xout){x.size()} diff --git a/paddle/phi/core/distributed/gloo_comm_context.cc b/paddle/phi/core/distributed/gloo_comm_context.cc index b5a50a39b78..7c956185ef4 100644 --- a/paddle/phi/core/distributed/gloo_comm_context.cc +++ b/paddle/phi/core/distributed/gloo_comm_context.cc @@ -13,15 +13,17 @@ // limitations under the License. #include "paddle/phi/core/distributed/gloo_comm_context.h" +#include "paddle/phi/core/distributed/gloo_utils.h" #include +#include #include +#include #include #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/distributed/check/static_check.h" -#include "paddle/phi/core/distributed/gloo_utils.h" #include "paddle/phi/core/enforce.h" namespace phi { @@ -67,5 +69,30 @@ void GlooCommContext::AllGather(phi::DenseTensor* out_tensor, GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); gloo::allgather(opts); } + +void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int reduce_type) { + gloo::AllreduceOptions opts(gloo_context_); + const auto& dtype = in_tensor.dtype(); + GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); + GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); + GENERATE_FUNC(dtype, SetReduceFunc, &opts, reduce_type); + gloo::allreduce(opts); +} + +void GlooCommContext::Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int reduce_type, + int root) { + gloo::ReduceOptions opts(gloo_context_); + opts.setRoot(root); + const auto& dtype = in_tensor.dtype(); + GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); + GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); + GENERATE_FUNC(dtype, SetReduceFunc, &opts, reduce_type); + gloo::reduce(opts); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/gloo_comm_context.h b/paddle/phi/core/distributed/gloo_comm_context.h index c1d550867cf..b8db0431c25 100644 --- a/paddle/phi/core/distributed/gloo_comm_context.h +++ b/paddle/phi/core/distributed/gloo_comm_context.h @@ -36,6 +36,13 @@ class GlooCommContext final : public CommContext { void Broadcast(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int root); + void AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int reduce_type); + void Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int reduce_type, + int root); void AllGather(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor); diff --git a/paddle/phi/core/distributed/gloo_utils.h b/paddle/phi/core/distributed/gloo_utils.h index 1efdd40efb8..80e5fca49af 100644 --- a/paddle/phi/core/distributed/gloo_utils.h +++ b/paddle/phi/core/distributed/gloo_utils.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -23,6 +24,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/reduce_helper.h" namespace phi { namespace distributed { @@ -99,6 +101,36 @@ void SetInput(P* opts, const phi::DenseTensor& tensor) { tensor.numel()); } +template +void SetReduceFunc(P* opts, int reduce_type) { + // gloo only support mutable data input + switch (reduce_type) { + case kRedSum: + opts->setReduceFunction( + static_cast( + &gloo::sum)); + break; + case kRedMax: + opts->setReduceFunction( + static_cast( + &gloo::max)); + break; + case kRedMin: + opts->setReduceFunction( + static_cast( + &gloo::min)); + break; + case kRedProd: + opts->setReduceFunction( + static_cast( + &gloo::product)); + break; + default: + PADDLE_THROW( + errors::InvalidArgument("Invalid reduce type: %d.", reduce_type)); + } +} + // env preparation std::shared_ptr CreateGlooDevice(); diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc index 8fd3337c2cf..7f3491ea53b 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.cc +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -65,5 +65,35 @@ void NCCLCommContext::AllGather(phi::DenseTensor* out_tensor, stream)); } +void NCCLCommContext::AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + ncclRedOp_t reduce_type, + gpuStream_t stream) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclAllReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + reduce_type, + nccl_comm_, + stream)); +} + +void NCCLCommContext::Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + ncclRedOp_t reduce_type, + int root, + gpuStream_t stream) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + reduce_type, + root, + nccl_comm_, + stream)); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/nccl_comm_context.h b/paddle/phi/core/distributed/nccl_comm_context.h index 53091c96e0b..f260536c02e 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.h +++ b/paddle/phi/core/distributed/nccl_comm_context.h @@ -40,6 +40,17 @@ class NCCLCommContext final : public CommContext { const phi::DenseTensor& in_tensor, gpuStream_t stream); + void AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + ncclRedOp_t reduce_type, + gpuStream_t stream); + + void Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + ncclRedOp_t reduce_type, + int root, + gpuStream_t stream); + private: DISABLE_COPY_AND_ASSIGN(NCCLCommContext); diff --git a/paddle/phi/core/distributed/reduce_helper.h b/paddle/phi/core/distributed/reduce_helper.h new file mode 100644 index 00000000000..19da924fb0d --- /dev/null +++ b/paddle/phi/core/distributed/reduce_helper.h @@ -0,0 +1,21 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace distributed { +enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; +} +} // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 66628e6f40a..497b35fa0ef 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -128,6 +128,11 @@ void AllGatherInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { out->set_dims(dim); } +void AllReduceInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dtype(x.dtype()); + out->set_dims(x.dims()); +} + void ArgMinMaxInferMeta(const MetaTensor& x, const Scalar& axis, bool keepdims, @@ -368,11 +373,6 @@ void BatchSizeLikeInferMeta(const MetaTensor& x, out->set_dims(output_dim); } -void BroadcastBaseInferMeta(const MetaTensor& x, MetaTensor* out) { - out->set_dtype(x.dtype()); - out->set_dims(x.dims()); -} - void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(out_dtype); @@ -772,6 +772,16 @@ void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out) { out->set_dtype(alpha.dtype()); } +void DistReduceInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dtype(x.dtype()); + out->set_dims(x.dims()); +} + +void DistBroadcastInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dtype(x.dtype()); + out->set_dims(x.dims()); +} + void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { phi::DDim x_dims = x.dims(); int rank = x_dims.size(); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 013eddcf052..bb7149f9fc6 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -41,6 +41,8 @@ void AffineGridInferMeta(const MetaTensor& input, void AllGatherInferMeta(const MetaTensor& x, int nranks, MetaTensor* out); +void AllReduceInferMeta(const MetaTensor& x, MetaTensor* out); + void ArgMinMaxInferMeta(const MetaTensor& x, const Scalar& axis, bool keepdims, @@ -65,8 +67,6 @@ void BatchSizeLikeInferMeta(const MetaTensor& x, int out_batch_size_dim, MetaTensor* out); -void BroadcastBaseInferMeta(const MetaTensor& x, MetaTensor* out); - void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void ChannelShuffleInferMeta(const MetaTensor& x, @@ -128,6 +128,10 @@ void DiagonalInferMeta( void DirichletInferMeta(const MetaTensor& alpha, MetaTensor* out); +void DistBroadcastInferMeta(const MetaTensor& x, MetaTensor* out); + +void DistReduceInferMeta(const MetaTensor& x, MetaTensor* out); + void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v); void EighInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/all_reduce_kernel.h b/paddle/phi/kernels/all_reduce_kernel.h new file mode 100644 index 00000000000..e750af7a620 --- /dev/null +++ b/paddle/phi/kernels/all_reduce_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/reduce_helper.h" + +namespace phi { + +template +void AllReduceKernel(const Context& dev_ctx, + const DenseTensor& x, + int reduce_type, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/all_reduce_kernel.cc b/paddle/phi/kernels/cpu/all_reduce_kernel.cc new file mode 100644 index 00000000000..920f900492e --- /dev/null +++ b/paddle/phi/kernels/cpu/all_reduce_kernel.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/all_reduce_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/phi/core/distributed/gloo_comm_context.h" +#endif + +namespace phi { + +template +void AllReduceKernel(const Context& dev_ctx, + const DenseTensor& x, + int reduce_type, + DenseTensor* out) { +#if defined(PADDLE_WITH_GLOO) + out->Resize(x.dims()); + dev_ctx.template Alloc(out); + + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + errors::Unavailable("NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + comm_ctx->AllReduce(out, x, reduce_type); + +#else + PADDLE_THROW( + errors::PreconditionNotMet("PaddlePaddle should compile with GPU.")); +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(all_reduce, + CPU, + ALL_LAYOUT, + phi::AllReduceKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/reduce_kernel.cc b/paddle/phi/kernels/cpu/reduce_kernel.cc new file mode 100644 index 00000000000..a368e85bff9 --- /dev/null +++ b/paddle/phi/kernels/cpu/reduce_kernel.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/phi/core/distributed/gloo_comm_context.h" +#endif + +namespace phi { + +template +void ReduceKernel(const Context& dev_ctx, + const DenseTensor& x, + int root, + int reduce_type, + DenseTensor* out) { +#if defined(PADDLE_WITH_GLOO) + out->Resize(x.dims()); + dev_ctx.template Alloc(out); + + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + errors::Unavailable("NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + comm_ctx->Reduce(out, x, reduce_type, root); +#else + PADDLE_THROW( + errors::PreconditionNotMet("PaddlePaddle should compile with GPU.")); +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(reduce, + CPU, + ALL_LAYOUT, + phi::ReduceKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/all_reduce_kernel.cu b/paddle/phi/kernels/gpu/all_reduce_kernel.cu new file mode 100644 index 00000000000..32b6dba2054 --- /dev/null +++ b/paddle/phi/kernels/gpu/all_reduce_kernel.cu @@ -0,0 +1,97 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/all_reduce_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#endif + +namespace phi { + +template +void AllReduceKernel(const Context& dev_ctx, + const DenseTensor& x, + int reduce_type, + DenseTensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + out->Resize(x.dims()); + dev_ctx.template Alloc(out); + + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + errors::Unavailable("NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + gpuStream_t stream = dev_ctx.stream(); + PADDLE_ENFORCE_NOT_NULL(stream, + errors::NotFound("Should initialize NCCL firstly.")); + + ncclRedOp_t red_type = ncclSum; + switch (reduce_type) { + case distributed::kRedSum: + red_type = ncclSum; + break; + case distributed::kRedMax: + red_type = ncclMax; + break; + case distributed::kRedMin: + red_type = ncclMin; + break; + case distributed::kRedProd: + red_type = ncclProd; + break; + } + comm_ctx->AllReduce(out, x, red_type, stream); +#else + PADDLE_THROW( + errors::PreconditionNotMet("PaddlePaddle should compile with GPU.")); +#endif +} + +} // namespace phi + +#if NCCL_VERSION_CODE >= 21000 +PD_REGISTER_KERNEL(all_reduce, + GPU, + ALL_LAYOUT, + phi::AllReduceKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::bfloat16, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(all_reduce, + GPU, + ALL_LAYOUT, + phi::AllReduceKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpu/reduce_kernel.cu b/paddle/phi/kernels/gpu/reduce_kernel.cu new file mode 100644 index 00000000000..87b5e61bda7 --- /dev/null +++ b/paddle/phi/kernels/gpu/reduce_kernel.cu @@ -0,0 +1,98 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reduce_kernel.h" + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#endif + +namespace phi { + +template +void ReduceKernel(const Context& dev_ctx, + const DenseTensor& x, + int root, + int reduce_type, + DenseTensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + out->Resize(x.dims()); + dev_ctx.template Alloc(out); + + auto comm_ctx = + static_cast(dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE( + comm_ctx, + nullptr, + errors::Unavailable("NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + gpuStream_t stream = dev_ctx.stream(); + PADDLE_ENFORCE_NOT_NULL(stream, + errors::NotFound("Should initialize NCCL firstly.")); + + ncclRedOp_t red_type = ncclSum; + switch (reduce_type) { + case distributed::kRedSum: + red_type = ncclSum; + break; + case distributed::kRedMax: + red_type = ncclMax; + break; + case distributed::kRedMin: + red_type = ncclMin; + break; + case distributed::kRedProd: + red_type = ncclProd; + break; + } + comm_ctx->Reduce(out, x, red_type, root, stream); +#else + PADDLE_THROW( + errors::PreconditionNotMet("PaddlePaddle should compile with GPU.")); +#endif +} + +} // namespace phi + +#if NCCL_VERSION_CODE >= 21000 +PD_REGISTER_KERNEL(reduce, + GPU, + ALL_LAYOUT, + phi::ReduceKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::bfloat16, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(reduce, + GPU, + ALL_LAYOUT, + phi::ReduceKernel, + float, + double, + int, + bool, + int8_t, + uint8_t, + int64_t, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/reduce_kernel.h b/paddle/phi/kernels/reduce_kernel.h new file mode 100644 index 00000000000..57897a7b02b --- /dev/null +++ b/paddle/phi/kernels/reduce_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/reduce_helper.h" + +namespace phi { + +template +void ReduceKernel(const Context& dev_ctx, + const DenseTensor& x, + int root_id, + int reduce_type, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt index a29a4ed5f97..48ddb92213a 100644 --- a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt @@ -71,7 +71,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_allreduce_api MODULES test_collective_allreduce_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_allreduce_api - PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "250" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -195,7 +195,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_reduce_api MODULES test_collective_reduce_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_reduce_api - PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "230" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) bash_test_modules( diff --git a/python/paddle/fluid/tests/unittests/collective/collective_allgather_api.py b/python/paddle/fluid/tests/unittests/collective/collective_allgather_api.py index 2ca80f5b621..40ee69227b1 100644 --- a/python/paddle/fluid/tests/unittests/collective/collective_allgather_api.py +++ b/python/paddle/fluid/tests/unittests/collective/collective_allgather_api.py @@ -67,8 +67,8 @@ def all_gather_new(tensor_list, tensor, group=None): nranks = dist.get_world_size() helper.append_op( type=op_type, - inputs={'X': [tensor]}, - outputs={'Out': [out]}, + inputs={'x': [tensor]}, + outputs={'out': [out]}, attrs={ 'ring_id': ring_id, 'nranks': nranks, diff --git a/python/paddle/fluid/tests/unittests/collective/collective_allreduce_api.py b/python/paddle/fluid/tests/unittests/collective/collective_allreduce_api.py index 291ad384f3e..30a0f951b7b 100644 --- a/python/paddle/fluid/tests/unittests/collective/collective_allreduce_api.py +++ b/python/paddle/fluid/tests/unittests/collective/collective_allreduce_api.py @@ -15,11 +15,54 @@ from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main import paddle +import paddle.distributed as dist import paddle.fluid as fluid +import paddle.fluid.data_feeder as data_feeder +import paddle.framework as framework paddle.enable_static() +def all_reduce_new(tensor, reduce_type=str(dist.ReduceOp.SUM), group=None): + op_type = 'all_reduce' + data_feeder.check_variable_and_dtype( + tensor, + 'tensor', + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'int8', + 'uint8', + 'bool', + ], + op_type, + ) + + ring_id = 0 if group is None else group.id + + if not isinstance(ring_id, int): + raise ValueError("The type of 'ring_id' for all_reduce should be int.") + + # TODO: Support task and use task.wait in static graph mode + # Use use_calc_stream rather than sync_op + helper = framework.LayerHelper(op_type, **locals()) + if not reduce_type.isdigit(): + raise ValueError( + "The type of 'reduce_type' for all_reduce should be int." + ) + helper.append_op( + type=op_type, + inputs={'x': [tensor]}, + outputs={'out': [tensor]}, + attrs={'ring_id': ring_id, 'reduce_type': int(reduce_type)}, + ) + + return None + + class TestCollectiveAllreduceAPI(TestCollectiveAPIRunnerBase): def __init__(self): self.global_ring_id = 0 @@ -27,11 +70,26 @@ class TestCollectiveAllreduceAPI(TestCollectiveAPIRunnerBase): def get_model(self, main_prog, startup_program, rank): with fluid.program_guard(main_prog, startup_program): tindata = paddle.static.data( - name="tindata", shape=[-1, 10, 1000], dtype='float32' + name="tindata", shape=[10, 1000], dtype='float32' ) paddle.distributed.all_reduce(tindata) return [tindata] + def get_model_new( + self, + main_prog, + startup_program, + rank, + dtype='float32', + reduce_type=str(dist.ReduceOp.SUM), + ): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.static.data( + name="tindata", shape=[10, 1000], dtype=dtype + ) + all_reduce_new(tindata, reduce_type) + return [tindata] + if __name__ == "__main__": runtime_main(TestCollectiveAllreduceAPI, "allreduce") diff --git a/python/paddle/fluid/tests/unittests/collective/collective_broadcast_api.py b/python/paddle/fluid/tests/unittests/collective/collective_broadcast_api.py index 7cb93a58d71..86df7933368 100644 --- a/python/paddle/fluid/tests/unittests/collective/collective_broadcast_api.py +++ b/python/paddle/fluid/tests/unittests/collective/collective_broadcast_api.py @@ -45,8 +45,8 @@ def broadcast_new(tensor, src, group=None, sync_op=True): helper.append_op( type=op_type, - inputs={'X': [tensor]}, - outputs={'Out': [tensor]}, + inputs={'x': [tensor]}, + outputs={'out': [tensor]}, attrs={ 'root': src, 'ring_id': ring_id, @@ -68,12 +68,7 @@ class TestCollectiveBroadcastAPI(TestCollectiveAPIRunnerBase): return [tindata] def get_model_new( - self, - main_prog, - startup_program, - rank, - dtype='float32', - reduce_type=None, + self, main_prog, startup_program, rank, dtype=None, reduce_type=None ): with fluid.program_guard(main_prog, startup_program): tindata = paddle.static.data( diff --git a/python/paddle/fluid/tests/unittests/collective/collective_reduce_api.py b/python/paddle/fluid/tests/unittests/collective/collective_reduce_api.py index 6f033c7d1fd..50da61bfcb0 100644 --- a/python/paddle/fluid/tests/unittests/collective/collective_reduce_api.py +++ b/python/paddle/fluid/tests/unittests/collective/collective_reduce_api.py @@ -15,11 +15,50 @@ from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main import paddle +import paddle.distributed as dist import paddle.fluid as fluid +import paddle.fluid.data_feeder as data_feeder +import paddle.framework as framework paddle.enable_static() +def reduce_new(tensor, dst, reduce_type=str(dist.ReduceOp.SUM), group=None): + op_type = "reduce" + data_feeder.check_variable_and_dtype( + tensor, + 'tensor', + [ + 'float16', + 'float32', + 'float64', + 'int32', + 'int64', + 'int8', + 'uint8', + 'bool', + ], + op_type, + ) + + ring_id = 0 if group is None else group.id + + helper = framework.LayerHelper(op_type, **locals()) + if not reduce_type.isdigit(): + raise ValueError("The type of 'reduce_type' for reduce should be int.") + helper.append_op( + type=op_type, + inputs={'x': [tensor]}, + outputs={'out': [tensor]}, + attrs={ + 'ring_id': ring_id, + 'root_id': dst, + 'reduce_type': int(reduce_type), + }, + ) + return None + + class TestCollectiveReduceAPI(TestCollectiveAPIRunnerBase): def __init__(self): self.global_ring_id = 0 @@ -33,6 +72,22 @@ class TestCollectiveReduceAPI(TestCollectiveAPIRunnerBase): paddle.distributed.reduce(tindata, dst=0) return [tindata] + def get_model_new( + self, + main_prog, + startup_program, + rank, + dtype='float32', + reduce_type=str(dist.ReduceOp.SUM), + ): + with fluid.program_guard(main_prog, startup_program): + tindata = paddle.static.data( + name="tindata", shape=[10, 1000], dtype=dtype + ) + tindata.desc.set_need_check_feed(False) + reduce_new(tindata, dst=0, reduce_type=reduce_type) + return [tindata] + if __name__ == "__main__": runtime_main(TestCollectiveReduceAPI, "reduce") diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py index 82bdfdbc92f..97850b5552b 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_allreduce_api.py @@ -19,6 +19,7 @@ from test_collective_api_base import TestDistBase import paddle paddle.enable_static() +import paddle.distributed as dist class TestCollectiveAllreduceAPI(TestDistBase): @@ -31,6 +32,33 @@ class TestCollectiveAllreduceAPI(TestDistBase): "collective_allreduce_api.py", "allreduce", "nccl" ) + def test_allreduce_nccl_with_comm_context(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + red_types_to_test = [ + dist.ReduceOp.SUM, + ] + if self._nccl_version >= 2100: + dtypes_to_test.append("bfloat16") + for dtype in dtypes_to_test: + for red_type in red_types_to_test: + self.check_with_place( + "collective_allreduce_api.py", + "allreduce", + "nccl", + dtype=dtype, + reduce_type=red_type, + need_envs={"USE_COMM_CONTEXT": "1"}, + ) + def test_allreduce_bkcl(self): if paddle.fluid.core.is_compiled_with_xpu(): self.check_with_place( @@ -42,6 +70,32 @@ class TestCollectiveAllreduceAPI(TestDistBase): "collective_allreduce_api.py", "allreduce", "gloo", "2" ) + def test_allreduce_gloo_with_comm_context(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + red_types_to_test = [ + dist.ReduceOp.SUM, + ] + for dtype in dtypes_to_test: + for red_type in red_types_to_test: + self.check_with_place( + "collective_allreduce_api.py", + "allreduce", + "gloo", + "2", + dtype=dtype, + reduce_type=red_type, + need_envs={"USE_COMM_CONTEXT": "1"}, + ) + def test_allreduce_nccl_dygraph(self): dtypes_to_test = [ "float16", diff --git a/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py b/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py index 301e57431e4..70ae163054a 100644 --- a/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py +++ b/python/paddle/fluid/tests/unittests/collective/test_collective_reduce_api.py @@ -19,6 +19,7 @@ from test_collective_api_base import TestDistBase import paddle paddle.enable_static() +import paddle.distributed as dist class TestCollectiveReduceAPI(TestDistBase): @@ -29,6 +30,34 @@ class TestCollectiveReduceAPI(TestDistBase): if paddle.fluid.core.is_compiled_with_cuda(): self.check_with_place("collective_reduce_api.py", "reduce", "nccl") + def test_reduce_nccl_with_comm_context(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + red_types_to_test = [ + dist.ReduceOp.SUM, + ] + if self._nccl_version >= 2100: + dtypes_to_test.append("bfloat16") + for dtype in dtypes_to_test: + if paddle.fluid.core.is_compiled_with_cuda(): + for red_type in red_types_to_test: + self.check_with_place( + "collective_reduce_api.py", + "reduce", + "nccl", + dtype=dtype, + reduce_type=red_type, + need_envs={"USE_COMM_CONTEXT": "1"}, + ) + def test_reduce_bkcl(self): if paddle.fluid.core.is_compiled_with_xpu(): self.check_with_place("collective_reduce_api.py", "reduce", "bkcl") @@ -36,6 +65,32 @@ class TestCollectiveReduceAPI(TestDistBase): def test_reduce_gloo(self): self.check_with_place("collective_reduce_api.py", "reduce", "gloo", "1") + def test_reduce_gloo_with_comm_context(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + red_types_to_test = [ + dist.ReduceOp.SUM, + ] + for dtype in dtypes_to_test: + for red_type in red_types_to_test: + self.check_with_place( + "collective_reduce_api.py", + "reduce", + "gloo", + "1", + dtype=dtype, + reduce_type=red_type, + need_envs={"USE_COMM_CONTEXT": "1"}, + ) + def test_reduce_nccl_dygraph(self): dtypes_to_test = [ "float16", diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 0454242a4ac..b8379843c30 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -133,7 +133,7 @@ class TestCollectiveAPIRunnerBase: train_prog, startup_prog, rank, - dtype=args["dtype"], + dtype=args['dtype'], reduce_type=args['reduce_type'], ) if args["use_comm_context"] @@ -373,7 +373,6 @@ class TestDistBase(unittest.TestCase): need_result = np.amin([input1, input2], 0) elif reduce_type == dist.ReduceOp.PROD: need_result = np.prod([input1, input2], 0) - need_result = input1 + input2 # bfloat16 precision loss comes from truncating the last 16 bits of float32, # which sums (\sum_{i=-23}^{-8}2^{i}) to about 0.0078 if dtype == "bfloat16": -- GitLab