提交 f251a58e 编写于 作者: Y Yu Yang

Use base class manage events

上级 1dd216dc
......@@ -68,6 +68,8 @@ struct OpHandle {
platform::PlaceHash>
dev_ctx_;
std::unordered_map<int, cudaEvent_t> events_;
std::string DebugString() {
std::stringstream ss;
ss << "(";
......@@ -84,32 +86,57 @@ struct OpHandle {
virtual ~OpHandle() {}
virtual void Run() = 0;
virtual void Wait(platform::DeviceContext *waited_dev) = 0;
void Run() {
if (events_.empty()) {
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
cudaSetDevice(dev_id);
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming);
}
}
RunImpl();
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
cudaEventRecord(events_.at(dev_id), stream);
}
}
virtual void Wait(platform::DeviceContext *waited_dev) {
if (platform::is_cpu_place(waited_dev->GetPlace())) {
for (auto &dev_ctx : dev_ctx_) {
dev_ctx.second->Wait();
}
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
}
}
}
protected:
virtual void RunImpl() = 0;
};
struct ComputationOpHandle : public OpHandle {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
cudaEvent_t event_;
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope),
place_(place) {
if (platform::is_gpu_place(place)) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
cudaEventCreateWithFlags(&event_, cudaEventDisableTiming);
}
}
~ComputationOpHandle() {
// FIXME: Destroy Event
}
place_(place) {}
void Run() override {
protected:
void RunImpl() override {
// Wait other op if necessary
if (platform::is_gpu_place(place_)) {
int dev_id = boost::get<platform::CUDAPlace>(place_).device;
......@@ -123,22 +150,6 @@ struct ComputationOpHandle : public OpHandle {
}
op_->Run(*scope_, place_);
if (platform::is_gpu_place(place_)) {
auto stream = static_cast<platform::CUDADeviceContext *>(dev_ctx_[place_])
->stream();
PADDLE_ENFORCE(cudaEventRecord(event_, stream));
}
}
void Wait(platform::DeviceContext *waited_dev) override {
if (platform::is_cpu_place(waited_dev->GetPlace()) ||
platform::is_cpu_place(place_)) {
this->dev_ctx_.at(place_)->Wait();
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, event_, 0));
}
}
};
......@@ -146,7 +157,6 @@ struct ScaleLossGradOpHandle : public OpHandle {
float coeff_;
Scope *scope_;
platform::Place place_;
cudaEvent_t ev_;
explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place)
......@@ -154,16 +164,14 @@ struct ScaleLossGradOpHandle : public OpHandle {
scope_(scope),
place_(place) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
// Must set device before create event
PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming));
}
~ScaleLossGradOpHandle() {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventDestroy(ev_));
}
void Run() override {
protected:
void RunImpl() override {
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
float *tmp = scope_->FindVar(var_name)
......@@ -176,20 +184,8 @@ struct ScaleLossGradOpHandle : public OpHandle {
auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
->stream();
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
PADDLE_ENFORCE(cudaEventRecord(ev_, stream));
}
}
void Wait(platform::DeviceContext *waited_dev) override {
if (platform::is_cpu_place(waited_dev->GetPlace())) {
dev_ctx_.at(place_)->Wait();
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev_, 0));
}
}
};
......@@ -216,7 +212,12 @@ struct FetchOpHandle : public OpHandle {
MergeTensors();
}
void Run() override {
void Wait(platform::DeviceContext *waited_dev) override {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
protected:
void RunImpl() override {
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(this->dev_ctx_[var->place_]);
......@@ -240,10 +241,6 @@ struct FetchOpHandle : public OpHandle {
}
}
void Wait(platform::DeviceContext *waited_dev) override {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
private:
void MergeTensors() const {
std::vector<const LoDTensor *> tensors_ptr;
......@@ -256,8 +253,8 @@ struct FetchOpHandle : public OpHandle {
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(size_t num_threads = 12)
: pool_(num_threads) {}
explicit ParallelExecutorPrivate(size_t num_threads = 0)
: pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {}
std::vector<platform::Place> places_;
......@@ -333,7 +330,7 @@ class ParallelExecutorPrivate {
std::vector<std::unique_ptr<OpHandle>> ops_;
// Use a simpler thread pool, might be faster.
ThreadPool pool_;
std::unique_ptr<ThreadPool> pool_;
std::unique_ptr<platform::EnforceNotMet> exception_;
};
......@@ -353,25 +350,12 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_;
std::unordered_map<int, cudaEvent_t> events_;
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) {
for (auto &nccl : member_->communication_streams_) {
int dev_id = nccl.second.device_id();
cudaSetDevice(dev_id);
PADDLE_ENFORCE(cudaEventCreate(&events_[dev_id], cudaEventDisableTiming));
}
}
: member_(member) {}
~NCCLAllReduceOpHandle() {
for (auto &ev : events_) {
cudaSetDevice(ev.first);
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
}
}
void Run() override {
protected:
void RunImpl() override {
if (this->inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
......@@ -403,34 +387,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
}
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
cudaSetDevice(dev_id);
platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream());
}
platform::dynload::ncclGroupEnd();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaEventRecord(
ev.second, member_->communication_streams_.at(ev.first).stream()));
}
}
}
void Wait(platform::DeviceContext *waited_dev) override {
if (platform::is_cpu_place(
waited_dev->GetPlace())) { // Wait by CPU, just sync stream
for (auto &pair : member_->communication_streams_) {
pair.second.ctx_->Wait();
}
} else {
if (events_.size() > 1) {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
}
}
}
}
};
......@@ -851,8 +812,11 @@ void ParallelExecutor::RunOp(
LOG(FATAL) << "Unknown exception catched";
}
};
op_run();
// member_->pool_.enqueue(op_run);
if (member_->pool_) {
member_->pool_->enqueue(op_run);
} else {
op_run();
}
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册