diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h new file mode 100644 index 0000000000000000000000000000000000000000..017492a0d8fc6652c6d11d4eb3b8829ea0509daa --- /dev/null +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -0,0 +1,39 @@ +#pragma once +#include + +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace platform { + +class NCCLManager { + public: + static NCCLManager* Get() { + static NCCLManager m; + return &m; + } + + NCCLManager() { _comms.resize(_gpu_worlds.size()); } + ~NCCLManager() {} + + private: + // clang-format off + std::vector _comms; + std::vector _gpu_worlds; + // clang-format on +}; + +class NCCLContext : public DeviceContext { + public: + explicit NCCLContext(GPUPlace place); + virtual ~NCCLContext(); + + private: + // clang-format off + std::vector _gpu_ids; + std::vector _streams; + int root_gpu; + // clang-format on +}; +} +} diff --git a/paddle/operators/nccl/nccl_ops.cc b/paddle/operators/nccl/nccl_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4bd8b9c0f613aef84448f7ffce8be47b6575376 --- /dev/null +++ b/paddle/operators/nccl/nccl_ops.cc @@ -0,0 +1,48 @@ +#include "paddle/framework/op_registry.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" + +namespace paddle { +namespace operators { + +// AllreduceOp +class NCCLAllreduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + // allreduce do nothing in infershape + void InferShape(const framework::InferShapeContext &ctx) const override {} +}; + +template +class NCCLAllreduceOp : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *ctx = static_cast(context.device_context()); + // auto *comm = ; + // auto *src = ; + // ncclAllReduce(src, dest, ) + } +}; + +// BcastSendOp +template +class NCCLBroadcastSendOp final : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override {} +}; + +// BcastRecvOp +template +class NCCLBroadcastRecvOp final : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override {} +}; +} +} diff --git a/paddle/operators/nccl/nccl_ops.h b/paddle/operators/nccl/nccl_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0d78c606395f36d0ca9f50784e67f9f7dcfde6aa --- /dev/null +++ b/paddle/operators/nccl/nccl_ops.h @@ -0,0 +1,7 @@ +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" + +namespace paddle { +namespace operators {} +}