From a873fa84ceca411a5a776ff8ae303f8be24df95a Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Tue, 2 Jul 2019 15:31:21 +0800 Subject: [PATCH] supports collective training with programs (#18392) 1. Since allreduce op has 4 reduce types, We split these four reduce types into four ops 2. We also refined the collective op code, e.g. we separated the collective op kernel into CPUKernel and CUDAKernel, and remove the device specified DeviceContext parameter in template as we already knew the target DeviceContext 3. We remove the newly added Collective op role to reduce the complexity of program and graph analysis --- paddle/fluid/framework/op_proto_maker.cc | 1 - paddle/fluid/framework/op_proto_maker.h | 3 - .../operators/collective/c_allgather_op.cc | 21 ++- .../operators/collective/c_allgather_op.cu.cc | 64 +++++++- .../operators/collective/c_allgather_op.h | 47 +----- .../collective/c_allreduce_max_op.cc | 39 +++++ ...duce_op.cu.cc => c_allreduce_max_op.cu.cc} | 10 +- .../collective/c_allreduce_min_op.cc | 39 +++++ .../collective/c_allreduce_min_op.cu.cc | 25 +++ .../operators/collective/c_allreduce_op.cc | 83 ---------- .../operators/collective/c_allreduce_op.h | 107 ++++++++---- .../collective/c_allreduce_prod_op.cc | 39 +++++ .../collective/c_allreduce_prod_op.cu.cc | 25 +++ .../collective/c_allreduce_sum_op.cc | 54 +++++++ .../collective/c_allreduce_sum_op.cu.cc | 25 +++ .../operators/collective/c_broadcast_op.cc | 18 +-- .../operators/collective/c_broadcast_op.cu.cc | 74 ++++++++- .../operators/collective/c_broadcast_op.h | 61 +------ .../operators/collective/c_comm_init_op.cc | 1 + .../operators/collective/c_gen_nccl_id_op.cc | 4 + .../collective/c_reducescatter_op.cc | 18 +-- .../collective/c_reducescatter_op.cu.cc | 62 ++++++- .../operators/collective/c_reducescatter_op.h | 45 +----- .../collective/c_sync_calc_stream_op.cc | 13 +- .../collective/c_sync_comm_stream_op.cc | 12 +- paddle/fluid/platform/collective_helper.cc | 31 ++-- paddle/fluid/platform/collective_helper.h | 16 +- paddle/fluid/pybind/const_value.cc | 1 - .../unittests/collective_allreduce_op.py | 6 +- python/paddle/fluid/transpiler/collective.py | 153 ++++++++++-------- 30 files changed, 669 insertions(+), 428 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_allreduce_max_op.cc rename paddle/fluid/operators/collective/{c_allreduce_op.cu.cc => c_allreduce_max_op.cu.cc} (69%) create mode 100644 paddle/fluid/operators/collective/c_allreduce_min_op.cc create mode 100644 paddle/fluid/operators/collective/c_allreduce_min_op.cu.cc delete mode 100644 paddle/fluid/operators/collective/c_allreduce_op.cc create mode 100644 paddle/fluid/operators/collective/c_allreduce_prod_op.cc create mode 100644 paddle/fluid/operators/collective/c_allreduce_prod_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_allreduce_sum_op.cc create mode 100644 paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 27922c73047..b502ef7a7c6 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -74,7 +74,6 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, static_cast(OpRole::kBackward), static_cast(OpRole::kOptimize) | static_cast(OpRole::kLRSched), - static_cast(OpRole::kCollective), static_cast(OpRole::kNotSpecified)}) .SetDefault(static_cast(OpRole::kNotSpecified)); AddAttr>(OpRoleVarAttrName(), diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index bf6528b2377..5f3ce60e1d9 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -34,9 +34,6 @@ enum class OpRole { kDist = 0x0008, // Tag all learning rate scheduler operators. kLRSched = 0x0010, - // Collective role is for all collective operators and other operators used - // for collective training - kCollective = 0x0020, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/operators/collective/c_allgather_op.cc b/paddle/fluid/operators/collective/c_allgather_op.cc index 6f915953dab..18c8f5d6423 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_allgather_op.h" -#include // NOLINT + #include -#include namespace paddle { namespace operators { @@ -25,8 +24,7 @@ class CAllGatherOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SyncFCGather op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); int nranks = ctx->Attrs().Get("nranks"); PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2"); framework::DDim dim = ctx->GetInputDim("X"); @@ -49,10 +47,10 @@ class CAllGatherOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("nranks", "Total trainer count of the distributed training job"); AddComment(R"DOC( -***CAllGather Operator*** +CAllGather Operator each rank receives the aggregation of data from all ranks in the order of the ranks -Call NCCL collective AllGather internally.https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/api/colls.html#c.ncclAllGather +reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allgather )DOC"); } }; @@ -81,9 +79,8 @@ namespace plat = paddle::platform; REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp, ops::CAllGatherOpGradMaker, ops::CAllGatherOpMaker); -REGISTER_OP_CPU_KERNEL( - c_allgather, ops::CAllGatherOpKernel, - ops::CAllGatherOpKernel, - ops::CAllGatherOpKernel, - ops::CAllGatherOpKernel, - ops::CAllGatherOpKernel); +REGISTER_OP_CPU_KERNEL(c_allgather, ops::CAllGatherOpCPUKernel, + ops::CAllGatherOpCPUKernel, + ops::CAllGatherOpCPUKernel, + ops::CAllGatherOpCPUKernel, + ops::CAllGatherOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index 8b13ceeb404..330219cd1f8 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -14,12 +14,64 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_allgather_op.h" +#include + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CAllGatherOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); + + int nranks = ctx.Attr("nranks"); + int rid = ctx.Attr("ring_id"); + auto comm = platform::NCCLCommContext::Instance().Get(rid); + PADDLE_ENFORCE_EQ(nranks, comm->nranks()); + + auto place = ctx.GetPlace(); + framework::DDim out_dims = in->dims(); + out_dims[0] *= nranks; + out->mutable_data(out_dims, place); + + int64_t send_numel = in->numel(); + const T* send_buff = in->data(); + T* recv_buff = out->data(); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + PADDLE_ENFORCE(platform::dynload::ncclAllGather( + send_buff, recv_buff, send_numel, static_cast(dtype), + comm->comm(), stream)); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - c_allgather, ops::CAllGatherOpKernel, - ops::CAllGatherOpKernel, - ops::CAllGatherOpKernel, - ops::CAllGatherOpKernel, - ops::CAllGatherOpKernel); +REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel, + ops::CAllGatherOpCUDAKernel, + ops::CAllGatherOpCUDAKernel, + ops::CAllGatherOpCUDAKernel, + ops::CAllGatherOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_allgather_op.h b/paddle/fluid/operators/collective/c_allgather_op.h index 8becbba0185..fe99a9e128d 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.h +++ b/paddle/fluid/operators/collective/c_allgather_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include #include @@ -22,52 +23,14 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/nccl_helper.h" -#endif - namespace paddle { namespace operators { -template -class CAllGatherOpKernel : public framework::OpKernel { +template +class CAllGatherOpCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto place = ctx.GetPlace(); - PADDLE_ENFORCE(is_gpu_place(place), - "CAllGatherOp can run on gpu place only for now."); -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - auto in = ctx.Input("X"); - auto out = ctx.Output("Out"); - ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); - - int rid = ctx.Attr("ring_id"); - auto comm = platform::NCCLCommContext::Instance().Get(rid); - int nranks = comm->nranks(); - - framework::DDim out_dims = in->dims(); - out_dims[0] *= nranks; - out->mutable_data(out_dims, place); - - int64_t send_numel = in->numel(); - const T* send_buff = in->data(); - T* recv_buff = out->data(); - - cudaStream_t stream = nullptr; - if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); - } else { - stream = comm->stream(); - } - - PADDLE_ENFORCE(platform::dynload::ncclAllGather( - send_buff, recv_buff, send_numel, static_cast(dtype), - comm->comm(), stream)); -#else - PADDLE_THROW("PaddlePaddle should compile with GPU."); -#endif + PADDLE_THROW("unimplemented cpu kernel for CAllGatherOp."); } }; diff --git a/paddle/fluid/operators/collective/c_allreduce_max_op.cc b/paddle/fluid/operators/collective/c_allreduce_max_op.cc new file mode 100644 index 00000000000..bcb529f1570 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_max_op.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace operators { + +class CAllReduceMaxOpMaker : public CAllReduceOpMaker { + protected: + std::string GetName() const override { return "Max"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_max, ops::CAllReduceOp, + ops::CAllReduceMaxOpMaker); + +REGISTER_OP_CPU_KERNEL(c_allreduce_max, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_allreduce_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_max_op.cu.cc similarity index 69% rename from paddle/fluid/operators/collective/c_allreduce_op.cu.cc rename to paddle/fluid/operators/collective/c_allreduce_max_op.cu.cc index 8b3246d95ac..34054103aa0 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allreduce_max_op.cu.cc @@ -18,8 +18,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - c_allreduce, ops::CAllReduceOpKernel, - ops::CAllReduceOpKernel, - ops::CAllReduceOpKernel, - ops::CAllReduceOpKernel, - ops::CAllReduceOpKernel); + c_allreduce_max, ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel) diff --git a/paddle/fluid/operators/collective/c_allreduce_min_op.cc b/paddle/fluid/operators/collective/c_allreduce_min_op.cc new file mode 100644 index 00000000000..9d27a9ceb30 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_min_op.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace operators { + +class CAllReduceMinOpMaker : public CAllReduceOpMaker { + protected: + std::string GetName() const override { return "Min"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_min, ops::CAllReduceOp, + ops::CAllReduceMinOpMaker); + +REGISTER_OP_CPU_KERNEL(c_allreduce_min, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_allreduce_min_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_min_op.cu.cc new file mode 100644 index 00000000000..4e8b6f9d0a9 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_min_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + c_allreduce_min, ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel) diff --git a/paddle/fluid/operators/collective/c_allreduce_op.cc b/paddle/fluid/operators/collective/c_allreduce_op.cc deleted file mode 100644 index 8af1135701b..00000000000 --- a/paddle/fluid/operators/collective/c_allreduce_op.cc +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include // NOLINT -#include - -#include "paddle/fluid/operators/collective/c_allreduce_op.h" - -namespace paddle { -namespace operators { - -class CAllReduceOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); - } -}; - -class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "(Tensor), tensor to be allreduced."); - AddOutput("Out", "(Tensor) the allreduced result."); - AddAttr("reduce_type", "(int default 0) determin the reduce type.") - .SetDefault(0); - AddAttr("ring_id", "(int default 0) communication ring id.") - .SetDefault(0); - AddAttr( - "use_calc_stream", - "(bool default false) eject CUDA operations to calculation stream.") - .SetDefault(false); - AddComment(R"DOC( -***CAllReduce Operator*** - -Call NCCL collective AllReduce internally. Note that this op must be used when one -thread is managing one GPU device. - -For speed reasons, reduce_type should be an integer: - -0: sum -1: prod -2: max -3: min -If input and output are the same variable, in-place allreduce will be used. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_WITHOUT_GRADIENT(c_allreduce, ops::CAllReduceOp, - ops::CAllReduceOpMaker); - -REGISTER_OP_CPU_KERNEL( - c_allreduce, ops::CAllReduceOpKernel, - ops::CAllReduceOpKernel, - ops::CAllReduceOpKernel, - ops::CAllReduceOpKernel, - ops::CAllReduceOpKernel); diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 0cd4b857ffd..1db5f15595e 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include -#include -#include + +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -29,17 +28,41 @@ limitations under the License. */ namespace paddle { namespace operators { -template -class CAllReduceOpKernel : public framework::OpKernel { +enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; + +class CAllReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +template +class CAllReduceOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW("CAllReduce op do not support CPUKernel for now."); + } +}; + +template +class CAllReduceOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto place = ctx.GetPlace(); - PADDLE_ENFORCE(is_gpu_place(place), - "CAllReduce op can run on gpu place only for now."); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto in = ctx.Input("X"); auto out = ctx.Output("Out"); + auto place = ctx.GetPlace(); ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); int64_t numel = in->numel(); const void* sendbuff = in->data(); @@ -49,23 +72,6 @@ class CAllReduceOpKernel : public framework::OpKernel { int rid = ctx.Attr("ring_id"); auto comm = platform::NCCLCommContext::Instance().Get(rid); - int reduce_type = ctx.Attr("reduce_type"); - ncclRedOp_t red_type = ncclSum; - switch (reduce_type) { - case 0: - red_type = ncclSum; - break; - case 1: - red_type = ncclProd; - break; - case 2: - red_type = ncclMax; - break; - case 3: - red_type = ncclMin; - break; - } - cudaStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); @@ -74,13 +80,60 @@ class CAllReduceOpKernel : public framework::OpKernel { stream = comm->stream(); } + ncclRedOp_t nccl_red_type = ncclSum; + switch (red_type) { + case kRedSum: + nccl_red_type = ncclSum; + break; + + case kRedMax: + nccl_red_type = ncclMax; + break; + + case kRedMin: + nccl_red_type = ncclMin; + break; + + case kRedProd: + nccl_red_type = ncclProd; + break; + + default: + PADDLE_THROW("Invalid reduce type: %d", red_type); + } + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - sendbuff, recvbuff, numel, dtype, red_type, comm->comm(), stream)); + sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); #else PADDLE_THROW("PaddlePaddle should compile with GPU."); #endif } }; +class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor), tensor to be allreduced."); + AddOutput("Out", "(Tensor) the allreduced result."); + AddAttr("ring_id", "(int default 0) communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(string::Sprintf(R"DOC( +CAllReduce %s Operator + +Call collective AllReduce with reduce type %s. If input and output are +the same variable, in-place allreduce will be used. +Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allreduce +)DOC", + GetName(), GetName())); + } + + protected: + virtual std::string GetName() const = 0; +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/collective/c_allreduce_prod_op.cc b/paddle/fluid/operators/collective/c_allreduce_prod_op.cc new file mode 100644 index 00000000000..3cfb1723f18 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_prod_op.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace operators { + +class CAllReduceProdOpMaker : public CAllReduceOpMaker { + protected: + std::string GetName() const override { return "Prod"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_prod, ops::CAllReduceOp, + ops::CAllReduceProdOpMaker); + +REGISTER_OP_CPU_KERNEL(c_allreduce_prod, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel) diff --git a/paddle/fluid/operators/collective/c_allreduce_prod_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_prod_op.cu.cc new file mode 100644 index 00000000000..61f76c178d0 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_prod_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + c_allreduce_prod, ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel) diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op.cc new file mode 100644 index 00000000000..c80c585a832 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace operators { + +class CAllReduceSumOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr retv(new framework::OpDesc()); + retv->SetType("c_allreduce_sum"); + retv->SetInput("X", OutputGrad("Out")); + retv->SetOutput("Out", InputGrad("X")); + retv->SetAttrMap(Attrs()); + return retv; + } +}; + +class CAllReduceSumOpMaker : public CAllReduceOpMaker { + protected: + std::string GetName() const override { return "Sum"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(c_allreduce_sum, ops::CAllReduceOp, + ops::CAllReduceSumOpGradMaker, ops::CAllReduceSumOpMaker); + +REGISTER_OP_CPU_KERNEL(c_allreduce_sum, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel, + ops::CAllReduceOpCPUKernel) diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc new file mode 100644 index 00000000000..8fe7fce21e4 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + c_allreduce_sum, ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel, + ops::CAllReduceOpCUDAKernel) diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cc b/paddle/fluid/operators/collective/c_broadcast_op.cc index ab8ed3d8695..72d330306cc 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cc @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include // NOLINT -#include - #include "paddle/fluid/operators/collective/c_broadcast_op.h" namespace paddle { @@ -50,9 +47,9 @@ class CBroadcastOpMaker : public framework::OpProtoAndCheckerMaker { "(bool default false) eject CUDA operations to calculation stream.") .SetDefault(false); AddComment(R"DOC( -***CBroadcast Operator*** +CBroadcast Operator -Call ncclBcast internally. +Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#broadcast )DOC"); } }; @@ -66,9 +63,8 @@ namespace plat = paddle::platform; REGISTER_OP_WITHOUT_GRADIENT(c_broadcast, ops::CBroadcastOp, ops::CBroadcastOpMaker); -REGISTER_OP_CPU_KERNEL( - c_broadcast, ops::CBroadcastOpKernel, - ops::CBroadcastOpKernel, - ops::CBroadcastOpKernel, - ops::CBroadcastOpKernel, - ops::CBroadcastOpKernel); +REGISTER_OP_CPU_KERNEL(c_broadcast, ops::CBroadcastOpCPUKernel, + ops::CBroadcastOpCPUKernel, + ops::CBroadcastOpCPUKernel, + ops::CBroadcastOpCPUKernel, + ops::CBroadcastOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index 23b0fb01ec3..c0f5bbd2c2f 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -14,12 +14,74 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_broadcast_op.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CBroadcastOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int rid = ctx.Attr("ring_id"); + auto comm = platform::NCCLCommContext::Instance().Get(rid); + + auto place = ctx.GetPlace(); + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + int root = ctx.Attr("root"); + if (root == comm->rank()) { + PADDLE_ENFORCE(platform::dynload::ncclBcast( + reinterpret_cast(const_cast(x->data())), numel, dtype, + root, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " + << x->numel(); + + if (out != x) { + framework::TensorCopy( + *static_cast(x), place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data(place), + numel, dtype, root, + comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved " + << framework::product(out->dims()); + } + + out->Resize(x->dims()); + out->set_lod(x->lod()); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - c_broadcast, ops::CBroadcastOpKernel, - ops::CBroadcastOpKernel, - ops::CBroadcastOpKernel, - ops::CBroadcastOpKernel, - ops::CBroadcastOpKernel); +REGISTER_OP_CUDA_KERNEL(c_broadcast, ops::CBroadcastOpCUDAKernel, + ops::CBroadcastOpCUDAKernel, + ops::CBroadcastOpCUDAKernel, + ops::CBroadcastOpCUDAKernel, + ops::CBroadcastOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_broadcast_op.h b/paddle/fluid/operators/collective/c_broadcast_op.h index c93c459b75f..4ceb0aa835f 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.h +++ b/paddle/fluid/operators/collective/c_broadcast_op.h @@ -22,69 +22,14 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/nccl_helper.h" -#endif - namespace paddle { namespace operators { -template -class CBroadcastOpKernel : public framework::OpKernel { +template +class CBroadcastOpCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto place = ctx.GetPlace(); - PADDLE_ENFORCE(is_gpu_place(place), - "CBroadcastOp can run on gpu place only for now."); -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - auto x = ctx.Input("X"); - auto out = ctx.Output("Out"); - int numel = x->numel(); - ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); - - int rid = ctx.Attr("ring_id"); - auto comm = platform::NCCLCommContext::Instance().Get(rid); - - cudaStream_t stream = nullptr; - if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); - } else { - stream = comm->stream(); - } - - int root = ctx.Attr("root"); - int nranks = comm->nranks(); - PADDLE_ENFORCE(root >= 0 && root < nranks, - "Expected root in range of [0,%d),but get %d", nranks, root); - if (root == comm->rank()) { - PADDLE_ENFORCE(platform::dynload::ncclBcast( - reinterpret_cast(const_cast(x->data())), numel, dtype, - root, comm->comm(), stream)); - VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " - << x->numel(); - - if (out != x) { - // TODO(liuyi05): check inplace - framework::TensorCopy( - *static_cast(x), place, - *platform::DeviceContextPool::Instance().Get(place), - static_cast(out)); - } - } else { - PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data(place), - numel, dtype, root, - comm->comm(), stream)); - VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved " - << framework::product(out->dims()); - } - - out->Resize(x->dims()); - out->set_lod(x->lod()); -#else - PADDLE_THROW("PaddlePaddle should compile with GPU."); -#endif + PADDLE_THROW("Unimplemented cpu kernel for CBroadcastOp."); } }; diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index 9dace1725f7..16ca6e5238e 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -28,6 +28,7 @@ limitations under the License. */ #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index a19a3fe1a38..d576ca7d6a3 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -11,9 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include #endif + #include #include #include @@ -24,9 +26,11 @@ limitations under the License. */ #include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h" + #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include "paddle/fluid/platform/nccl_helper.h" #endif + namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cc index feb9dcd5a48..1194ac71b32 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_reducescatter_op.h" -#include // NOLINT + #include -#include namespace paddle { namespace operators { @@ -54,9 +53,9 @@ class CReduceScatterOpMaker : public framework::OpProtoAndCheckerMaker { "(bool default false) eject CUDA operations to calculation stream.") .SetDefault(false); AddComment(R"DOC( -***CReduceScatter Operator*** +CReduceScatter Operator -Call NCCL collective ReduceScatter internally. +Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#reducescatter )DOC"); } }; @@ -85,9 +84,8 @@ namespace plat = paddle::platform; REGISTER_OPERATOR(c_reducescatter, ops::CReduceScatterOp, ops::CReduceScatterOpMaker); -REGISTER_OP_CPU_KERNEL( - c_reducescatter, ops::CReduceScatterOpKernel, - ops::CReduceScatterOpKernel, - ops::CReduceScatterOpKernel, - ops::CReduceScatterOpKernel, - ops::CReduceScatterOpKernel); +REGISTER_OP_CPU_KERNEL(c_reducescatter, ops::CReduceScatterOpCPUKernel, + ops::CReduceScatterOpCPUKernel, + ops::CReduceScatterOpCPUKernel, + ops::CReduceScatterOpCPUKernel, + ops::CReduceScatterOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index ef9eed2aabf..7244aa949eb 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -14,13 +14,61 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_reducescatter_op.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CReduceScatterOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + + int rid = ctx.Attr("ring_id"); + auto comm = platform::NCCLCommContext::Instance().Get(rid); + int nranks = comm->nranks(); + + auto place = ctx.GetPlace(); + auto out_dims = in->dims(); + out_dims[0] = out_dims[0] / nranks; + out->mutable_data(out_dims, place); + + int64_t recv_numel = in->numel() / nranks; + const T* send_buff = in->data(); + T* recv_buff = out->data(); + int dtype = platform::ToNCCLDataType(in->type()); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + PADDLE_ENFORCE(platform::dynload::ncclReduceScatter( + send_buff, recv_buff, recv_numel, static_cast(dtype), + ncclSum, comm->comm(), stream)); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - c_reducescatter, - ops::CReduceScatterOpKernel, - ops::CReduceScatterOpKernel, - ops::CReduceScatterOpKernel, - ops::CReduceScatterOpKernel, - ops::CReduceScatterOpKernel); +REGISTER_OP_CUDA_KERNEL(c_reducescatter, ops::CReduceScatterOpCUDAKernel, + ops::CReduceScatterOpCUDAKernel, + ops::CReduceScatterOpCUDAKernel, + ops::CReduceScatterOpCUDAKernel, + ops::CReduceScatterOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.h b/paddle/fluid/operators/collective/c_reducescatter_op.h index 93d623ff2e3..ee308080677 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.h +++ b/paddle/fluid/operators/collective/c_reducescatter_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include #include @@ -22,52 +23,14 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) -#include "paddle/fluid/platform/collective_helper.h" -#include "paddle/fluid/platform/nccl_helper.h" -#endif - namespace paddle { namespace operators { -template -class CReduceScatterOpKernel : public framework::OpKernel { +template +class CReduceScatterOpCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto place = ctx.GetPlace(); - PADDLE_ENFORCE(is_gpu_place(place), - "CAllReduce op can run on gpu place only for now."); -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - auto in = ctx.Input("X"); - auto out = ctx.Output("Out"); - - int rid = ctx.Attr("ring_id"); - auto comm = platform::NCCLCommContext::Instance().Get(rid); - int nranks = comm->nranks(); - - auto out_dims = in->dims(); - out_dims[0] = out_dims[0] / nranks; - out->mutable_data(out_dims, place); - - int64_t recv_numel = in->numel() / nranks; - const T* send_buff = in->data(); - T* recv_buff = out->data(); - int dtype = platform::ToNCCLDataType(in->type()); - - cudaStream_t stream = nullptr; - if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); - } else { - stream = comm->stream(); - } - - PADDLE_ENFORCE(platform::dynload::ncclReduceScatter( - send_buff, recv_buff, recv_numel, static_cast(dtype), - ncclSum, comm->comm(), stream)); -#else - PADDLE_THROW("PaddlePaddle should compile with GPU."); -#endif + PADDLE_THROW("Unimplemented cpu kernel for CReduceScatterOp."); } }; diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index 965761dc158..fe74fc59773 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -15,12 +15,12 @@ limitations under the License. */ #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include #endif -#include -#include + #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" + #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include "paddle/fluid/platform/collective_helper.h" #endif @@ -40,7 +40,6 @@ class CSyncCalcStreamOp : public framework::OperatorBase { const platform::Place& place) const override { PADDLE_ENFORCE(is_gpu_place(place), "Sync stream op can run on gpu place only for now."); - #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) auto dev_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place)); @@ -57,12 +56,12 @@ class CSyncCalcStreamOp : public framework::OperatorBase { class CSyncCalcStreamOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddInput("X", "(Tensor) Dependency of last param need to sync"); - AddOutput("Out", "(Tensor) Dependency of last param need to sync"); + AddInput("X", "(Tensor) Dependency of the variable need to sync"); + AddOutput("Out", "(Tensor) Dependency of the variable need to sync"); AddComment(R"DOC( -***Sync Operator*** +CSyncCalcStream Operator -Call cuda stream synchronize. +Call calculation stream synchronization. )DOC"); } }; diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 6fbb5b8cb11..5170356165f 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -11,11 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include #endif -#include -#include + #include #include "paddle/fluid/framework/lod_tensor.h" @@ -57,13 +57,13 @@ class CSyncCommStreamOp : public framework::OperatorBase { class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddInput("X", "(Tensor) Dependency of last param need to sync"); - AddOutput("Out", "(Tensor) Dependency of last param need to sync"); + AddInput("X", "(Tensor) Dependency of the variable need to sync"); + AddOutput("Out", "(Tensor) Dependency of the variable need to sync"); AddAttr("ring_id", "(int default 0) ring id.").SetDefault(0); AddComment(R"DOC( -***Sync Operator*** +CSyncCommStream Operator -Call nccl stream synchronize. +Call communication stream synchronization. )DOC"); } }; diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 49f3e0c7369..ddd242cda83 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// #ifndef _WIN32 #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include "paddle/fluid/platform/collective_helper.h" -#include +#include +#include #include "paddle/fluid/platform/dynload/nccl.h" @@ -34,24 +34,23 @@ class NCCLCommImpl : public NCCLComm { void set_rank(int rank) { rank_ = rank; } int rank() const override { return rank_; } - void set_local_rank(int local_rank) { local_rank_ = local_rank; } - int local_rank() const override { return local_rank_; } - - void set_comm(ncclComm_t comm) { comm_ = comm; } - ncclComm_t comm() const override { return comm_; } + int device_id() const override { + return boost::get(dev_ctx_->GetPlace()).device; + } - void set_dev_ctx(CUDADeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } - CUDADeviceContext* DevCtx() const override { return dev_ctx_; } + ncclComm_t comm() const override { return dev_ctx_->nccl_comm(); } cudaStream_t stream() const override { return dev_ctx_->stream(); } + void set_dev_ctx(std::unique_ptr&& dev_ctx) { + dev_ctx_ = std::move(dev_ctx); + } + private: int ring_id_; int nranks_; int rank_; - int local_rank_; - ncclComm_t comm_; - CUDADeviceContext* dev_ctx_; + std::unique_ptr dev_ctx_; }; // NOTE: not thread-safe @@ -73,13 +72,15 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, PADDLE_ENFORCE( platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); + std::unique_ptr dev_ctx( + new CUDADeviceContext(CUDAPlace(dev_id))); + dev_ctx->set_nccl_comm(comm); + NCCLCommImpl* communicator = new NCCLCommImpl; communicator->set_ring_id(ring_id); communicator->set_nranks(nranks); communicator->set_rank(rank); - communicator->set_local_rank(dev_id); - communicator->set_comm(comm); - communicator->set_dev_ctx(dev_ctx_map_.at(dev_id).get()); + communicator->set_dev_ctx(std::move(dev_ctx)); comm_map_.emplace(ring_id, std::unique_ptr(communicator)); diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index 97d94175928..7479ebaf7d2 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -// #ifndef _WIN32 -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #pragma once +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #include #include #include @@ -53,10 +52,9 @@ class NCCLComm { virtual int ring_id() const = 0; virtual int nranks() const = 0; virtual int rank() const = 0; - virtual int local_rank() const = 0; + virtual int device_id() const = 0; virtual ncclComm_t comm() const = 0; virtual cudaStream_t stream() const = 0; - virtual CUDADeviceContext* DevCtx() const = 0; virtual ~NCCLComm() = default; }; @@ -73,16 +71,6 @@ class NCCLCommContext { NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank, int dev_id, int ring_id = 0); - CUDADeviceContext* DevCtx(int dev_id) const { - PADDLE_ENFORCE(dev_ctx_map_.count(dev_id), - "CUDADeviceContext at device %d has not been initialized"); - return dev_ctx_map_.at(dev_id).get(); - } - - CUDADeviceContext* DevCtx(platform::Place p) const { - return DevCtx(boost::get(p).device); - } - // retrieve a communicator by the ring id NCCLComm* Get(int ring_id) const { PADDLE_ENFORCE(comm_map_.count(ring_id), diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 3f0fe62fec3..633e3259ada 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -46,7 +46,6 @@ void BindConstValue(pybind11::module* m) { .value("Loss", framework::OpRole::kLoss) .value("RPC", framework::OpRole::kRPC) .value("Dist", framework::OpRole::kDist) - .value("Collective", framework::OpRole::kCollective) .value("LRSched", framework::OpRole::kLRSched); op_proto_and_checker_maker.def( diff --git a/python/paddle/fluid/tests/unittests/collective_allreduce_op.py b/python/paddle/fluid/tests/unittests/collective_allreduce_op.py index 69bd6f99044..9aef8879cab 100644 --- a/python/paddle/fluid/tests/unittests/collective_allreduce_op.py +++ b/python/paddle/fluid/tests/unittests/collective_allreduce_op.py @@ -42,7 +42,6 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase): def get_model(self, main_prog, startup_program): ring_id = 0 - reduce_type = 0 with fluid.program_guard(main_prog, startup_program): tindata = layers.data( name="tindata", shape=[10, 1000], dtype='float32') @@ -53,10 +52,9 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase): persistable=False, stop_gradient=False) main_prog.global_block().append_op( - type="c_allreduce", + type="c_allreduce_sum", inputs={'X': tindata}, - attrs={'ring_id': ring_id, - 'reduce_type': reduce_type}, + attrs={'ring_id': ring_id}, outputs={'Out': toutdata}) main_prog.global_block().append_op( type="c_sync_comm_stream", diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index df5cdbc104f..18cf1fec417 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'); # you may not use this file except in compliance with the License. @@ -37,8 +37,8 @@ class Collective(object): ''' ''' - def __init__(self): - self.global_ring_id = 0 + def __init__(self, nrings): + self.nrings = nrings self.endpoints = None self.current_endpoint = None self.nranks = None @@ -90,9 +90,10 @@ class Collective(object): raise NotImplementedError('call the inherited method of subclasses') def _transpile_startup_program(self): - self._init_communicator(self.startup_program, self.current_endpoint, - self.endpoints, self.rank, self.global_ring_id, - self.wait_port) + for ring_id in range(self.nrings): + self._init_communicator(self.startup_program, self.current_endpoint, + self.endpoints, self.rank, ring_id, + self.wait_port) self._broadcast_params() def _init_communicator(self, program, current_endpoint, endpoints, rank, @@ -116,7 +117,7 @@ class Collective(object): 'rank': rank, 'endpoint': current_endpoint, 'other_endpoints': other_endpoints, - self.op_role_key: OpRole.Collective + self.op_role_key: OpRole.Forward }) block.append_op( type='c_comm_init', @@ -126,29 +127,31 @@ class Collective(object): 'nranks': nranks, 'rank': rank, 'ring_id': ring_id, - self.op_role_key: OpRole.Collective + self.op_role_key: OpRole.Forward }) def _broadcast_params(self): block = self.startup_program.global_block() - for var in block.iter_parameters(): + ring_id = -1 + for param in block.iter_parameters(): + ring_id = (ring_id + 1) % self.nrings block.append_op( type='c_broadcast', - inputs={'X': var}, - outputs={'Out': var}, + inputs={'X': param}, + outputs={'Out': param}, attrs={ - 'ring_id': self.global_ring_id, + 'ring_id': ring_id, 'root': 0, - self.op_role_key: OpRole.Collective + self.op_role_key: OpRole.Forward }) - block.append_op( - type='c_sync_comm_stream', - inputs={'X': var}, - outputs={'Out': var}, - attrs={ - 'ring_id': self.global_ring_id, - self.op_role_key: OpRole.Collective - }) + + for ring_id in range(self.nrings): + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={'ring_id': ring_id, + self.op_role_key: OpRole.Forward}) def _is_loss_grad_op(self, op): if self.op_role_key not in op.attr_names: @@ -173,8 +176,8 @@ class GradAllReduce(Collective): ''' ''' - def __init__(self): - Collective.__init__(self) + def __init__(self, nrings=2): + Collective.__init__(self, nrings) def _transpile_main_program(self): self._insert_scale_loss_grad_ops() @@ -196,11 +199,13 @@ class GradAllReduce(Collective): outputs={'Out': loss_grad_var}, attrs={ 'scale': 1.0 / self.nranks, - self.op_role_key: OpRole.Collective + self.op_role_key: OpRole.Backward }) def _insert_allreduce_ops(self): block = self.main_program.global_block() + ring_id = -1 + grad = None for idx, op in reversed(list(enumerate(block.ops))): if self._is_backward_op(op) and \ self.op_role_var_key in op.attr_names: @@ -208,41 +213,50 @@ class GradAllReduce(Collective): if len(op_role_var) == 0: continue - assert len(op_role_var) % 2 == 0 - block._insert_op( - idx + 1, - type='c_sync_calc_stream', - inputs={'X': block.vars[grad]}, - outputs={'Out': block.vars[grad]}, - attrs={self.op_role_key: OpRole.Collective}) - - offset = 2 + offset = idx for i in range(0, len(op_role_var), 2): - grad = op_role_var[i + 1] + param = block.vars[op_role_var[i]] + grad = block.vars[op_role_var[i + 1]] + if offset == idx: + offset += 1 + block._insert_op( + offset, + type='c_sync_calc_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={self.op_role_key: OpRole.Backward}) + offset += 1 + + # As we search ops reversedly, we should insert c_allreduce_sum + # op in the same way to keep the ring_id alternate + ring_id = (ring_id + 1) % self.nrings block._insert_op( - idx + offset, - type='c_allreduce', - inputs={'X': [block.vars[grad]]}, - outputs={'Out': [block.vars[grad]]}, + offset, + type='c_allreduce_sum', + inputs={'X': grad}, + outputs={'Out': grad}, attrs={ - 'reduce_type': 0, - self.op_role_key: OpRole.Collective + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward }) - offset += 1 + + if grad is None: + return for idx, op in enumerate(block.ops): if self._is_optimizer_op(op): - block._insert_op( - idx, - type='c_sync_comm_stream', - inputs={'X': block.vars[grad]}, - outputs={'Out': block.vars[grad]}, - attrs={ - 'ring_id': self.global_ring_id, - self.op_role_key: OpRole.Collective - }) + for ring_id in range(self.nrings): + block._insert_op( + idx + ring_id, + type='c_sync_comm_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': ring_id, + self.op_role_key: OpRole.Backward + }) break @@ -250,8 +264,8 @@ class LocalSGD(Collective): ''' ''' - def __init__(self): - Collective.__init__(self) + def __init__(self, nrings=2): + Collective.__init__(self, nrings) self.snapshot_key = '@SNAPSHOT' def _transpile_startup_program(self): @@ -268,7 +282,7 @@ class LocalSGD(Collective): type='assign', inputs={'X': [param]}, outputs={'Out': [snapshot]}, - attrs={self.op_role_key: OpRole.Collective}) + attrs={self.op_role_key: OpRole.Forward}) def snapshot_name(self, param_name): return param_name + self.snapshot_key @@ -276,6 +290,7 @@ class LocalSGD(Collective): def _transpile_main_program(self): block = self.main_program.global_block() ordered_param_snapshot = [] + ring_id = -1 for idx, op in reversed(list(enumerate(block.ops))): if self._is_update_op(op): param = block.vars[op.input('Param')[0]] @@ -291,33 +306,33 @@ class LocalSGD(Collective): inputs={'X': [snapshot], 'Y': [param]}, outputs={'Out': [param]}, - attrs={self.op_role_key: OpRole.Collective}) + attrs={self.op_role_key: OpRole.Optimize}) block._insert_op( idx + 2, type='c_sync_calc_stream', inputs={'X': param}, outputs={'Out': param}, - attrs={self.op_role_key: OpRole.Collective}) + attrs={self.op_role_key: OpRole.Optimize}) + ring_id = (ring_id + 1) % self.nrings block._insert_op( idx + 3, - type='c_allreduce', + type='c_allreduce_sum', inputs={'X': [param]}, outputs={'Out': [param]}, attrs={ - 'reduce_type': 0, - self.op_role_key: OpRole.Collective + 'ring_id': ring_id, + self.op_role_key: OpRole.Optimize }) ordered_param_snapshot.append((param, snapshot)) - block.append_op( - type='c_sync_comm_stream', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': self.global_ring_id, - self.op_role_key: OpRole.Collective - }) + for ring_id in range(self.nrings): + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={'ring_id': ring_id, + self.op_role_key: OpRole.Optimize}) for param_snapshot in reversed(ordered_param_snapshot): param = param_snapshot[0] @@ -328,16 +343,16 @@ class LocalSGD(Collective): outputs={'Out': [param]}, attrs={ 'scale': 1.0 / self.nranks, - self.op_role_key: OpRole.Collective + self.op_role_key: OpRole.Optimize }) block.append_op( type='elementwise_sub', inputs={'X': [snapshot], 'Y': [param]}, outputs={'Out': [param]}, - attrs={self.op_role_key: OpRole.Collective}) + attrs={self.op_role_key: OpRole.Optimize}) block.append_op( type='assign', inputs={'X': [param]}, outputs={'Out': [snapshot]}, - attrs={self.op_role_key: OpRole.Collective}) + attrs={self.op_role_key: OpRole.Optimize}) -- GitLab