未验证 提交 14fe40aa 编写于 作者: D dzhwinter 提交者: GitHub

Refine/nccl (#9009)

* "Refine nccl op"

* "refine code "

* "refine nccl code"
上级 788c600e
...@@ -104,19 +104,38 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { ...@@ -104,19 +104,38 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
" Input(Communicator) of AllReduce op input should not be NULL"); " Input(Communicator) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of AllReduce op output should not be NULL"); " Output(Out) of AllReduce op output should not be NULL");
auto x_dims = ctx->GetInputsDim("X");
std::string reduction = ctx->Attrs().Get<std::string>("reduction"); std::string reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"), reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction."); "invalid reduction.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims); ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
// AllReduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction",
"(string, default 'ncclSum') "
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddComment(R"DOC(
NCCLAllReduce Operator.
AllReduce the input tensors.
)DOC");
}
};
// ReduceOp // ReduceOp
class NCCLReduceOp : public framework::OperatorWithKernel { class NCCLReduceOp : public framework::OperatorWithKernel {
public: public:
...@@ -143,50 +162,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel { ...@@ -143,50 +162,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
} }
}; };
// BcastOp
class NCCLBcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
" Input(Communicator) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// AllreduceOp
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction",
"(string, default 'ncclSum') "
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddComment(R"DOC(
NCCLAllReduce Operator.
AllReduce the input tensors.
)DOC");
}
};
// ReduceOp // ReduceOp
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -213,6 +188,29 @@ Reduce the tensors. ...@@ -213,6 +188,29 @@ Reduce the tensors.
} }
}; };
// BcastOp
class NCCLBcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
" Input(Communicator) of Bcast op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
// BcastOp // BcastOp
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
......
...@@ -43,13 +43,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -43,13 +43,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto* x = ctx.Input<LoDTensor>("X");
auto ins = ctx.MultiInput<LoDTensor>("X"); auto* out = ctx.Output<LoDTensor>("Out");
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto* comm = ctx.Input<Communicator>("Communicator");
std::string reduction = ctx.Attr<std::string>("reduction"); std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") { if (reduction == "ncclMin") {
reduction_op_ = ncclMin; reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") { } else if (reduction == "ncclMax") {
...@@ -61,30 +60,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -61,30 +60,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_THROW("Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = ctx.cuda_device_context().stream();
// device id // device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
VLOG(3) << "gpu : "
for (size_t i = 0; i < ins.size(); ++i) { << " invoke allreduce. send " << x->numel() << " recv "
VLOG(1) << "gpu : " << out->numel();
<< " invoke allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()), x->data<T>(), out->mutable_data<T>(ctx.GetPlace()), out->numel(),
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_, NCCLTypeWrapper<T>::type, reduction_op_, comm->comms().at(idx),
comm->comms().at(idx), stream)); ctx.cuda_device_context().stream()));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); VLOG(3) << "gpu : "
<< " finished allreduce. send " << x->numel() << " recv "
VLOG(1) << "gpu : " << out->numel();
<< " finished allreduce. send " << ins[i]->numel() << " recv "
<< outs[i]->numel();
}
} }
}; };
...@@ -94,13 +82,13 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -94,13 +82,13 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto x = ctx.Input<LoDTensor>("X"); // x0, x1, x2
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2 auto out = ctx.Output<LoDTensor>("Out");
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto* comm = ctx.Input<Communicator>("Communicator");
int root = ctx.Attr<int>("root");
std::string reduction = ctx.Attr<std::string>("reduction"); std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") { if (reduction == "ncclMin") {
reduction_op_ = ncclMin; reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") { } else if (reduction == "ncclMax") {
...@@ -112,40 +100,21 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -112,40 +100,21 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_THROW("Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id // device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) {
if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms().size();
}
T* recvbuffer = nullptr; T* recvbuffer = nullptr;
if (root == gpu_id) { if (root == gpu_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace()); recvbuffer = out->mutable_data<T>(ctx.GetPlace());
} }
VLOG(3) << "gpu : " << gpu_id << " invoke reduce. send " << x->numel()
VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send " << " recv " << out->numel();
<< ins[i]->numel() << " recv " << outs[i]->numel();
PADDLE_ENFORCE(platform::dynload::ncclReduce( PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(), x->data<T>(), recvbuffer, x->numel(), NCCLTypeWrapper<T>::type,
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx), reduction_op_, root, comm->comms().at(idx),
stream)); ctx.cuda_device_context().stream()));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); VLOG(3) << "gpu : " << gpu_id << " finished reduce. send " << x->numel()
<< " recv " << out->numel();
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
<< ins[i]->numel() << " recv " << outs[i]->numel();
}
} }
}; };
...@@ -155,47 +124,27 @@ class NCCLBcastKernel : public framework::OpKernel<T> { ...@@ -155,47 +124,27 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id // device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
if (idx == root) { if (idx == root) {
auto ins = ctx.MultiInput<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
for (size_t i = 0; i < ins.size(); ++i) { VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. send " << x->numel();
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send "
<< ins[i]->numel();
VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type, (void*)x->data<T>(), x->numel(), NCCLTypeWrapper<T>::type, root,
root, comm->comms().at(idx), stream)); comm->comms().at(idx), ctx.cuda_device_context().stream()));
VLOG(1) << " after ncclBcast"; VLOG(3) << "gpu : " << gpu_id << " finished Bcast.";
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast.";
}
} else { } else {
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) { VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. recv buffer "
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " << framework::product(out->dims());
<< framework::product(outs[i]->dims());
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(), out->mutable_data<T>(ctx.GetPlace()), out->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream)); NCCLTypeWrapper<T>::type, root, comm->comms().at(idx),
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); ctx.cuda_device_context().stream()));
VLOG(3) << "gpu : " << gpu_id << " finished Bcast. recv " << out->numel();
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
<< outs[i]->numel();
}
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册