diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index 4f3a2f2768f8cc1d1f257977d806ed9b7fa8b8cb..3744d1b4707be68185565d8e94f2a165ec42739b 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -94,6 +94,11 @@ class NCCLReduceOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), " Input(X) of Reduce op input should not be NULL"); + std::string reduction = ctx->Attrs().Get("reduction"); + PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), + "invalid reduction."); + auto x_dims = ctx->GetInputsDim("X"); ctx->SetOutputsDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); @@ -150,6 +155,9 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input of Reduce op"); AddInput("Communicator", "Communicator for communicating between gpus"); AddOutput("Out", "The output of Reduce op"); + AddAttr("reduction", + "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") + .SetDefault("ncclSum"); AddAttr("root", "root gpu of the parameter. if not " "set(platform::kInvalidGPUId). hashed by name.") diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu index cc01db80ca1a14d1dce8f33ee766b402794d73b0..f8b3b8a8baf77a1091629969b56ace826cdbd371 100644 --- a/paddle/operators/nccl_op.cu +++ b/paddle/operators/nccl_op.cu @@ -49,7 +49,6 @@ class NCCLAllReduceKernel : public framework::OpKernel { auto outs = ctx.MultiOutput("Out"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t reduction_op_ = ncclSum; if (reduction == "ncclMin") { @@ -101,8 +100,23 @@ class NCCLReduceKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); // x0, x1, x2 auto outs = ctx.MultiOutput("Out"); - int root = ctx.Attr("root"); + std::string reduction = ctx.Attr("reduction"); + ncclRedOp_t reduction_op_ = ncclSum; + + if (reduction == "ncclMin") { + reduction_op_ = ncclMin; + } else if (reduction == "ncclMax") { + reduction_op_ = ncclMax; + } else if (reduction == "ncclSum") { + reduction_op_ = ncclSum; + } else if (reduction == "ncclProd") { + reduction_op_ = ncclProd; + } else { + PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum."); + } + + int root = ctx.Attr("root"); auto* comm = ctx.Input("Communicator"); auto stream = reinterpret_cast( @@ -128,7 +142,8 @@ class NCCLReduceKernel : public framework::OpKernel { PADDLE_ENFORCE(platform::dynload::ncclReduce( ins[i]->data(), recvbuffer, ins[i]->numel(), - NCCLTypeWrapper::type, ncclSum, root, comm->comms_[idx], stream)); + NCCLTypeWrapper::type, reduction_op_, root, comm->comms_[idx], + stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "