From 37e514432b0d6453906aaaacbf54d8ae51ebddc4 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 12 Jul 2018 15:30:21 +0800 Subject: [PATCH] op compose node and update nodes. --- .../framework/details/all_reduce_op_handle.cc | 13 +++- .../framework/details/all_reduce_op_handle.h | 4 +- .../framework/details/broadcast_op_handle.h | 11 ++- .../details/broadcast_op_handle_test.cc | 32 +++++--- .../details/computation_op_handle.cc | 11 +-- .../framework/details/computation_op_handle.h | 2 +- .../details/data_balance_op_handle.cc | 8 +- .../details/data_balance_op_handle.h | 4 +- .../framework/details/fetch_op_handle.cc | 13 ++-- .../fluid/framework/details/fetch_op_handle.h | 2 +- .../framework/details/fuse_vars_op_handle.h | 6 +- .../framework/details/gather_op_handle.cc | 5 +- .../framework/details/gather_op_handle.h | 2 +- .../details/gather_op_handle_test.cc | 22 ++++-- .../details/multi_devices_graph_builder.cc | 73 ++++++++++++------- .../fluid/framework/details/op_handle_base.cc | 12 +-- .../fluid/framework/details/op_handle_base.h | 6 +- .../framework/details/reduce_op_handle.h | 11 ++- .../details/reduce_op_handle_test.cc | 26 ++++--- .../fluid/framework/details/rpc_op_handle.cc | 9 ++- .../fluid/framework/details/rpc_op_handle.h | 5 +- .../details/scale_loss_grad_op_handle.cc | 8 +- .../details/scale_loss_grad_op_handle.h | 3 +- .../framework/details/ssa_graph_builder.cc | 18 +++-- .../framework/details/ssa_graph_checker.cc | 4 +- .../details/threaded_ssa_graph_executor.cc | 19 +++-- .../details/threaded_ssa_graph_executor.h | 1 + paddle/fluid/framework/details/var_handle.h | 49 ++++++++++++- paddle/fluid/framework/ir/node.h | 36 ++------- 29 files changed, 262 insertions(+), 153 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index b335d3a0d36..700c73c745b 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -23,10 +23,14 @@ namespace framework { namespace details { #ifdef PADDLE_WITH_CUDA -AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, +AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(ctxs) { if (nccl_ctxs_) { for (auto &p : places_) { this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); @@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, } } #else -AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, +AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif void AllReduceOpHandle::RunImpl() { diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index fdd250b0d3e..f6ef3a1367b 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -30,11 +30,11 @@ namespace details { struct AllReduceOpHandle : public OpHandleBase { #ifdef PADDLE_WITH_CUDA - AllReduceOpHandle(const std::vector &local_scopes, + AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs); #else - AllReduceOpHandle(const std::vector &local_scopes, + AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 8036f756b6d..fe4e733e434 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -35,10 +35,13 @@ namespace details { struct BroadcastOpHandle : public OpHandleBase { public: #ifdef PADDLE_WITH_CUDA - BroadcastOpHandle(const std::vector &local_scopes, + BroadcastOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *nccl_ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(nccl_ctxs) { if (nccl_ctxs_) { for (auto &p_ctx : nccl_ctxs_->contexts_) { dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); @@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase { } } #else - BroadcastOpHandle(const std::vector &local_scopes, + BroadcastOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index c6e923ef77f..90ee3f7d93e 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,48 +96,56 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); + std::unique_ptr n(new ir::Node(ir::Node::Type::kOperation)); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, + nccl_ctxs_.get())); #else PADDLE_THROW("CUDA is not support."); #endif } else { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, + nccl_ctxs_.get())); #else - op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); + op_handle_.reset( + new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_)); #endif } - auto* in_var_handle = - new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); + std::unique_ptr v(new ir::Node(ir::Node::Type::kVariable)); + auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", + gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); // add dummy var - vars_.emplace_back(new DummyVarHandle()); + + std::unique_ptr v2(new ir::Node(ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); - dummy_var_handle->generated_op_ = nullptr; + dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(dummy_var_handle); for (size_t j = 0; j < gpu_list_.size(); ++j) { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); + std::unique_ptr v3(new ir::Node(ir::Node::Type::kVariable)); + VarHandle* out_var_handle = + new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); } // add dummy var - vars_.emplace_back(new DummyVarHandle()); + std::unique_ptr v4(new ir::Node(ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); - out_dummy_var_handle->generated_op_ = nullptr; + out_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddOutput(out_dummy_var_handle); } diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index df05bb06333..16ad30d491b 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -19,9 +19,10 @@ namespace paddle { namespace framework { namespace details { -ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, - platform::Place place) - : op_(framework::OpRegistry::CreateOp(op_desc)), +ComputationOpHandle::ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, + Scope *scope, platform::Place place) + : OpHandleBase(node), + op_(framework::OpRegistry::CreateOp(op_desc)), scope_(scope), place_(place) {} @@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool need_wait = - in_var && in_var->generated_op_ && - in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; + in_var && in_var->GeneratedOp() && + in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_]; return need_wait; } diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index f048f973fde..9ca1d927b8d 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,7 +28,7 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(const OpDesc &op_desc, Scope *scope, + ComputationOpHandle(ir::Node *node, const OpDesc &op_desc, Scope *scope, platform::Place place); std::string Name() const override; diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index 68896c8ac1b..525d2432244 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -22,10 +22,10 @@ namespace details { #ifdef PADDLE_WITH_CUDA DataBalanceOpHandle::DataBalanceOpHandle( - const std::vector &local_scopes, + ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs) - : local_scopes_(local_scopes), places_(places) { + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) { if (ctxs) { for (auto &p : places_) { this->dev_ctxes_[p] = ctxs->DevCtx(p); @@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle( } #else DataBalanceOpHandle::DataBalanceOpHandle( - const std::vector &local_scopes, + ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string DataBalanceOpHandle::Name() const { return "data balance"; } diff --git a/paddle/fluid/framework/details/data_balance_op_handle.h b/paddle/fluid/framework/details/data_balance_op_handle.h index 76a407e3610..0462fb6ec71 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.h +++ b/paddle/fluid/framework/details/data_balance_op_handle.h @@ -30,11 +30,11 @@ namespace details { struct DataBalanceOpHandle : public OpHandleBase { public: #ifdef PADDLE_WITH_CUDA - DataBalanceOpHandle(const std::vector &local_scopes, + DataBalanceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs); #else - DataBalanceOpHandle(const std::vector &local_scopes, + DataBalanceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); #endif diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index d646c944601..fe18b2060c5 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -21,13 +21,16 @@ namespace paddle { namespace framework { namespace details { -FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset, +FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, std::vector *local_scopes) - : data_(data), offset_(offset), local_scopes_(local_scopes) {} + : OpHandleBase(node), + data_(data), + offset_(offset), + local_scopes_(local_scopes) {} FetchOpHandle::~FetchOpHandle() { for (auto *input_var : inputs_) { - input_var->pending_ops_.erase(this); + input_var->RemoveOutput(this, this->Node()); } } @@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() { void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) { auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place); for (auto *input : inputs_) { - if (input->generated_op_) { - input->generated_op_->RecordWaitEventOnCtx(cpu_ctx); + if (input->GeneratedOp()) { + input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx); } } } diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index e09bdd1d333..6ce42f92d7f 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -28,7 +28,7 @@ namespace details { struct FetchOpHandle : public OpHandleBase { public: - FetchOpHandle(FeedFetchList *data, size_t offset, + FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, std::vector *local_scopes); ~FetchOpHandle(); diff --git a/paddle/fluid/framework/details/fuse_vars_op_handle.h b/paddle/fluid/framework/details/fuse_vars_op_handle.h index 140fb5bb49a..3f360c510a4 100644 --- a/paddle/fluid/framework/details/fuse_vars_op_handle.h +++ b/paddle/fluid/framework/details/fuse_vars_op_handle.h @@ -30,10 +30,12 @@ namespace details { struct FuseVarsOpHandle : public OpHandleBase { public: - FuseVarsOpHandle(Scope *local_scope, const platform::Place &place, + FuseVarsOpHandle(ir::Node *node, Scope *local_scope, + const platform::Place &place, const std::unordered_map &inputs_numel, const std::type_index &var_type) - : local_scope_(local_scope), + : OpHandleBase(node), + local_scope_(local_scope), place_(place), inputs_numel_(inputs_numel), type_(var_type) { diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 2be02304566..9aae19fc73d 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -20,9 +20,10 @@ namespace paddle { namespace framework { namespace details { -GatherOpHandle::GatherOpHandle(const std::vector &local_scopes, +GatherOpHandle::GatherOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} void GatherOpHandle::RunImpl() { if (places_.size() == 1) return; diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index d11ef8556aa..d9afbc6547e 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -30,7 +30,7 @@ namespace details { struct GatherOpHandle : public OpHandleBase { public: - GatherOpHandle(const std::vector &local_scopes, + GatherOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); std::string Name() const override; diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index 3cce2cc1640..5b11f8cdc74 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -70,6 +70,7 @@ struct TestGatherOpHandle { } void InitGatherOp(size_t input_scope_idx) { + std::vector> nodes; for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); Scope& local_scope = local_scopes_.back()->NewScope(); @@ -81,30 +82,37 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_)); + nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + op_handle_.reset( + new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto* in_var_handle = + new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); } // add dummy var - vars_.emplace_back(new DummyVarHandle()); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); - in_dummy_var_handle->generated_op_ = nullptr; + in_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(in_dummy_var_handle); // add output - auto* out_var_handle = - new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, + "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - vars_.emplace_back(new DummyVarHandle()); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); op_handle_->AddOutput(dummy_var_handle); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 0a953704193..cb2ab905163 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -328,12 +328,16 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_); + auto *op_handle = new BroadcastOpHandle(result->nodes.back().get(), + local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); + auto *op_handle = + new BroadcastOpHandle(result->nodes.back().get(), local_scopes_, places_); #endif result->Get("ops").emplace_back(op_handle); + auto *in = result->Get("vars").at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); @@ -341,8 +345,10 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); auto &vars = result->Get("vars").at(i).at(p_name); - auto *out_var = new VarHandle(vars.size(), i, p_name, p); + auto *out_var = + new VarHandle(result->nodes.back().get(), vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } @@ -351,19 +357,21 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const { - result->Get("ops").emplace_back( - new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + result->Get("ops").emplace_back(new ComputationOpHandle( + result->nodes.back().get(), op, local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, op, dev_id); } void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new AllReduceOpHandle( + result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back( - new AllReduceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back(new AllReduceOpHandle( + result->nodes.back().get(), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -375,7 +383,8 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); - auto var = new VarHandle(vars.size(), i, og, p); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto var = new VarHandle(result->nodes.back().get(), vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -383,12 +392,13 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, void MultiDevSSAGraphBuilder::InsertDataBalanceOp( Graph *result, const std::vector &datas) const { + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); #else - result->Get("ops").emplace_back( - new DataBalanceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->nodes.back().get(), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { @@ -398,7 +408,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - auto var = new VarHandle(vars.size(), i, d_name, p); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto var = + new VarHandle(result->nodes.back().get(), vars.size(), i, d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -452,10 +464,10 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { auto *communication_dev_ctx = platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif - - auto *op_handle = - new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], - places_[i], communication_dev_ctx); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + auto *op_handle = new ScaleLossGradOpHandle( + result->nodes.back().get(), local_scopes_.size(), local_scopes_[i], + places_[i], communication_dev_ctx); result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale @@ -475,8 +487,9 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); result->Get("ops").emplace_back( - new ComputationOpHandle(op, s, p)); + new ComputationOpHandle(result->nodes.back().get(), op, s, p)); CreateOpHandleIOs(result, op, scope_idx); } } @@ -484,12 +497,13 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); #ifdef PADDLE_WITH_CUDA - result->Get("ops").emplace_back( - new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new ReduceOpHandle( + result->nodes.back().get(), local_scopes_, places_, nccl_ctxs_)); #else result->Get("ops").emplace_back( - new ReduceOpHandle(local_scopes_, places_)); + new ReduceOpHandle(result->nodes.back().get(), local_scopes_, places_)); #endif auto *op_handle = result->Get("ops").back().get(); @@ -502,7 +516,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get("vars")[dst_dev_id][og]; - auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto var = new VarHandle(result->nodes.back().get(), vars.size(), dst_dev_id, + og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); return var; @@ -514,7 +530,8 @@ void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - auto *dep_var = new DummyVarHandle(); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *dep_var = new DummyVarHandle(result->nodes.back().get()); prev_op->AddOutput(dep_var); result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); @@ -587,8 +604,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", op.Type()); - result->Get("ops").emplace_back(new RPCOpHandle( - op, local_scopes_[op_dev_id], op.Type(), places_[op_dev_id])); + result->nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); + result->Get("ops").emplace_back( + new RPCOpHandle(result->nodes.back().get(), op, local_scopes_[op_dev_id], + op.Type(), places_[op_dev_id])); if (op.Type() == "send_barrier") { ConnectOp(result, result->Get("ops").back().get(), "send"); diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index d80bdcf15d7..ee9f9184da6 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void OpHandleBase::AddInput(VarHandleBase *in) { this->inputs_.emplace_back(in); - in->pending_ops_.insert(this); + node_->inputs.push_back(in->Node()); + in->AddOutput(this, this->Node()); } void OpHandleBase::AddOutput(VarHandleBase *out) { outputs_.emplace_back(out); - out->generated_op_ = this; + node_->outputs.push_back(out->Node()); + out->AddInput(this, this->Node()); } void OpHandleBase::WaitInputVarGenerated() { for (auto in_var : inputs_) { if (NeedWait(in_var)) { for (auto &pair : dev_ctxes_) { - in_var->generated_op_->RecordWaitEventOnCtx(pair.second); + in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second); } } } @@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() { void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { for (auto *in : inputs_) { if (NeedWait(in)) { - in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]); + in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]); } } } @@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const { } bool OpHandleBase::NeedWait(VarHandleBase *in_var) { - return in_var && in_var->generated_op_; + return in_var && in_var->GeneratedOp(); } void OpHandleBase::RunAndRecordEvent(const std::function &callback) { diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 6aec1788311..368a1537115 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/details/var_handle.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/macros.h" @@ -28,7 +29,7 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; class OpHandleBase { public: - OpHandleBase() {} + explicit OpHandleBase(ir::Node *node) : node_(node) {} virtual ~OpHandleBase(); @@ -82,6 +83,8 @@ class OpHandleBase { size_t NoDummyInputSize() const; + ir::Node *Node() { return node_; } + protected: void RunAndRecordEvent(const std::function &callback); @@ -90,6 +93,7 @@ class OpHandleBase { virtual void RunImpl() = 0; + ir::Node *node_; std::vector inputs_; std::vector outputs_; std::map dev_ctxes_; diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 4d14334cdfe..a6289b055f9 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase { #ifdef PADDLE_WITH_CUDA const platform::NCCLContextMap *nccl_ctxs_; - ReduceOpHandle(const std::vector &local_scopes, + ReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *nccl_ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(nccl_ctxs) { if (nccl_ctxs_) { for (auto &p_ctx : nccl_ctxs_->contexts_) { dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); @@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase { } } #else - ReduceOpHandle(const std::vector &local_scopes, + ReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index ffdd7c14eb5..d029dd9e15e 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -84,6 +84,7 @@ struct TestReduceOpHandle { } void InitReduceOp(size_t out_scope_idx) { + std::vector> nodes; // init scope for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); @@ -96,19 +97,21 @@ struct TestReduceOpHandle { } param_scopes_[out_scope_idx]->Var("out"); + nodes.emplace_back(new ir::Node(ir::Node::Type::kOperation)); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, + gpu_list_, nccl_ctxs_.get())); #else PADDLE_THROW("CUDA is not support."); #endif } else { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, + gpu_list_, nccl_ctxs_.get())); #else - op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_)); + op_handle_.reset( + new ReduceOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); #endif } @@ -118,8 +121,10 @@ struct TestReduceOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); - in_var_handle->generated_op_ = nullptr; + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *in_var_handle = + new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); + in_var_handle->ClearGeneratedOp(); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); } @@ -128,12 +133,13 @@ struct TestReduceOpHandle { vars_.emplace_back(new DummyVarHandle()); DummyVarHandle *in_dummy_var_handle = static_cast(vars_.back().get()); - in_dummy_var_handle->generated_op_ = nullptr; + in_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(in_dummy_var_handle); // add output - auto *out_var_handle = - new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]); + nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, + "out", gpu_list_[out_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index 586465f99fd..924ff4d118a 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -18,10 +18,11 @@ namespace paddle { namespace framework { namespace details { -RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, +RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, const Scope *local_scope, const std::string &name, const platform::Place &place) - : op_(framework::OpRegistry::CreateOp(op_desc)), + : OpHandleBase(node), + op_(framework::OpRegistry::CreateOp(op_desc)), local_scope_(local_scope), name_(name), place_(place) {} @@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() { if (in->DebugString() == "dummy") { // HACK continue; } - if (in->generated_op_) { - in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); + if (in->GeneratedOp()) { + in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]); } } auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h index ae38c7fe19e..7f99cdeacf6 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.h +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -28,8 +28,9 @@ namespace framework { namespace details { struct RPCOpHandle : public OpHandleBase { - RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const std::string& name, const platform::Place& place); + RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc, + const Scope* local_scope, const std::string& name, + const platform::Place& place); std::string Name() const override; diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index d9c387e79dc..609e1858195 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -19,10 +19,14 @@ namespace paddle { namespace framework { namespace details { -ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, +ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, + Scope *scope, platform::Place place, platform::DeviceContext *dev_ctx) - : coeff_(static_cast(1.0 / num_dev)), scope_(scope), place_(place) { + : OpHandleBase(node), + coeff_(static_cast(1.0 / num_dev)), + scope_(scope), + place_(place) { dev_ctxes_[place_] = dev_ctx; } diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index d93d599d46f..523b55724c8 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -25,7 +25,8 @@ namespace framework { namespace details { struct ScaleLossGradOpHandle : public OpHandleBase { - ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place, + ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope, + platform::Place place, platform::DeviceContext *context); ~ScaleLossGradOpHandle() final; diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 2508ed0296d..846f98ddfa3 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { auto it_old = name_pair.second.rbegin(); ++it_old; for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - auto *write_op = (*it_new)->generated_op_; - auto &read_ops = (*it_old)->pending_ops_; + OpHandleBase *write_op = (*it_new)->GeneratedOp(); + const auto &read_ops = (*it_old)->PendingOps(); for (auto *read_op : read_ops) { // Manually add a dependency var from read_op to write_op; @@ -37,7 +37,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } - auto *dep_var = new DummyVarHandle(); + graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *dep_var = new DummyVarHandle(graph->nodes.back().get()); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); graph->Get("dep_vars").emplace(dep_var); @@ -54,7 +55,9 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( auto &var_holder = var_holders[each_var_name]; VarHandle *var = nullptr; if (var_holder.empty()) { - var = new VarHandle(0, place_offset, each_var_name, place); + graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + var = new VarHandle(graph->nodes.back().get(), 0, place_offset, + each_var_name, place); var_holder.emplace_back(var); } else { var = var_holder.rbegin()->get(); @@ -68,7 +71,9 @@ void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, size_t place_offset) { auto &vars = graph->Get("vars")[place_offset][each_var_name]; size_t version = vars.size(); - auto var = new VarHandle(version, place_offset, each_var_name, place); + graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto var = new VarHandle(graph->nodes.back().get(), version, place_offset, + each_var_name, place); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -80,7 +85,8 @@ void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { if (!op->Outputs().empty()) { continue; } - auto *dummy_leaf = new DummyVarHandle(); + graph->nodes.emplace_back(new ir::Node(ir::Node::Type::kVariable)); + auto *dummy_leaf = new DummyVarHandle(graph->nodes.back().get()); graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index c01334ca06f..6a211f52bb0 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -28,7 +28,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { auto insert_pending_var = [&](VarHandleBase *var) { pending_vars.insert(var); - if (var->generated_op_ == nullptr) { + if (var->GeneratedOp() == nullptr) { ready_vars.emplace(var); } }; @@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { for (auto ready_var : ready_vars) { pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { + for (auto *op : ready_var->PendingOps()) { auto &deps = --pending_ops[op]; if (deps == 0) { ready_ops.insert(op); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ed8e38039ed..9a2413118e2 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -65,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Step 2. Insert FetchOps std::vector> fetch_ops; + std::vector> tmp_nodes; std::unordered_set> fetch_dependencies; FeedFetchList fetch_data(fetch_tensors.size()); - InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, - &pending_vars, &ready_vars, &fetch_data); + InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies, + &pending_ops, &pending_vars, &ready_vars, &fetch_data); auto run_all_ops = [&](std::unordered_set &set) { for (auto *op : set) { @@ -126,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Find the ready_ops after the ready_var. for (auto ready_var : cur_ready_vars) { pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { + for (auto *op : ready_var->PendingOps()) { auto &deps = pending_ops[op]; --deps; if (deps == 0) { @@ -152,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( void ThreadedSSAGraphExecutor::InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, + std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, @@ -170,7 +172,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_); + + ir::Node *fetch_n = new ir::Node(ir::Node::Type::kOperation); + auto *op = new FetchOpHandle(fetch_n, fetch_data, i, &local_scopes_); + temp_nodes->emplace_back(fetch_n); fetch_ops->emplace_back(op); for (auto &p : places_) { @@ -181,9 +186,11 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - auto *fetch_dummy = new DummyVarHandle(); + ir::Node *dummy_n = new ir::Node(ir::Node::Type::kVariable); + auto *fetch_dummy = new DummyVarHandle(dummy_n); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); + temp_nodes->emplace_back(dummy_n); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); this->InsertPendingOp(pending_ops, op); } @@ -199,7 +206,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( std::unordered_set *pending_vars, BlockingQueue *ready_vars, VarHandleBase *var) const { pending_vars->insert(var); - if (var->generated_op_ == nullptr) { + if (var->GeneratedOp() == nullptr) { ready_vars->Push(var); } } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 7d0aaf2ddc0..bf7c0a367a1 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -72,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { void InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, + std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index c62f9a9d087..8bd3db9203b 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once + +#include #include #include #include @@ -30,15 +32,51 @@ class OpHandleBase; // A variable can only be generated by a single operator. i.e. // This is a single assignment graph. struct VarHandleBase { + explicit VarHandleBase(ir::Node* node) : node_(node) {} + virtual ~VarHandleBase(); + virtual std::string DebugString() const = 0; + void AddInput(OpHandleBase* in, ir::Node* node) { + node_->inputs.clear(); + node_->inputs.push_back(node); + generated_op_ = in; + } + + void AddOutput(OpHandleBase* out, ir::Node* node) { + if (pending_ops_.find(out) == pending_ops_.end()) { + pending_ops_.insert(out); + node_->outputs.push_back(node); + } + } + + void RemoveOutput(OpHandleBase* out, ir::Node* node) { + pending_ops_.erase(out); + std::remove(node_->outputs.begin(), node_->outputs.end(), node); + } + + void ClearGeneratedOp() { + generated_op_ = nullptr; + node_->inputs.clear(); + } + + OpHandleBase* GeneratedOp() { return generated_op_; } + + const std::unordered_set& PendingOps() const { + return pending_ops_; + } + + ir::Node* Node() { return node_; } + + protected: // The operator who generate this variable. nullptr if the variable // is a root node. OpHandleBase* generated_op_{nullptr}; // Operators which depend on this variable ready. std::unordered_set pending_ops_; + ir::Node* node_; }; // VarHandle is actually a single version of Runtime Variable. @@ -47,11 +85,14 @@ struct VarHandleBase { // // NOTE: runtime variables have place. struct VarHandle : public VarHandleBase { + explicit VarHandle(ir::Node* node) : VarHandleBase(node) {} + std::string DebugString() const override; - VarHandle(size_t version, size_t scope_index, std::string name, - platform::Place place) - : version_(version), + VarHandle(ir::Node* node, size_t version, size_t scope_index, + std::string name, platform::Place place) + : VarHandleBase(node), + version_(version), scope_idx_(scope_index), name_(std::move(name)), place_(std::move(place)) {} @@ -71,6 +112,8 @@ struct VarHandle : public VarHandleBase { // Dummy Variable. It is used to represent dependencies between operators struct DummyVarHandle : public VarHandleBase { + explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {} + std::string DebugString() const override; }; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 0fd80483904..94ace92953a 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -14,10 +14,12 @@ limitations under the License. */ #pragma once +#include #include #include #include #include +#include #include #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/variant.h" @@ -30,10 +32,10 @@ class Node { public: enum class Type { kNone = -1, kOperation, kVariable }; - Node(const std::string& name, Type type) : name_(name), type_(type) {} + explicit Node(Type type) : type_(type) {} virtual ~Node() { - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { if (attr_dels_.find(attr.first) != attr_dels_.end()) { attr_dels_[attr.first](); } @@ -42,54 +44,32 @@ class Node { attrs_.clear(); } - int64_t ID() const { return id_; } - - std::string Name() const { return name_; } - - virtual std::string ToString() const { - return Name() + "(" + std::to_string(ID()) + ")"; - } - - virtual std::string DebugString() const = 0; - Type NodeType() const { return type_; } template - void Set(const std::string& name, AttrType attr) { + void Set(const std::string &name, AttrType attr) { attrs_[name] = attr; } template - void Set(const std::string& name, AttrType* attr, + void Set(const std::string &name, AttrType *attr, std::function attr_del) { attrs_[name] = attr; attr_dels_[name] = attr_del; } - std::vector inputs; - std::vector outputs; + std::vector inputs; + std::vector outputs; protected: std::map attrs_; std::map> attr_dels_; - int64_t id_ = 0; - std::string name_; Type type_; private: DISABLE_COPY_AND_ASSIGN(Node); }; -class Variable : public Node { - public: - explicit Variable(const std::string& name) : Node(name, Type::kVariable) {} -}; - -class Operation : public Node { - public: - explicit Operation(const std::string& name) : Node(name, Type::kOperation) {} -}; - } // namespace ir } // namespace framework } // namespace paddle -- GitLab