提交 9eec2c75 编写于 作者: C chengduoZH

refine pe

上级 f4851f14
...@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() { ...@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
out_var_handles.size(), places_.size(), out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places."); "The number of output should equal to the number of places.");
// Wait input done, this Wait is asynchronous operation platform::Place WaitInputVarGenerated();
// &in_place;
WaitInputVarGenerated(*in_var_handle);
std::vector<const Scope *> var_scopes; std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) { for (auto *s : local_scopes_) {
...@@ -147,14 +145,6 @@ void BroadcastOpHandle::RunImpl() { ...@@ -147,14 +145,6 @@ void BroadcastOpHandle::RunImpl() {
} }
} }
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
if (in_var.generated_op_) {
for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second);
}
}
}
std::string BroadcastOpHandle::Name() const { return "broadcast"; } std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
......
...@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ...@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
place_(place) {} place_(place) {}
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
auto *cur_ctx = dev_ctxes_[place_]; WaitInputVarGenerated(place_);
for (auto *in : inputs_) {
bool need_wait = in->generated_op_ &&
in->generated_op_->DeviceContext(place_) != cur_ctx;
if (need_wait) {
in->generated_op_->Wait(cur_ctx);
}
}
this->RunAndRecordEvent([this] { this->RunAndRecordEvent([this] {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}); });
} }
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
dynamic_cast<VarHandle *>(in_var) && in_var->generated_op_ &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait;
}
std::string ComputationOpHandle::Name() const { return op_->Type(); } std::string ComputationOpHandle::Name() const { return op_->Type(); }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
virtual bool NeedWait(VarHandleBase *in_var);
private: private:
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
......
...@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() { ...@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() {
} }
} }
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) { void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
} }
...@@ -45,12 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const { ...@@ -45,12 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
} }
void FetchOpHandle::RunImpl() { void FetchOpHandle::RunImpl() {
auto cpu_ctx = WaitInputVarGenerated(platform::CPUPlace());
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(cpu_ctx);
}
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]); auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var_handle->name_; auto &var_name = var_handle->name_;
...@@ -77,6 +73,15 @@ void FetchOpHandle::RunImpl() { ...@@ -77,6 +73,15 @@ void FetchOpHandle::RunImpl() {
this->WaitAndMergeCPUTensors(); this->WaitAndMergeCPUTensors();
} }
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) {
if (input->generated_op_) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
}
}
}
std::string FetchOpHandle::Name() const { return "Fetch"; } std::string FetchOpHandle::Name() const { return "Fetch"; }
} // namespace details } // namespace details
......
...@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase {
~FetchOpHandle(); ~FetchOpHandle();
void Wait(platform::DeviceContext *waited_dev) override; void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
void WaitAndMergeCPUTensors() const; void WaitAndMergeCPUTensors() const;
...@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
virtual void WaitInputVarGenerated(const platform::Place &place);
private: private:
FeedFetchList *data_; FeedFetchList *data_;
size_t offset_; size_t offset_;
......
...@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() { ...@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() {
"Currently, gather_op only can gather SelectedRows."); "Currently, gather_op only can gather SelectedRows.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated();
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>(); auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
...@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() { ...@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() {
}); });
} }
void GatherOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}
std::string GatherOpHandle::Name() const { return "gather"; } std::string GatherOpHandle::Name() const { return "gather"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase { ...@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
......
...@@ -34,10 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -34,10 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
// Wait input done // Wait input done
for (auto *in : inputs_) { WaitInputVarGenerated();
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]);
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_; auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
int dtype = -1; int dtype = -1;
......
...@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) { ...@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
RunImpl(); RunImpl();
} }
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) { if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctxes_) { for (auto &dev_ctx : dev_ctxes_) {
dev_ctx.second->Wait(); dev_ctx.second->Wait();
} }
} else { } else {
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream(); static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream();
for (auto &ev : events_) { for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0)); PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
} }
...@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) { ...@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out->generated_op_ = this; out->generated_op_ = this;
} }
void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
}
}
}
}
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
}
}
}
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return dynamic_cast<VarHandle *>(in_var) && in_var->generated_op_;
}
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event if (!events_.empty()) { // Use event
......
...@@ -38,12 +38,20 @@ class OpHandleBase { ...@@ -38,12 +38,20 @@ class OpHandleBase {
void Run(bool use_event); void Run(bool use_event);
virtual void Wait(platform::DeviceContext *waited_dev); virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx);
void AddInput(VarHandleBase *in); void AddInput(VarHandleBase *in);
void AddOutput(VarHandleBase *out); void AddOutput(VarHandleBase *out);
// Wait inputs are generated, this Wait is asynchronous operation.
virtual void WaitInputVarGenerated();
// Wait inputs are generated, this Wait is asynchronous operation.
virtual void WaitInputVarGenerated(const platform::Place &place);
virtual bool NeedWait(VarHandleBase *in_var);
// If the Op involves data transfer of multiple devices that // If the Op involves data transfer of multiple devices that
// will likely block other computations. // will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; } virtual bool IsMultiDeviceTransfer() { return false; }
......
...@@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() {
PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE_NOT_NULL(pre_in_var);
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated();
// NOTE: The Places of all input tensor must be all on CPU or all on GPU. // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx std::vector<platform::Place> in_places; // used to get dev_ctx
...@@ -157,17 +157,6 @@ std::vector<const T *> ReduceOpHandle::GetInputValues( ...@@ -157,17 +157,6 @@ std::vector<const T *> ReduceOpHandle::GetInputValues(
return in_selected_rows; return in_selected_rows;
} }
void ReduceOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}
std::string ReduceOpHandle::Name() const { return "reduce"; } std::string ReduceOpHandle::Name() const { return "reduce"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
template <typename T> template <typename T>
std::vector<const T *> GetInputValues( std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles, const std::vector<VarHandle *> &in_var_handles,
......
...@@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, ...@@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
void ScaleLossGradOpHandle::RunImpl() { void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_; std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
...@@ -27,13 +27,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, ...@@ -27,13 +27,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
void SendOpHandle::RunImpl() { void SendOpHandle::RunImpl() {
// Wait input done // Wait input done
for (auto *in : inputs_) { WaitInputVarGenerated();
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock. // lock.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册