提交 6cce5268 编写于 作者: D Dong Zhihong

"fixed based on comment"

上级 16a39d24
......@@ -290,11 +290,12 @@ class ExecutionContext {
return device_context_;
}
//! Get a input which has multiple variables.
//! Get variables vector with same input name.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
//! Get an output which has multiple variables.
//! Get variables vector with same output name.
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
......
......@@ -30,6 +30,8 @@
namespace paddle {
namespace platform {
constexpr int kInvalidGPUId = -1;
struct Communicator {
std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_;
......
......@@ -69,10 +69,10 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputsDim("X");
// std::string reduction = ctx->Attrs().Get<std::string>("reduction");
// PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
// reduction == "ncclMin" || reduction == "ncclMax"),
// "invalid reduction.");
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
reduction == "ncclMin" || reduction == "ncclMax"),
"invalid reduction.");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -115,7 +115,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel {
" Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != -1, "Bcast root must be set.");
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
......@@ -132,9 +132,9 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
// AddAttr<std::string>("reduction",
// "{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}.");
// AddAttr<std::vector<int>>("gpus", "gpu id lists");
AddAttr<std::string>("reduction",
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
.SetDefault("ncclSum");
AddComment(R"DOC(
AllReduce the input tensors.
)DOC");
......@@ -151,8 +151,9 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op");
AddAttr<int>("root",
"root gpu of the parameter. if not set(-1). hashed by name.")
.SetDefault(-1);
"root gpu of the parameter. if not "
"set(platform::kInvalidGPUId). hashed by name.")
.SetDefault(platform::kInvalidGPUId);
AddComment(R"DOC(
Reduce the tensors)DOC");
}
......@@ -168,8 +169,9 @@ class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Bcast");
AddAttr<int>("root",
"root gpu of the parameter. if not set(-1). hashed by name.")
.SetDefault(-1);
"root gpu of the parameter. if not "
"set(platform::kInvalidGPUId). hashed by name.")
.SetDefault(platform::kInvalidGPUId);
AddComment(R"DOC(
Bcast the tensors.
)DOC");
......
......@@ -48,11 +48,28 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") {
reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") {
reduction_op_ = ncclMax;
} else if (reduction == "ncclSum") {
reduction_op_ = ncclSum;
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum.");
}
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
......@@ -64,7 +81,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, ncclSum,
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
......@@ -98,7 +115,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) {
if (root == -1) {
if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms_.size();
}
T* recvbuffer = nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册