提交 f251a58e 编写于 作者: Y Yu Yang

Use base class manage events

上级 1dd216dc
...@@ -68,6 +68,8 @@ struct OpHandle { ...@@ -68,6 +68,8 @@ struct OpHandle {
platform::PlaceHash> platform::PlaceHash>
dev_ctx_; dev_ctx_;
std::unordered_map<int, cudaEvent_t> events_;
std::string DebugString() { std::string DebugString() {
std::stringstream ss; std::stringstream ss;
ss << "("; ss << "(";
...@@ -84,32 +86,57 @@ struct OpHandle { ...@@ -84,32 +86,57 @@ struct OpHandle {
virtual ~OpHandle() {} virtual ~OpHandle() {}
virtual void Run() = 0; void Run() {
virtual void Wait(platform::DeviceContext *waited_dev) = 0; if (events_.empty()) {
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
cudaSetDevice(dev_id);
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming);
}
}
RunImpl();
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
cudaEventRecord(events_.at(dev_id), stream);
}
}
virtual void Wait(platform::DeviceContext *waited_dev) {
if (platform::is_cpu_place(waited_dev->GetPlace())) {
for (auto &dev_ctx : dev_ctx_) {
dev_ctx.second->Wait();
}
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
}
}
}
protected:
virtual void RunImpl() = 0;
}; };
struct ComputationOpHandle : public OpHandle { struct ComputationOpHandle : public OpHandle {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
cudaEvent_t event_;
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope, explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place) platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope), scope_(scope),
place_(place) { place_(place) {}
if (platform::is_gpu_place(place)) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
cudaEventCreateWithFlags(&event_, cudaEventDisableTiming);
}
}
~ComputationOpHandle() {
// FIXME: Destroy Event
}
void Run() override { protected:
void RunImpl() override {
// Wait other op if necessary // Wait other op if necessary
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
int dev_id = boost::get<platform::CUDAPlace>(place_).device; int dev_id = boost::get<platform::CUDAPlace>(place_).device;
...@@ -123,22 +150,6 @@ struct ComputationOpHandle : public OpHandle { ...@@ -123,22 +150,6 @@ struct ComputationOpHandle : public OpHandle {
} }
op_->Run(*scope_, place_); op_->Run(*scope_, place_);
if (platform::is_gpu_place(place_)) {
auto stream = static_cast<platform::CUDADeviceContext *>(dev_ctx_[place_])
->stream();
PADDLE_ENFORCE(cudaEventRecord(event_, stream));
}
}
void Wait(platform::DeviceContext *waited_dev) override {
if (platform::is_cpu_place(waited_dev->GetPlace()) ||
platform::is_cpu_place(place_)) {
this->dev_ctx_.at(place_)->Wait();
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, event_, 0));
}
} }
}; };
...@@ -146,7 +157,6 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -146,7 +157,6 @@ struct ScaleLossGradOpHandle : public OpHandle {
float coeff_; float coeff_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
cudaEvent_t ev_;
explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope, explicit ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place) platform::Place place)
...@@ -154,16 +164,14 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -154,16 +164,14 @@ struct ScaleLossGradOpHandle : public OpHandle {
scope_(scope), scope_(scope),
place_(place) { place_(place) {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device); cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
// Must set device before create event
PADDLE_ENFORCE(cudaEventCreateWithFlags(&ev_, cudaEventDisableTiming));
} }
~ScaleLossGradOpHandle() { ~ScaleLossGradOpHandle() {
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device); cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
PADDLE_ENFORCE(cudaEventDestroy(ev_));
} }
void Run() override { protected:
void RunImpl() override {
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_; std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
float *tmp = scope_->FindVar(var_name) float *tmp = scope_->FindVar(var_name)
...@@ -176,20 +184,8 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -176,20 +184,8 @@ struct ScaleLossGradOpHandle : public OpHandle {
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_]) static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
->stream(); ->stream();
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp, memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream); platform::CPUPlace(), &coeff_, sizeof(float), stream);
PADDLE_ENFORCE(cudaEventRecord(ev_, stream));
}
}
void Wait(platform::DeviceContext *waited_dev) override {
if (platform::is_cpu_place(waited_dev->GetPlace())) {
dev_ctx_.at(place_)->Wait();
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev_, 0));
} }
} }
}; };
...@@ -216,7 +212,12 @@ struct FetchOpHandle : public OpHandle { ...@@ -216,7 +212,12 @@ struct FetchOpHandle : public OpHandle {
MergeTensors(); MergeTensors();
} }
void Run() override { void Wait(platform::DeviceContext *waited_dev) override {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
protected:
void RunImpl() override {
for (auto *input : inputs_) { for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input); auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(this->dev_ctx_[var->place_]); var->generated_op_->Wait(this->dev_ctx_[var->place_]);
...@@ -240,10 +241,6 @@ struct FetchOpHandle : public OpHandle { ...@@ -240,10 +241,6 @@ struct FetchOpHandle : public OpHandle {
} }
} }
void Wait(platform::DeviceContext *waited_dev) override {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
private: private:
void MergeTensors() const { void MergeTensors() const {
std::vector<const LoDTensor *> tensors_ptr; std::vector<const LoDTensor *> tensors_ptr;
...@@ -256,8 +253,8 @@ struct FetchOpHandle : public OpHandle { ...@@ -256,8 +253,8 @@ struct FetchOpHandle : public OpHandle {
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(size_t num_threads = 12) explicit ParallelExecutorPrivate(size_t num_threads = 0)
: pool_(num_threads) {} : pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {}
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
...@@ -333,7 +330,7 @@ class ParallelExecutorPrivate { ...@@ -333,7 +330,7 @@ class ParallelExecutorPrivate {
std::vector<std::unique_ptr<OpHandle>> ops_; std::vector<std::unique_ptr<OpHandle>> ops_;
// Use a simpler thread pool, might be faster. // Use a simpler thread pool, might be faster.
ThreadPool pool_; std::unique_ptr<ThreadPool> pool_;
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
}; };
...@@ -353,25 +350,12 @@ ncclDataType_t ToNCCLDataType(std::type_index type) { ...@@ -353,25 +350,12 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
struct NCCLAllReduceOpHandle : public OpHandle { struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::unordered_map<int, cudaEvent_t> events_;
explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member) explicit NCCLAllReduceOpHandle(ParallelExecutorPrivate *member)
: member_(member) { : member_(member) {}
for (auto &nccl : member_->communication_streams_) {
int dev_id = nccl.second.device_id();
cudaSetDevice(dev_id);
PADDLE_ENFORCE(cudaEventCreate(&events_[dev_id], cudaEventDisableTiming));
}
}
~NCCLAllReduceOpHandle() { protected:
for (auto &ev : events_) { void RunImpl() override {
cudaSetDevice(ev.first);
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
}
}
void Run() override {
if (this->inputs_.size() == 1) { if (this->inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
...@@ -403,34 +387,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -403,34 +387,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
} }
auto &nccl_ctx = member_->communication_streams_.at(dev_id); auto &nccl_ctx = member_->communication_streams_.at(dev_id);
cudaSetDevice(dev_id);
platform::dynload::ncclAllReduce( platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum, buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream()); nccl_ctx.comm, nccl_ctx.stream());
} }
platform::dynload::ncclGroupEnd(); platform::dynload::ncclGroupEnd();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaEventRecord(
ev.second, member_->communication_streams_.at(ev.first).stream()));
}
}
}
void Wait(platform::DeviceContext *waited_dev) override {
if (platform::is_cpu_place(
waited_dev->GetPlace())) { // Wait by CPU, just sync stream
for (auto &pair : member_->communication_streams_) {
pair.second.ctx_->Wait();
}
} else {
if (events_.size() > 1) {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
}
}
} }
} }
}; };
...@@ -851,8 +812,11 @@ void ParallelExecutor::RunOp( ...@@ -851,8 +812,11 @@ void ParallelExecutor::RunOp(
LOG(FATAL) << "Unknown exception catched"; LOG(FATAL) << "Unknown exception catched";
} }
}; };
if (member_->pool_) {
member_->pool_->enqueue(op_run);
} else {
op_run(); op_run();
// member_->pool_.enqueue(op_run); }
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册