提交 6cce5268 编写于 作者: D Dong Zhihong

"fixed based on comment"

上级 16a39d24
...@@ -290,11 +290,12 @@ class ExecutionContext { ...@@ -290,11 +290,12 @@ class ExecutionContext {
return device_context_; return device_context_;
} }
//! Get a input which has multiple variables. //! Get variables vector with same input name.
const std::vector<std::string>& Inputs(const std::string& name) const { const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name); return op_.Inputs(name);
} }
//! Get an output which has multiple variables.
//! Get variables vector with same output name.
const std::vector<std::string>& Outputs(const std::string& name) const { const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name); return op_.Outputs(name);
} }
......
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
namespace paddle { namespace paddle {
namespace platform { namespace platform {
constexpr int kInvalidGPUId = -1;
struct Communicator { struct Communicator {
std::vector<ncclComm_t> comms_; std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_; std::unordered_map<int, int> comm_id_map_;
......
...@@ -69,10 +69,10 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { ...@@ -69,10 +69,10 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputsDim("X"); auto x_dims = ctx->GetInputsDim("X");
// std::string reduction = ctx->Attrs().Get<std::string>("reduction"); std::string reduction = ctx->Attrs().Get<std::string>("reduction");
// PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
// reduction == "ncclMin" || reduction == "ncclMax"), reduction == "ncclMin" || reduction == "ncclMax"),
// "invalid reduction."); "invalid reduction.");
ctx->SetOutputsDim("Out", x_dims); ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -115,7 +115,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel { ...@@ -115,7 +115,7 @@ class NCCLBcastOp : public framework::OperatorWithKernel {
" Output(Out) of Bcast op output should not be NULL"); " Output(Out) of Bcast op output should not be NULL");
int root = ctx->Attrs().Get<int>("root"); int root = ctx->Attrs().Get<int>("root");
PADDLE_ENFORCE(root != -1, "Bcast root must be set."); PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
auto x_dims = ctx->GetInputsDim("X"); auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims); ctx->SetOutputsDim("Out", x_dims);
...@@ -132,9 +132,9 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -132,9 +132,9 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The input of AllReduce op"); AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op"); AddOutput("Out", "The output of AllReduce op");
// AddAttr<std::string>("reduction", AddAttr<std::string>("reduction",
// "{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}."); "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
// AddAttr<std::vector<int>>("gpus", "gpu id lists"); .SetDefault("ncclSum");
AddComment(R"DOC( AddComment(R"DOC(
AllReduce the input tensors. AllReduce the input tensors.
)DOC"); )DOC");
...@@ -151,8 +151,9 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -151,8 +151,9 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Reduce op"); AddOutput("Out", "The output of Reduce op");
AddAttr<int>("root", AddAttr<int>("root",
"root gpu of the parameter. if not set(-1). hashed by name.") "root gpu of the parameter. if not "
.SetDefault(-1); "set(platform::kInvalidGPUId). hashed by name.")
.SetDefault(platform::kInvalidGPUId);
AddComment(R"DOC( AddComment(R"DOC(
Reduce the tensors)DOC"); Reduce the tensors)DOC");
} }
...@@ -168,8 +169,9 @@ class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -168,8 +169,9 @@ class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of Bcast"); AddOutput("Out", "The output of Bcast");
AddAttr<int>("root", AddAttr<int>("root",
"root gpu of the parameter. if not set(-1). hashed by name.") "root gpu of the parameter. if not "
.SetDefault(-1); "set(platform::kInvalidGPUId). hashed by name.")
.SetDefault(platform::kInvalidGPUId);
AddComment(R"DOC( AddComment(R"DOC(
Bcast the tensors. Bcast the tensors.
)DOC"); )DOC");
......
...@@ -48,11 +48,28 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -48,11 +48,28 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<LoDTensor>("X"); auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto outs = ctx.MultiOutput<LoDTensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t reduction_op_ = ncclSum;
if (reduction == "ncclMin") {
reduction_op_ = ncclMin;
} else if (reduction == "ncclMax") {
reduction_op_ = ncclMax;
} else if (reduction == "ncclSum") {
reduction_op_ = ncclSum;
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum.");
}
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream(); .stream();
// device id // device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
...@@ -64,7 +81,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -64,7 +81,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()), ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, ncclSum, outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
comm->comms_[idx], stream)); comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
...@@ -98,7 +115,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -98,7 +115,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
auto ins_names = ctx.Inputs("X"); auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher; std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
if (root == -1) { if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms_.size(); root = hasher(ins_names[i]) % comm->comms_.size();
} }
T* recvbuffer = nullptr; T* recvbuffer = nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册