diff --git a/paddle/operators/nccl/nccl_gpu_common.cc b/paddle/operators/nccl/nccl_gpu_common.cc new file mode 100644 index 0000000000000000000000000000000000000000..0144d93969a774bcf528ec1b07e132c36fcb0c26 --- /dev/null +++ b/paddle/operators/nccl/nccl_gpu_common.cc @@ -0,0 +1,9 @@ +#include "paddle/operators/nccl/nccl_gpu_common.h" + +namespace paddle { +namespace platform { + + + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h index 55e7d8db661f29683889b84e86c5abba050d4c1b..cace878079709751a8a50450c8c07e9521a47017 100644 --- a/paddle/operators/nccl/nccl_gpu_common.h +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -1,11 +1,31 @@ #pragma once #include +#include +#include +#include +#include +#include + #include "paddle/platform/device_context.h" namespace paddle { namespace platform { + +// class NCCLContext : public DeviceContext { +// public: +// explicit NCCLContext(GPUPlace place); +// virtual ~NCCLContext(); + +// private: +// std::vector gpu_ids_; +// std::vector streams_; +// }; + + +class Communicator; + class NCCLManager { public: static NCCLManager* Get() { @@ -13,23 +33,28 @@ class NCCLManager { return &m; } - NCCLManager() { _comms.resize(_gpu_worlds.size()); } + NCCLManager() { + } ~NCCLManager() {} + // for each card only have one communicator + Communicator* GetCommunicator() const; + private: - std::vector _comms; - std::vector _gpu_worlds; -}; + struct Communicator { + std::vector comms_; + std::vector streams_; // do not own + std::vector events_; + int root_gpu; + }; -class NCCLContext : public DeviceContext { - public: - explicit NCCLContext(GPUPlace place); - virtual ~NCCLContext(); + // the gpu id list available. Note that only support + // whole world communication. + std::vector _gpu_worlds; - private: - std::vector _gpu_ids; - std::vector _streams; - int root_gpu; + // communicator list + std::unordered_map comms_; }; -} -} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/nccl/nccl_ops.cc b/paddle/operators/nccl/nccl_ops.cc index a4bd8b9c0f613aef84448f7ffce8be47b6575376..4b7bfa7234aae127fad1c57002fbe81a2430abf2 100644 --- a/paddle/operators/nccl/nccl_ops.cc +++ b/paddle/operators/nccl/nccl_ops.cc @@ -1,17 +1,28 @@ -#include "paddle/framework/op_registry.h" -#include "paddle/operators/nccl/nccl_gpu_common.h" +#include "paddle/operators/nccl/nccl_ops.h" namespace paddle { namespace operators { // AllreduceOp -class NCCLAllreduceOp : public framework::OperatorWithKernel { +class NCCLAllReduceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: // allreduce do nothing in infershape - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + " Input(X) of AllReduce op input should not be NULL"); + auto ins = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); + PADDLE_ENFORCE(ins.size() == outs.size(), "Input(X) and Output(Out) must have same size"); + for(size_t i=0; i < ins.size(); ++i) { + outs[i]->Resize(ins[i]->dims()); + } + std::string reduction = ctx.Attr("reduction"); + PADDLE_ENFORCE( (reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), "invalid reduction!"); + } }; template @@ -19,30 +30,67 @@ class NCCLAllreduceOp : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *ctx = static_cast(context.device_context()); - // auto *comm = ; - // auto *src = ; - // ncclAllReduce(src, dest, ) } }; // BcastSendOp template -class NCCLBroadcastSendOp final : public framework::OperatorWithKernel { +class NCCLBcastSendOp final : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + " Input(X) of BcastSend op input should not be NULL"); + } }; // BcastRecvOp template -class NCCLBroadcastRecvOp final : public framework::OperatorWithKernel { +class NCCLBcastRecvOp final : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + " Input(X) of BcastRecv op input should not be NULL"); + } +}; + + +class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of AllReduce op"); + AddOutput("Out", "The output of AllReduce op"); + AddAttr("reduction: {'min', 'max', 'prod', 'sum'}."); + AddComment(R"DOC( + AllReduce the input tensors. + )DOC"); + } }; + +class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { + NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of BcastSend op"); + AddComment(R"DOC( + BcastSend the tensors. + )DOC"); + } +}; + +class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { + NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "The output of BcastRecv op"); + AddComment(R"DOC( + BcastRecv the tensors. + )DOC"); + } +}; + } } diff --git a/paddle/operators/nccl/nccl_ops.h b/paddle/operators/nccl/nccl_ops.h index 0d78c606395f36d0ca9f50784e67f9f7dcfde6aa..3664d2f55cfbea06c8c5fe7236cbaabb5fc5172e 100644 --- a/paddle/operators/nccl/nccl_ops.h +++ b/paddle/operators/nccl/nccl_ops.h @@ -2,6 +2,59 @@ #include "paddle/framework/op_registry.h" #include "paddle/operators/nccl/nccl_gpu_common.h" +#include + namespace paddle { -namespace operators {} +namespace operators { + + +template +class NCCLTypeWrapper; + +template<> +class NCCLTypeWrapper { + static const ncclDataType_t type = ncclFloat; +}; + +template<> +class NCCLTypeWrapper { + static const ncclDataType_t type = ncclDouble; +}; + + + +template +class NCCLAllReduceKernel : public framework::OpKernel { +public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); + std::string reduction = ctx.Attr("reduction"); + ncclRedOp_t op_type; + if (reduction == "ncclSum") { + op_type = ncclSum; + } else if (reduction == "ncclProd") { + op_type = ncclProd; + } else if (reduction == "ncclMin") { + op_type = ncclMin; + } else (reduction == "ncclMax") { + op_type = ncclMax; + } + + auto dev_ctx = ctx.device_context(); + + for( size_t i=0; i < ins.size(); ++i) { + ncclAllReduce(ins[i]->data(), + outs[i]->mutable_data(), + outs[i]->numel() * sizeof(T), + NCCLTypeWrapper::type, + op_type, + comm, + stream); + } + } +}; + + +} }