提交 0fa34db7 编写于 作者: D dzhwinter

nccl init

上级 408e21af
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace platform {
} // namespace operators
} // namespace paddle
#pragma once #pragma once
#include <nccl.h> #include <nccl.h>
#include <memory>
#include <mutex>
#include <condition_variable>
#include <vector>
#include <unordered_map>
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
// class NCCLContext : public DeviceContext {
// public:
// explicit NCCLContext(GPUPlace place);
// virtual ~NCCLContext();
// private:
// std::vector<int> gpu_ids_;
// std::vector<cudaStream_t> streams_;
// };
class Communicator;
class NCCLManager { class NCCLManager {
public: public:
static NCCLManager* Get() { static NCCLManager* Get() {
...@@ -13,23 +33,28 @@ class NCCLManager { ...@@ -13,23 +33,28 @@ class NCCLManager {
return &m; return &m;
} }
NCCLManager() { _comms.resize(_gpu_worlds.size()); } NCCLManager() {
}
~NCCLManager() {} ~NCCLManager() {}
// for each card only have one communicator
Communicator* GetCommunicator() const;
private: private:
std::vector<ncclComm_t> _comms; struct Communicator {
std::vector<int> _gpu_worlds; std::vector<ncclComm_t> comms_;
}; std::vector<cudaStream_t*> streams_; // do not own
std::vector<cudaEvent_t> events_;
int root_gpu;
};
class NCCLContext : public DeviceContext { // the gpu id list available. Note that only support
public: // whole world communication.
explicit NCCLContext(GPUPlace place); std::vector<int> _gpu_worlds;
virtual ~NCCLContext();
private: // communicator list
std::vector<int> _gpu_ids; std::unordered_map<std::string /* key*/, Communicator*> comms_;
std::vector<cudaStream_t> _streams;
int root_gpu;
}; };
}
} } // namespace operators
} // namespace paddle
#include "paddle/framework/op_registry.h" #include "paddle/operators/nccl/nccl_ops.h"
#include "paddle/operators/nccl/nccl_gpu_common.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// AllreduceOp // AllreduceOp
class NCCLAllreduceOp : public framework::OperatorWithKernel { class NCCLAllReduceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
// allreduce do nothing in infershape // allreduce do nothing in infershape
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
" Input(X) of AllReduce op input should not be NULL");
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
PADDLE_ENFORCE(ins.size() == outs.size(), "Input(X) and Output(Out) must have same size");
for(size_t i=0; i < ins.size(); ++i) {
outs[i]->Resize(ins[i]->dims());
}
std::string reduction = ctx.Attr<std::string>("reduction");
PADDLE_ENFORCE( (reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"), "invalid reduction!");
}
}; };
template <typename T> template <typename T>
...@@ -19,30 +30,67 @@ class NCCLAllreduceOp : public framework::OpKernel { ...@@ -19,30 +30,67 @@ class NCCLAllreduceOp : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *ctx = static_cast<NCCLContext *>(context.device_context()); auto *ctx = static_cast<NCCLContext *>(context.device_context());
// auto *comm = ;
// auto *src = ;
// ncclAllReduce(src, dest, )
} }
}; };
// BcastSendOp // BcastSendOp
template <typename T> template <typename T>
class NCCLBroadcastSendOp final : public framework::OperatorWithKernel { class NCCLBcastSendOp final : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
" Input(X) of BcastSend op input should not be NULL");
}
}; };
// BcastRecvOp // BcastRecvOp
template <typename T> template <typename T>
class NCCLBroadcastRecvOp final : public framework::OperatorWithKernel { class NCCLBcastRecvOp final : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
" Input(X) of BcastRecv op input should not be NULL");
}
};
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction: {'min', 'max', 'prod', 'sum'}.");
AddComment(R"DOC(
AllReduce the input tensors.
)DOC");
}
}; };
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
AddComment(R"DOC(
BcastSend the tensors.
)DOC");
}
};
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output of BcastRecv op");
AddComment(R"DOC(
BcastRecv the tensors.
)DOC");
}
};
} }
} }
...@@ -2,6 +2,59 @@ ...@@ -2,6 +2,59 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h" #include "paddle/operators/nccl/nccl_gpu_common.h"
#include <string.h>
namespace paddle { namespace paddle {
namespace operators {} namespace operators {
template<typename Type>
class NCCLTypeWrapper;
template<>
class NCCLTypeWrapper<float> {
static const ncclDataType_t type = ncclFloat;
};
template<>
class NCCLTypeWrapper<double> {
static const ncclDataType_t type = ncclDouble;
};
template<typename T>
class NCCLAllReduceKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<Tensor>("X");
auto outs = ctx.MultiOutput<Tensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t op_type;
if (reduction == "ncclSum") {
op_type = ncclSum;
} else if (reduction == "ncclProd") {
op_type = ncclProd;
} else if (reduction == "ncclMin") {
op_type = ncclMin;
} else (reduction == "ncclMax") {
op_type = ncclMax;
}
auto dev_ctx = ctx.device_context();
for( size_t i=0; i < ins.size(); ++i) {
ncclAllReduce(ins[i]->data<T>(),
outs[i]->mutable_data<T>(),
outs[i]->numel() * sizeof(T),
NCCLTypeWrapper<T>::type,
op_type,
comm,
stream);
}
}
};
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册