提交 52200523 编写于 作者: D Dong Zhihong

"polish code based on comment"

上级 6cce5268
...@@ -94,6 +94,11 @@ class NCCLReduceOp : public framework::OperatorWithKernel { ...@@ -94,6 +94,11 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of Reduce op input should not be NULL"); " Input(X) of Reduce op input should not be NULL");
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction.");
auto x_dims = ctx->GetInputsDim("X"); auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims); ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -150,6 +155,9 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -150,6 +155,9 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The input of Reduce op"); AddInput("X", "The input of Reduce op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op"); AddOutput("Out", "The output of Reduce op");
AddAttr<std::string>("reduction",
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddAttr<int>("root", AddAttr<int>("root",
"root gpu of the parameter. if not " "root gpu of the parameter. if not "
"set(platform::kInvalidGPUId). hashed by name.") "set(platform::kInvalidGPUId). hashed by name.")
......
...@@ -49,7 +49,6 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -49,7 +49,6 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto outs = ctx.MultiOutput<LoDTensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction"); std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum; ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") { if (reduction == "ncclMin") {
...@@ -101,8 +100,23 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -101,8 +100,23 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2 auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto outs = ctx.MultiOutput<LoDTensor>("Out");
int root = ctx.Attr<int>("root");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") {
reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") {
reduction_op_ = ncclMax;
} else if (reduction == "ncclSum") {
reduction_op_ = ncclSum;
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum.");
}
int root = ctx.Attr<int>("root");
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
...@@ -128,7 +142,8 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -128,7 +142,8 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclReduce( PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(), ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, ncclSum, root, comm->comms_[idx], stream)); NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send " VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册