未验证 提交 ce72c3ff 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #10476 from chengduoZH/refine_parallel_exe

Clean Parallel exe
...@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() { ...@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
out_var_handles.size(), places_.size(), out_var_handles.size(), places_.size(),
"The number of output should equal to the number of places."); "The number of output should equal to the number of places.");
// Wait input done, this Wait is asynchronous operation platform::Place WaitInputVarGenerated();
// &in_place;
WaitInputVarGenerated(*in_var_handle);
std::vector<const Scope *> var_scopes; std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) { for (auto *s : local_scopes_) {
...@@ -50,29 +48,9 @@ void BroadcastOpHandle::RunImpl() { ...@@ -50,29 +48,9 @@ void BroadcastOpHandle::RunImpl() {
auto *in_var = auto *in_var =
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_); var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
// NOTE: The tensors' Place of input and output must be all on GPU or all on InitOutputValue(*in_var_handle, out_var_handles);
// CPU.
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
continue;
}
auto t_out_p = out_var_handle->place_;
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var);
if (platform::is_gpu_place(in_tensor.place())) {
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU.");
} else {
t_out_p = platform::CPUPlace();
}
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
in_tensor.type());
}
if (platform::is_cpu_place(in_tensor.place())) { if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) { for (auto *out_var_handle : out_var_handles) {
...@@ -147,11 +125,37 @@ void BroadcastOpHandle::RunImpl() { ...@@ -147,11 +125,37 @@ void BroadcastOpHandle::RunImpl() {
} }
} }
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) { void BroadcastOpHandle::InitOutputValue(
if (in_var.generated_op_) { const VarHandle &in_var_handle,
for (auto &pair : dev_ctxes_) { const std::vector<VarHandle *> &out_var_handles) const {
in_var.generated_op_->Wait(pair.second); std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto *in_var =
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
// NOTE: The tensors' Place of input and output must be all on GPU or all on
// CPU.
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue;
}
auto t_out_p = out_var_handle->place_;
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_);
PADDLE_ENFORCE_NOT_NULL(out_var);
if (is_gpu_place(in_tensor.place())) {
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU.");
} else {
t_out_p = platform::CPUPlace();
} }
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
in_tensor.type());
} }
} }
......
...@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const VarHandle &in_var);
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
...@@ -65,6 +64,9 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -65,6 +64,9 @@ struct BroadcastOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_; const platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
void InitOutputValue(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ...@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
place_(place) {} place_(place) {}
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
auto *cur_ctx = dev_ctxes_[place_]; WaitInputVarGenerated(place_);
for (auto *in : inputs_) {
bool need_wait = in->generated_op_ &&
in->generated_op_->DeviceContext(place_) != cur_ctx;
if (need_wait) {
in->generated_op_->Wait(cur_ctx);
}
}
this->RunAndRecordEvent([this] { this->RunAndRecordEvent([this] {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}); });
} }
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->generated_op_ &&
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
return need_wait;
}
std::string ComputationOpHandle::Name() const { return op_->Type(); } std::string ComputationOpHandle::Name() const { return op_->Type(); }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
virtual bool NeedWait(VarHandleBase *in_var);
private: private:
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
......
...@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() { ...@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() {
} }
} }
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) { void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
} }
...@@ -45,14 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const { ...@@ -45,14 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
} }
void FetchOpHandle::RunImpl() { void FetchOpHandle::RunImpl() {
auto cpu_ctx = WaitInputVarGenerated(platform::CPUPlace());
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
if (var->generated_op_) {
var->generated_op_->Wait(cpu_ctx);
}
}
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]); auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var_handle->name_; auto &var_name = var_handle->name_;
...@@ -79,6 +73,15 @@ void FetchOpHandle::RunImpl() { ...@@ -79,6 +73,15 @@ void FetchOpHandle::RunImpl() {
this->WaitAndMergeCPUTensors(); this->WaitAndMergeCPUTensors();
} }
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
for (auto *input : inputs_) {
if (input->generated_op_) {
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
}
}
}
std::string FetchOpHandle::Name() const { return "Fetch"; } std::string FetchOpHandle::Name() const { return "Fetch"; }
} // namespace details } // namespace details
......
...@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase {
~FetchOpHandle(); ~FetchOpHandle();
void Wait(platform::DeviceContext *waited_dev) override; void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
void WaitAndMergeCPUTensors() const; void WaitAndMergeCPUTensors() const;
...@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
virtual void WaitInputVarGenerated(const platform::Place &place);
private: private:
FeedFetchList *data_; FeedFetchList *data_;
size_t offset_; size_t offset_;
......
...@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() { ...@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() {
"Currently, gather_op only can gather SelectedRows."); "Currently, gather_op only can gather SelectedRows.");
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated();
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>(); auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
...@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() { ...@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() {
}); });
} }
void GatherOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}
std::string GatherOpHandle::Name() const { return "gather"; } std::string GatherOpHandle::Name() const { return "gather"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase { ...@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
......
...@@ -34,12 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -34,12 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
// Wait input done // Wait input done
for (auto *in : inputs_) { WaitInputVarGenerated();
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_; auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
int dtype = -1; int dtype = -1;
......
...@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) { ...@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
RunImpl(); RunImpl();
} }
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) { if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctxes_) { for (auto &dev_ctx : dev_ctxes_) {
dev_ctx.second->Wait(); dev_ctx.second->Wait();
} }
} else { } else {
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream(); static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream();
for (auto &ev : events_) { for (auto &ev : events_) {
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0)); PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
} }
...@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) { ...@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
out->generated_op_ = this; out->generated_op_ = this;
} }
void OpHandleBase::WaitInputVarGenerated() {
for (auto in_var : inputs_) {
if (NeedWait(in_var)) {
for (auto &pair : dev_ctxes_) {
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
}
}
}
}
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
}
}
}
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
return in_var && in_var->generated_op_;
}
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event if (!events_.empty()) { // Use event
......
...@@ -38,12 +38,24 @@ class OpHandleBase { ...@@ -38,12 +38,24 @@ class OpHandleBase {
void Run(bool use_event); void Run(bool use_event);
virtual void Wait(platform::DeviceContext *waited_dev); virtual void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx);
void AddInput(VarHandleBase *in); void AddInput(VarHandleBase *in);
void AddOutput(VarHandleBase *out); void AddOutput(VarHandleBase *out);
// This method adds the wait events of all the input on all the device
// context.
// NODE: This Wait is asynchronous operation.
virtual void WaitInputVarGenerated();
// This method adds the wait events of all the input on the specified device
// context.
// NODE: This Wait is asynchronous operation.
virtual void WaitInputVarGenerated(const platform::Place &place);
virtual bool NeedWait(VarHandleBase *in_var);
// If the Op involves data transfer of multiple devices that // If the Op involves data transfer of multiple devices that
// will likely block other computations. // will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; } virtual bool IsMultiDeviceTransfer() { return false; }
......
...@@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -51,7 +51,7 @@ void ReduceOpHandle::RunImpl() {
PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE_NOT_NULL(pre_in_var);
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated(in_var_handles); WaitInputVarGenerated();
// NOTE: The Places of all input tensor must be all on CPU or all on GPU. // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx std::vector<platform::Place> in_places; // used to get dev_ctx
...@@ -80,19 +80,21 @@ void ReduceOpHandle::RunImpl() { ...@@ -80,19 +80,21 @@ void ReduceOpHandle::RunImpl() {
} }
if (pre_in_var->IsType<framework::SelectedRows>()) { if (pre_in_var->IsType<framework::SelectedRows>()) {
this->RunAndRecordEvent([&] {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p, GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
});
} else { } else {
std::vector<const LoDTensor *> lod_tensors = std::vector<const LoDTensor *> lod_tensors =
GetInputValues<LoDTensor>(in_var_handles, var_scopes); GetInputValues<LoDTensor>(in_var_handles, var_scopes);
if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) { if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) {
this->RunAndRecordEvent([&] {
ReduceLoDTensor func(lod_tensors, ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>()); out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(ToDataType(lod_tensors[0]->type()), func);
});
} else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) { } else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto pre_in = pre_in_var->Get<framework::LoDTensor>(); auto pre_in = pre_in_var->Get<framework::LoDTensor>();
...@@ -157,17 +159,6 @@ std::vector<const T *> ReduceOpHandle::GetInputValues( ...@@ -157,17 +159,6 @@ std::vector<const T *> ReduceOpHandle::GetInputValues(
return in_selected_rows; return in_selected_rows;
} }
void ReduceOpHandle::WaitInputVarGenerated(
const std::vector<VarHandle *> &in_var_handles) {
for (auto *in : in_var_handles) {
if (in->generated_op_) {
for (auto pair : dev_ctxes_) {
in->generated_op_->Wait(pair.second);
}
}
}
}
std::string ReduceOpHandle::Name() const { return "reduce"; } std::string ReduceOpHandle::Name() const { return "reduce"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -60,8 +60,6 @@ struct ReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
template <typename T> template <typename T>
std::vector<const T *> GetInputValues( std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles, const std::vector<VarHandle *> &in_var_handles,
......
...@@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, ...@@ -29,6 +29,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
void ScaleLossGradOpHandle::RunImpl() { void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_; std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
...@@ -26,6 +26,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, ...@@ -26,6 +26,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
place_(place) {} place_(place) {}
void SendOpHandle::RunImpl() { void SendOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done // Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
...@@ -33,7 +34,7 @@ void SendOpHandle::RunImpl() { ...@@ -33,7 +34,7 @@ void SendOpHandle::RunImpl() {
continue; continue;
} }
if (in->generated_op_) { if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]); in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
} }
} }
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -45,73 +43,33 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -45,73 +43,33 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// Should revisit it if overlapping is available. // Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var);
if (var.generated_op_ == nullptr) {
ready_vars.Push(&var);
}
};
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
pending_ops.insert({&op_instance, op_instance.Inputs().size()});
};
// Transform SSAGraph to pending_ops & pending_vars // Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->vars_) { for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(*version_pair); InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
} }
} }
} }
for (auto &var : graph_->dep_vars_) { for (auto &var : graph_->dep_vars_) {
InsertPendingVar(*var); InsertPendingVar(&pending_vars, &ready_vars, var.get());
} }
for (auto &op : graph_->ops_) { for (auto &op : graph_->ops_) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
InsertPendingOp(*op); InsertPendingOp(&pending_ops, op.get());
} }
} }
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
FeedFetchList fetch_data(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
}
}
}
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
for (size_t i = 0; i < fetch_tensors.size(); ++i) { FeedFetchList fetch_data(fetch_tensors.size());
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(&fetch_data, i, &local_scopes_);
fetch_ops.emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
auto *fetch_dummy = new DummyVarHandle(); InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
op->AddOutput(fetch_dummy); &pending_vars, &ready_vars, &fetch_data);
fetch_dependencies.emplace(fetch_dummy);
InsertPendingVar(*fetch_dummy);
InsertPendingOp(*op);
}
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
...@@ -174,6 +132,60 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -174,6 +132,60 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return fetch_data; return fetch_data;
} }
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name);
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
auto *fetch_dummy = new DummyVarHandle();
op->AddOutput(fetch_dummy);
fetch_dependencies->emplace(fetch_dummy);
this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy);
this->InsertPendingOp(pending_ops, op);
}
}
void ThreadedSSAGraphExecutor::InsertPendingOp(
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const {
pending_ops->insert({op_instance, op_instance->Inputs().size()});
}
void ThreadedSSAGraphExecutor::InsertPendingVar(
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, VarHandleBase *var) const {
pending_vars->insert(var);
if (var->generated_op_ == nullptr) {
ready_vars->Push(var);
}
}
void ThreadedSSAGraphExecutor::RunOp( void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <functional> #include <functional>
#include "ThreadPool.h" // ThreadPool in thrird party #include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace paddle { namespace paddle {
...@@ -58,6 +59,21 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -58,6 +59,21 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_; std::atomic<int> running_ops_;
bool allow_op_delay_; bool allow_op_delay_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const;
void InsertPendingVar(std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars,
VarHandleBase *var) const;
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data);
}; };
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册