提交 fdfc8f9b 编写于 作者: D Dong Zhihong

"switch to Init op"

上级 23cb8259
......@@ -79,7 +79,22 @@ struct Communicator {
streams_.resize(gpus.size());
events_.resize(gpus.size());
}
// Communicator(int num_device): comms_.resize(num_device) {}
~Communicator() {
for (size_t i = 0; i < gpus_.size(); ++i) {
int gid = gpus_[i];
platform::SetDeviceId(gid);
int idx = gid % gpus_.size();
// wait finish
PADDLE_ENFORCE(
cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx]));
PADDLE_ENFORCE(ncclCommDestroy(comm->comms_[idx]));
}
}
inline int get_root_gpu() const { return root_gpu; }
......
......@@ -14,7 +14,33 @@
namespace paddle {
namespace operators {
// AllreduceOp
// NCCLinitOp
class NCCLInitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Communicator"),
" Input(X) of AllReduce op input should not be NULL");
}
};
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<std::vector<int>>("gpus", "gpu id lists");
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
AddComment(R"DOC(
create communicator.
)DOC");
}
};
// AllReduceOp
class NCCLAllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -23,6 +49,9 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
" Input(X) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(
ctx->HasInput("Communicator"),
" Input(Communicator) of AllReduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of AllReduce op input should not be NULL");
......@@ -45,6 +74,7 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of AllReduce op");
AddInput("Communicator", "Communicator for communicating between gpus");
AddOutput("Out", "The output of AllReduce op");
AddAttr<std::string>("reduction",
"{'ncclmin', 'ncclmax', 'ncclprod', 'ncclsum'}.");
......@@ -55,31 +85,31 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
// BcastSendOp
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
AddComment(R"DOC(
BcastSend the tensors.
)DOC");
}
};
// // BcastSendOp
// class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
// public:
// NCCLAllReduceOpMaker(framework::OpProto *proto,
// framework::OpAttrChecker *op_checker)
// : OpProtoAndCheckerMaker(proto, op_checker) {
// AddInput("X", "The input of BcastSend op");
// AddComment(R"DOC(
// BcastSend the tensors.
// )DOC");
// }
// };
// BcastRecvOp
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllReduceOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "The output of BcastRecv op");
AddComment(R"DOC(
BcastRecv the tensors.
)DOC");
}
};
// // BcastRecvOp
// class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
// public:
// NCCLAllReduceOpMaker(framework::OpProto *proto,
// framework::OpAttrChecker *op_checker)
// : OpProtoAndCheckerMaker(proto, op_checker) {
// AddOutput("Out", "The output of BcastRecv op");
// AddComment(R"DOC(
// BcastRecv the tensors.
// )DOC");
// }
// };
} // namespace operators
} // namespace paddle
......
......@@ -35,6 +35,16 @@ class NCCLTypeWrapper<double> {
static const ncclDataType_t type = ncclDouble;
};
class NCCLInitOp : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto gpus = ctx.Input<std::vector<int>>("gpus");
auto* comm = ctx.Output<Communicator>("Communicator");
comm->mutable_data<Communicator>(CPUPlace());
comm = NCCLManager::GetCommunicator(gpus);
}
};
template <typename T>
class NCCLAllReduceKernel : public framework::OpKernel<T> {
public:
......@@ -54,13 +64,15 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
op_type = ncclMax;
}
auto* comm = ctx.Input<Communicator>("Communicator");
auto dev_ctx =
static_cast<const platform::CUDADeviceContext>(ctx.device_context());
platform::NCCLManager* m = platform::NCCLManager::Get();
// platform::NCCLManager* m = platform::NCCLManager::Get();
auto* comm = m->GetCommunicator(gpus);
comm->wg_.Add(1);
// auto* comm = m->GetCommunicator(gpus);
// comm->wg_.Add(1);
auto stream = dev_ctx.stream();
......@@ -76,14 +88,14 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
op_type, comm->comms_[idx], comm->streams_[idx]));
PADDLE_ENFORCE(cudaEventRecord(comm->events_[idx], comm->streams_[idx]));
// wait finish
PADDLE_ENFORCE(
cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
// // wait finish
// PADDLE_ENFORCE(
// cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
}
comm->wg_.Done();
// comm->wg_.Done();
comm->wg_.Wait();
// comm->wg_.Wait();
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册