提交 9eec2c75 编写于 作者: C chengduoZH

refine pe

上级 f4851f14
......@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places.");
// Wait input done, this Wait is asynchronous operation platform::Place
// &in_place;
WaitInputVarGenerated(*in_var_handle);
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
......@@ -147,14 +145,6 @@ void BroadcastOpHandle::RunImpl() {
}
}
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
if (in_var.generated_op_) {
for (auto &pair : dev_ctxes_) {
in_var.generated_op_->Wait(pair.second);
}
}
}
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
} // namespace details
} // namespace framework
......
......@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);
private:
const std::vector<Scope *> &local_scopes_;
......
......@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
place_(place) {}
void ComputationOpHandle::RunImpl() {
auto *cur_ctx = dev_ctxes_[place_];
for (auto *in : inputs_) {
bool need_wait = in->generated_op_ &&
in->generated_op_->DeviceContext(place_) != cur_ctx;
if (need_wait) {
in->generated_op_->Wait(cur_ctx);
}
}
WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
});
}
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
dynamic_cast<VarHandle *>(in_var) && in_var->generated_op_ &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait;
}
std::string ComputationOpHandle::Name() const { return op_->Type(); }
} // namespace details
} // namespace framework
......
......@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
virtual bool NeedWait(VarHandleBase *in_var);
private:
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
......
......@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() {
}
}
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) {
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
......@@ -45,12 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
}
void FetchOpHandle::RunImpl() {
auto cpu_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(cpu_ctx);
}
WaitInputVarGenerated(platform::CPUPlace());
tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var_handle->name_;
......@@ -77,6 +73,15 @@ void FetchOpHandle::RunImpl() {
this->WaitAndMergeCPUTensors();
}
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) {
if (input->generated_op_) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
}
}
}
std::string FetchOpHandle::Name() const { return "Fetch"; }
} // namespace details
......
......@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase {
~FetchOpHandle();
void Wait(platform::DeviceContext *waited_dev) override;
void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
void WaitAndMergeCPUTensors() const;
......@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
virtual void WaitInputVarGenerated(const platform::Place &place);
private:
FeedFetchList *data_;
size_t offset_;
......
......@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() {
"Currently, gather_op only can gather SelectedRows.");
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles);
WaitInputVarGenerated();
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
std::vector<int64_t> out_rows;
......@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() {
});
}
void GatherOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}
std::string GatherOpHandle::Name() const { return "gather"; }
} // namespace details
} // namespace framework
......
......@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
private:
const std::vector<Scope *> &local_scopes_;
......
......@@ -34,10 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]);
}
WaitInputVarGenerated();
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
int dtype = -1;
......
......@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
RunImpl();
}
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctxes_) {
dev_ctx.second->Wait();
}
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream();
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
}
......@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out->generated_op_ = this;
}
void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
}
}
}
}
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
}
}
}
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return dynamic_cast<VarHandle *>(in_var) && in_var->generated_op_;
}
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
......
......@@ -38,12 +38,20 @@ class OpHandleBase {
void Run(bool use_event);
virtual void Wait(platform::DeviceContext *waited_dev);
virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx);
void AddInput(VarHandleBase *in);
void AddOutput(VarHandleBase *out);
// Wait inputs are generated, this Wait is asynchronous operation.
virtual void WaitInputVarGenerated();
// Wait inputs are generated, this Wait is asynchronous operation.
virtual void WaitInputVarGenerated(const platform::Place &place);
virtual bool NeedWait(VarHandleBase *in_var);
// If the Op involves data transfer of multiple devices that
// will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; }
......
......@@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() {
PADDLE_ENFORCE_NOT_NULL(pre_in_var);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles);
WaitInputVarGenerated();
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx
......@@ -157,17 +157,6 @@ std::vector<const T *> ReduceOpHandle::GetInputValues(
return in_selected_rows;
}
void ReduceOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}
std::string ReduceOpHandle::Name() const { return "reduce"; }
} // namespace details
} // namespace framework
......
......@@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
template <typename T>
std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles,
......
......@@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
......@@ -27,13 +27,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
void SendOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
}
WaitInputVarGenerated();
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册