diff --git a/doc/fluid/design/ir/draft.md b/doc/fluid/design/ir/draft.md new file mode 100644 index 0000000000000000000000000000000000000000..a141dcbca584c6064c8da863410692a8be911d12 --- /dev/null +++ b/doc/fluid/design/ir/draft.md @@ -0,0 +1,89 @@ +## Motivation + +There is a ```gap``` between the ```Program``` defined by +user and the ```Executable``` that can be scheduled +efficiently on heterogeneous hardware, either locally +or distributedly. + +Usually, the ```gap``` is bridged by + +* A serious transformations with defined order. + +* These transformations usually involve +```insert, delete, clustering, split, dependency analysis```. + +* Has a simple way to verify and debug each transformation. + +* Flexible to add, remove or customize transformations to fit +the requirements of various algorithms (models) and hardware secenarios. + +Some other events also push us to a better unified pattern. + +* The deep learning framework is built around the concepts of graphs. +To leverage tools such as compilation (e.g. TVM and nGraph) or +cross-framework conversion (e.g. ONNX), we also need a intermediate +representation that can be connected to the rest of the ecosystem. + + +We need a unified pattern to naturally support the requirements +described above. The pattern should fit both training, inference +and other offline serielized model transformations. +Learned from LLVM and other deep learning framework, we draft the +design below. + + +## Design + +### Major Concepts + +#### Node + +```Node``` represents an operation that performs some computation or +a variable that is input or output of operation. + +```Node```s are connected to other ```Node```s via inputs and outputs. + +Other properties (maybe device placement information) can be added +to ```Node``` in the future if it's a +common requirement of many other ```Pass```es. Otherwise, it should live +in a ```Node``` wrapper class that is private to some ```Pass``` or be +a local member of a ```Pass```. + +#### Graph + +```Graph``` contains a list of ```Node```s, which are connected to +each other via inputs and outputs. + +TODO: Better definitions for the graph. + +```Graph``` can also contain ```Attribute```s. ```Attribute```s +can be ``any`` thing. For example, it can be a list of "wraper" +nodes. The ```wrapper``` nodes compose ```Node```s and provide +helper method for execution or transformation. ```Attribute``` +can also contain other things that describe some properties of +the ```Graph``` or ```Graph``` nodes. ```Attribute``` can be passed +across ```Pass```. However, it should be used with care. + +#### Pass + +```Pass``` represents a transformation of ```Graph```. Its input +is a ```Graph``` and its output is also a ```Graph```. For example, +a ```Pass``` can simply print out the ```Graph```. A ```Pass``` +can also fuse some ```Graph```'s ```Node```s. + +#### Optimize + +```Optimize``` contains a series of ```Pass``` with defined order. +```Optimize``` transforms a ```Graph``` that only contains raw +modeling logic to a ```Graph``` that can be run efficiently while +maintaining the original modeling logic. + + +### Optimize Process + +* Program is first converted to Graph. +* Graph goes through a series of Pass +* Graph is transformed from raw model logic to a +form that is efficient to execute. + +Program->ProgramToGraph->Graph->Pass1->Graph->Pass2->Graph->Pass3->Graph->Executor diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index ec252929d5584c211cea7fa52004ecdfdf586a85..de06c860f550641a58a32d49e85feb7278fed1dd 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(details) +add_subdirectory(ir) # ddim lib proto_library(framework_proto SRCS framework.proto) @@ -93,7 +94,7 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 4fb4ec38ee965a2790d11378a1ce6befa0ef5a00..e8057c35e8b957cb43e66937a5073a085c6e7708 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -5,8 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) -cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) -cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) +cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS graph) cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(ssa_graph_checker SRCS ssa_graph_checker.cc DEPS ssa_graph_builder) @@ -35,7 +34,7 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) -cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) +cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto) cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index b335d3a0d364c916e19574de8d3ed89aaec7de41..700c73c745bad72637d77385f5cd38c494501c86 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -23,10 +23,14 @@ namespace framework { namespace details { #ifdef PADDLE_WITH_CUDA -AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, +AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(ctxs) { if (nccl_ctxs_) { for (auto &p : places_) { this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); @@ -34,9 +38,10 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, } } #else -AllReduceOpHandle::AllReduceOpHandle(const std::vector &local_scopes, +AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif void AllReduceOpHandle::RunImpl() { diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index fdd250b0d3eb166249271a95f7592b9fadee5265..f6ef3a1367b91b6abf8ce74a91f73056efd0f84e 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -30,11 +30,11 @@ namespace details { struct AllReduceOpHandle : public OpHandleBase { #ifdef PADDLE_WITH_CUDA - AllReduceOpHandle(const std::vector &local_scopes, + AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs); #else - AllReduceOpHandle(const std::vector &local_scopes, + AllReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 8036f756b6d6506684c109ab881d546f38176a10..fe4e733e43417977df324fde808f52b228a27d19 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -35,10 +35,13 @@ namespace details { struct BroadcastOpHandle : public OpHandleBase { public: #ifdef PADDLE_WITH_CUDA - BroadcastOpHandle(const std::vector &local_scopes, + BroadcastOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *nccl_ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(nccl_ctxs) { if (nccl_ctxs_) { for (auto &p_ctx : nccl_ctxs_->contexts_) { dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); @@ -46,9 +49,9 @@ struct BroadcastOpHandle : public OpHandleBase { } } #else - BroadcastOpHandle(const std::vector &local_scopes, + BroadcastOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc index c6e923ef77ff03413eefe4f26457a5322747618e..1413f7bd9ac515ae7dceee62de8f3bc74e3a2efc 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc @@ -96,48 +96,61 @@ struct TestBroadcastOpHandle { } param_scopes_[input_scope_idx]->Var("input"); + std::unique_ptr n( + new ir::Node("node0", ir::Node::Type::kOperation)); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, + nccl_ctxs_.get())); #else PADDLE_THROW("CUDA is not support."); #endif } else { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new BroadcastOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_, + nccl_ctxs_.get())); #else - op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); + op_handle_.reset( + new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_)); #endif } - auto* in_var_handle = - new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]); + std::unique_ptr v( + new ir::Node("node1", ir::Node::Type::kVariable)); + auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input", + gpu_list_[input_scope_idx]); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); // add dummy var - vars_.emplace_back(new DummyVarHandle()); + + std::unique_ptr v2( + new ir::Node("node2", ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(v2.get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); - dummy_var_handle->generated_op_ = nullptr; + dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(dummy_var_handle); for (size_t j = 0; j < gpu_list_.size(); ++j) { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); + std::unique_ptr v3( + new ir::Node("node3", ir::Node::Type::kVariable)); + VarHandle* out_var_handle = + new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); } // add dummy var - vars_.emplace_back(new DummyVarHandle()); + std::unique_ptr v4( + new ir::Node("node4", ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(v4.get())); DummyVarHandle* out_dummy_var_handle = static_cast(vars_.back().get()); - out_dummy_var_handle->generated_op_ = nullptr; + out_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddOutput(out_dummy_var_handle); } diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index df05bb06333d6b964f2f5434c3d43214e5d2cb7a..b6282debdb4eb6b1f29c39e54ac4f3e2296838da 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -19,9 +19,10 @@ namespace paddle { namespace framework { namespace details { -ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, +ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place) - : op_(framework::OpRegistry::CreateOp(op_desc)), + : OpHandleBase(node), + op_(framework::OpRegistry::CreateOp(*node->Op())), scope_(scope), place_(place) {} @@ -35,8 +36,8 @@ void ComputationOpHandle::RunImpl() { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool need_wait = - in_var && in_var->generated_op_ && - in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_]; + in_var && in_var->GeneratedOp() && + in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_]; return need_wait; } diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index f048f973fdeb6cf7d1485cda8cea7d530d9ba465..d9fcd92427ef38b131b4ce782c0ada37765682db 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -28,8 +28,7 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(const OpDesc &op_desc, Scope *scope, - platform::Place place); + ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); std::string Name() const override; diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index 68896c8ac1bae7d4bfcfa79cc8ec5c26bf2d93ee..525d24322442ef4dd6e8c24212af61c908959b87 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -22,10 +22,10 @@ namespace details { #ifdef PADDLE_WITH_CUDA DataBalanceOpHandle::DataBalanceOpHandle( - const std::vector &local_scopes, + ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs) - : local_scopes_(local_scopes), places_(places) { + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) { if (ctxs) { for (auto &p : places_) { this->dev_ctxes_[p] = ctxs->DevCtx(p); @@ -34,9 +34,9 @@ DataBalanceOpHandle::DataBalanceOpHandle( } #else DataBalanceOpHandle::DataBalanceOpHandle( - const std::vector &local_scopes, + ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string DataBalanceOpHandle::Name() const { return "data balance"; } diff --git a/paddle/fluid/framework/details/data_balance_op_handle.h b/paddle/fluid/framework/details/data_balance_op_handle.h index 76a407e3610e8bb48facf1f814779f4c23f92d98..0462fb6ec713eb977f420a9cb485c0273e782496 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.h +++ b/paddle/fluid/framework/details/data_balance_op_handle.h @@ -30,11 +30,11 @@ namespace details { struct DataBalanceOpHandle : public OpHandleBase { public: #ifdef PADDLE_WITH_CUDA - DataBalanceOpHandle(const std::vector &local_scopes, + DataBalanceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *ctxs); #else - DataBalanceOpHandle(const std::vector &local_scopes, + DataBalanceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); #endif diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index d646c944601e81477787740189d7ac60ae97fa80..fe18b2060c5cd7e157374da53c5a985f70545ab7 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -21,13 +21,16 @@ namespace paddle { namespace framework { namespace details { -FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset, +FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, std::vector *local_scopes) - : data_(data), offset_(offset), local_scopes_(local_scopes) {} + : OpHandleBase(node), + data_(data), + offset_(offset), + local_scopes_(local_scopes) {} FetchOpHandle::~FetchOpHandle() { for (auto *input_var : inputs_) { - input_var->pending_ops_.erase(this); + input_var->RemoveOutput(this, this->Node()); } } @@ -77,8 +80,8 @@ void FetchOpHandle::RunImpl() { void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) { auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place); for (auto *input : inputs_) { - if (input->generated_op_) { - input->generated_op_->RecordWaitEventOnCtx(cpu_ctx); + if (input->GeneratedOp()) { + input->GeneratedOp()->RecordWaitEventOnCtx(cpu_ctx); } } } diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index e09bdd1d3338bb175c1ddae35b53f98197b68e9a..6ce42f92d7f1e81eeafd1eb5c28ce3564a5ffebc 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -28,7 +28,7 @@ namespace details { struct FetchOpHandle : public OpHandleBase { public: - FetchOpHandle(FeedFetchList *data, size_t offset, + FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, std::vector *local_scopes); ~FetchOpHandle(); diff --git a/paddle/fluid/framework/details/fuse_vars_op_handle.h b/paddle/fluid/framework/details/fuse_vars_op_handle.h index 140fb5bb49a33146de974b6d79559b4cf15bdd7b..3f360c510a4fdc0caaeb15d862b217ef41b8ea6e 100644 --- a/paddle/fluid/framework/details/fuse_vars_op_handle.h +++ b/paddle/fluid/framework/details/fuse_vars_op_handle.h @@ -30,10 +30,12 @@ namespace details { struct FuseVarsOpHandle : public OpHandleBase { public: - FuseVarsOpHandle(Scope *local_scope, const platform::Place &place, + FuseVarsOpHandle(ir::Node *node, Scope *local_scope, + const platform::Place &place, const std::unordered_map &inputs_numel, const std::type_index &var_type) - : local_scope_(local_scope), + : OpHandleBase(node), + local_scope_(local_scope), place_(place), inputs_numel_(inputs_numel), type_(var_type) { diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 2be02304566cf5dbe348fa01fc4171990eafd158..9aae19fc73de4387186da47c55710c94d53f1b88 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -20,9 +20,10 @@ namespace paddle { namespace framework { namespace details { -GatherOpHandle::GatherOpHandle(const std::vector &local_scopes, +GatherOpHandle::GatherOpHandle(ir::Node *node, + const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} void GatherOpHandle::RunImpl() { if (places_.size() == 1) return; diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index d11ef8556aa8840949ca8dc7aa176413f70b9f22..d9afbc6547e18e8886c414ff150e332cfaf9b0c3 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -30,7 +30,7 @@ namespace details { struct GatherOpHandle : public OpHandleBase { public: - GatherOpHandle(const std::vector &local_scopes, + GatherOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places); std::string Name() const override; diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index 3cce2cc1640b3866130126424ff8fef18b8befc6..c9b94d1e1039df6ff27f9ffe225b2a50c35a5c50 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -70,6 +70,7 @@ struct TestGatherOpHandle { } void InitGatherOp(size_t input_scope_idx) { + std::vector> nodes; for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); Scope& local_scope = local_scopes_.back()->NewScope(); @@ -81,30 +82,37 @@ struct TestGatherOpHandle { } param_scopes_[input_scope_idx]->Var("out"); - op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_)); + nodes.emplace_back(new ir::Node("node", ir::Node::Type::kOperation)); + op_handle_.reset( + new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); - auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); + nodes.emplace_back(new ir::Node("node1", ir::Node::Type::kVariable)); + auto* in_var_handle = + new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); } // add dummy var - vars_.emplace_back(new DummyVarHandle()); + nodes.emplace_back(new ir::Node("node2", ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* in_dummy_var_handle = static_cast(vars_.back().get()); - in_dummy_var_handle->generated_op_ = nullptr; + in_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(in_dummy_var_handle); // add output - auto* out_var_handle = - new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]); + nodes.emplace_back(new ir::Node("node3", ir::Node::Type::kVariable)); + auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, + "out", gpu_list_[input_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); // add dummy var - vars_.emplace_back(new DummyVarHandle()); + nodes.emplace_back(new ir::Node("node4", ir::Node::Type::kVariable)); + vars_.emplace_back(new DummyVarHandle(nodes.back().get())); DummyVarHandle* dummy_var_handle = static_cast(vars_.back().get()); op_handle_->AddOutput(dummy_var_handle); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 6f5d4471a97cc4efc73b9df68040ab9eccde0b1c..c52980472de8d48e8c21e7c1e53813aa4847cece 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -25,6 +25,7 @@ #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" @@ -66,31 +67,38 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, - const OpDesc &op, +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, ir::Node *node, size_t place_id) const { auto p = places_[place_id]; - auto *op_handle = result->ops_.back().get(); + auto *op_handle = result->Get("ops").back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); - for (auto &each_var_name : op.InputArgumentNames()) { - VarHandle *var = - CreateOrGetLatestVarHandle(result, each_var_name, p, place_id); + for (ir::Node *input : node->inputs) { + VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id); op_handle->AddInput(var); } - for (auto &each_var_name : op.OutputArgumentNames()) { - CreateOpOutput(result, op_handle, each_var_name, p, place_id); + for (ir::Node *output : node->outputs) { + ir::Node *new_node = nullptr; + if (output->Var()) { + new_node = result->CreateVarNode(output->Var()); + } else { + new_node = + result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); + } + CreateOpOutput(result, op_handle, new_node, p, place_id); } } std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( - const ProgramDesc &program) const { + const std::vector> &nodes) const { std::vector send_vars; // since parameters are all in block 0, // it's enough to only scan send ops in block 0 - for (auto *op : program.Block(0).AllOps()) { + for (auto &node : nodes) { + if (node->NodeType() != ir::Node::Type::kOperation) continue; + OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find send op, // instead of the the hard code string if (op->Type() == "send") { @@ -104,9 +112,11 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( } std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( - const ProgramDesc &program) const { + const std::vector> &nodes) const { std::vector recv_vars; - for (auto *op : program.Block(0).AllOps()) { + for (auto &node : nodes) { + if (node->NodeType() != ir::Node::Type::kOperation) continue; + OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find recv op, // instead of the hard code string if (op->Type() == "recv") { @@ -120,7 +130,7 @@ std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( } bool MultiDevSSAGraphBuilder::IsDistTrainOp( - const OpDesc &op, const std::vector &send_vars, + ir::Node *node, const std::vector &send_vars, const std::vector &recv_vars) const { if (send_vars.size() == 0 || recv_vars.size() == 0) { return false; @@ -143,8 +153,17 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( return false; }; - return checker(op.OutputArgumentNames(), send_vars) || - checker(op.InputArgumentNames(), recv_vars); + std::vector input_var_names; + std::vector output_var_names; + for (ir::Node *input : node->inputs) { + input_var_names.push_back(input->Name()); + } + for (ir::Node *output : node->outputs) { + output_var_names.push_back(output->Name()); + } + + return checker(output_var_names, send_vars) || + checker(input_var_names, recv_vars); } size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( @@ -167,25 +186,30 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( return dev_id; } -std::unique_ptr MultiDevSSAGraphBuilder::Build( - const ProgramDesc &program) const { - for (auto *var : program.Block(0).AllVars()) { - all_vars_.emplace(var->Name(), var); +std::unique_ptr MultiDevSSAGraphBuilder::Apply( + std::unique_ptr graph) const { + // Rebuild the graph structure. + auto nodes = std::move(graph->nodes); + graph->nodes.clear(); + + for (auto &node : nodes) { + if (node->NodeType() == ir::Node::Type::kVariable) { + all_vars_.emplace(node->Name(), node->Var()); + } } - auto graph = new SSAGraph(); - SSAGraph &result = *graph; + Graph &result = *graph; std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 - result.vars_ = std::vector< - std::unordered_map>>>( - places_.size()); + result.Set("vars", new GraphVars(places_.size())); + result.Set("dep_vars", new GraphDepVars); + result.Set("ops", new GraphOps); // find send/recv vars so that we can place the distributed training // realted op in the place 0 - auto send_vars = FindDistTrainSendVars(program); - auto recv_vars = FindDistTrainRecvVars(program); + auto send_vars = FindDistTrainSendVars(nodes); + auto recv_vars = FindDistTrainRecvVars(nodes); std::vector> bcast_var_name_set; bcast_var_name_set.resize(places_.size()); @@ -193,14 +217,19 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( size_t cur_device_id = 0; bool is_forwarding = true; - for (auto *op : program.Block(0).AllOps()) { + // NOTE: Currently, passes before SSAGraphBuilder cannot reorder + // forward, backward nodes. E.g. you can't append an forward node + // at the end of the node list. + // TODO(panyx0718): FIXME: Needs to sort by forward->backward order. + for (auto &node : nodes) { + if (node->NodeType() != ir::Node::Type::kOperation) continue; if (boost::get( - op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { - CreateRPCOp(&result, *op); - } else if (IsDistTrainOp(*op, send_vars, recv_vars)) { - CreateDistTrainOp(&result, *op); - } else if (IsScaleLossOp(*op)) { + CreateRPCOp(&result, node.get()); + } else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) { + CreateDistTrainOp(&result, node.get()); + } else if (IsScaleLossOp(node.get())) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != BuildStrategy::GradientScaleStrategy::kCustomized) { @@ -212,33 +241,35 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( // the block. is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(*op); + int op_dev_id = GetOpDeviceID(node.get()); if (op_dev_id != -1) { // This op only runs on one specific device. - CreateComputationalOp(&result, *op, op_dev_id); - for (auto &var_name : op->OutputArgumentNames()) { - var_name_on_devices_.emplace(var_name, op_dev_id); + CreateComputationalOp(&result, node.get(), op_dev_id); + for (ir::Node *n : node->outputs) { + var_name_on_devices_.emplace(n->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's // gradients. - if (op->Type() == "read" && strategy_.enable_data_balance_) { - op->SetAttr("throw_eof_exp", false); - CreateComputationalOps(&result, *op, places_.size()); - const auto &data_var_names = op->Output("Out"); + if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { + node->Op()->SetAttr("throw_eof_exp", false); + CreateComputationalOps(&result, node.get(), places_.size()); + // TODO(paddle-dev): builder shouldn't depend on the out logic of + // a specific op. + const auto &data_var_names = node->Op()->Output("Out"); InsertDataBalanceOp(&result, data_var_names); } else { - CreateComputationalOps(&result, *op, places_.size()); + CreateComputationalOps(&result, node.get(), places_.size()); } if (!is_forwarding && places_.size() > 1) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. - if (static_cast(boost::get(op->GetAttr( + if (static_cast(boost::get(node->Op()->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward))) { try { - auto backward_vars = - boost::get>(op->GetNullableAttr( + auto backward_vars = boost::get>( + node->Op()->GetNullableAttr( OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); @@ -302,8 +333,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - - return std::unique_ptr(graph); + return std::move(graph); } bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { @@ -327,78 +357,96 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext( #endif } -void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA - auto *op_handle = new BroadcastOpHandle(local_scopes_, places_, nccl_ctxs_); + auto *op_handle = new BroadcastOpHandle( + result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_); #else - auto *op_handle = new BroadcastOpHandle(local_scopes_, places_); + auto *op_handle = new BroadcastOpHandle( + result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), + local_scopes_, places_); #endif + result->Get("ops").emplace_back(op_handle); - result->ops_.emplace_back(op_handle); - auto *in = result->vars_.at(src_dev_id).at(p_name).back().get(); + auto *in = + result->Get("vars").at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_.at(i).at(p_name); - auto *out_var = new VarHandle(vars.size(), i, p_name, p); + auto &vars = result->Get("vars").at(i).at(p_name); + auto *out_var = new VarHandle( + result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), + i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } } -void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, - const OpDesc &op, +void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result, + ir::Node *node, int dev_id) const { - result->ops_.emplace_back( - new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id])); - CreateOpHandleIOs(result, op, dev_id); + result->Get("ops").emplace_back( + new ComputationOpHandle(result->CreateOpNode(node->Op()), + local_scopes_[dev_id], places_[dev_id])); + CreateOpHandleIOs(result, node, dev_id); } -void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, +void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new AllReduceOpHandle( + result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back(new AllReduceOpHandle( + result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_[i][og]; + auto &vars = result->Get("vars")[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); - auto var = new VarHandle(vars.size(), i, og, p); + auto var = + new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), + vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } } void MultiDevSSAGraphBuilder::InsertDataBalanceOp( - SSAGraph *result, const std::vector &datas) const { + Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back(new DataBalanceOpHandle( + result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { - auto &vars = result->vars_[i][d_name]; + auto &vars = result->Get("vars")[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); - auto var = new VarHandle(vars.size(), i, d_name, p); + auto var = new VarHandle( + result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), + vars.size(), i, d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); } @@ -417,22 +465,22 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( return is_pg_once; } -int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const { +int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } int op_role = boost::get( - op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); if (op_role != static_cast(framework::OpRole::kOptimize)) { return -1; } auto param_grad = boost::get>( - op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); int dev_id = GetVarDeviceID(param_grad[1]); - PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(), - param_grad[0]); + PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", + node->Op()->Type(), param_grad[0]); return dev_id; } @@ -441,7 +489,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const { return got == var_name_on_devices_.end() ? -1 : got->second; } -void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { +void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle #ifdef PADDLE_WITH_CUDA @@ -452,11 +500,11 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { auto *communication_dev_ctx = platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif - - auto *op_handle = - new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], - places_[i], communication_dev_ctx); - result->ops_.emplace_back(op_handle); + auto *op_handle = new ScaleLossGradOpHandle( + result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), + local_scopes_.size(), local_scopes_[i], places_[i], + communication_dev_ctx); + result->Get("ops").emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. @@ -464,43 +512,51 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i], - i); + CreateOpOutput(result, op_handle, + result->CreateEmptyNode(GradVarName(loss_var_name_), + ir::Node::Type::kVariable), + places_[i], i); } } -void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, - const OpDesc &op, +void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result, + ir::Node *node, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); - CreateOpHandleIOs(result, op, scope_idx); + result->Get("ops").emplace_back( + new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); + CreateOpHandleIOs(result, node, scope_idx); } } -VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, +VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA - result->ops_.emplace_back( - new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); + result->Get("ops").emplace_back(new ReduceOpHandle( + result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_)); #else - result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_)); + result->Get("ops").emplace_back(new ReduceOpHandle( + result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), + local_scopes_, places_)); #endif - auto *op_handle = result->ops_.back().get(); + auto *op_handle = result->Get("ops").back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); - auto &vars = result->vars_[i][og]; + auto &vars = result->Get("vars")[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } - auto &vars = result->vars_[dst_dev_id][og]; - auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]); + auto &vars = result->Get("vars")[dst_dev_id][og]; + auto var = + new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), + vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); return var; @@ -508,35 +564,46 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, // Find the first occurence of `prev_op_name` and make current `op` depend // on it. -void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op, +void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { - for (auto &prev_op : result->ops_) { + for (auto &prev_op : result->Get("ops")) { if (prev_op->Name() == prev_op_name) { - auto *dep_var = new DummyVarHandle(); + auto *dep_var = new DummyVarHandle( + result->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); prev_op->AddOutput(dep_var); - result->dep_vars_.emplace(dep_var); + result->Get("dep_vars").emplace(dep_var); op->AddInput(dep_var); } } } -void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, - const OpDesc &op) const { +void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result, + ir::Node *node) const { int op_dev_id = -1; - if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + std::vector input_var_names; + std::vector output_var_names; + for (ir::Node *input : node->inputs) { + input_var_names.push_back(input->Name()); + } + for (ir::Node *output : node->outputs) { + output_var_names.push_back(output->Name()); + } + + if (node->Op()->Type() == "split_byref" || + node->Op()->Type() == "split_selected_rows") { + op_dev_id = GetVarDeviceID(input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { - op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); - for (auto &varname : op.InputArgumentNames()) { + op_dev_id = GetAppropriateDeviceID(input_var_names); + for (auto &varname : input_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } - for (auto &varname : op.OutputArgumentNames()) { + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } - } else if (op.Type() == "concat") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); - for (auto &varname : op.OutputArgumentNames()) { + } else if (node->Op()->Type() == "concat") { + op_dev_id = GetVarDeviceID(input_var_names[0]); + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } else { @@ -546,34 +613,43 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result, } PADDLE_ENFORCE(op_dev_id != -1, - "can not find right place for distributed op: %s", op.Type()); + "can not find right place for distributed op: %s", + node->Op()->Type()); - CreateComputationalOp(result, op, op_dev_id); - if (op.Type() == "concat") { - ConnectOp(result, result->ops_.back().get(), "fetch_barrier"); + CreateComputationalOp(result, node, op_dev_id); + if (node->Op()->Type() == "concat") { + ConnectOp(result, result->Get("ops").back().get(), + "fetch_barrier"); } } // Create RPC related op handles that connects its in ops and out ops. -void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, - const OpDesc &op) const { +void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result, ir::Node *node) const { int op_dev_id = -1; - if (op.Type() == "send") { - op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); + if (node->Op()->Type() == "send") { + op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); // the variable name which contains .block means it was splited by // split_byref op // so that we can balance the variable blocks to all the pserver // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && - op.InputArgumentNames()[0].find(".block") == std::string::npos) { - op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); - for (auto &varname : op.InputArgumentNames()) { + node->inputs[0]->Name().find(".block") == std::string::npos) { + std::vector input_var_names; + for (ir::Node *n : node->inputs) { + input_var_names.push_back(n->Name()); + } + op_dev_id = GetAppropriateDeviceID(input_var_names); + for (auto &varname : input_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } - } else if (op.Type() == "recv") { - op_dev_id = GetAppropriateDeviceID(op.OutputArgumentNames()); - for (auto &varname : op.OutputArgumentNames()) { + } else if (node->Op()->Type() == "recv") { + std::vector output_var_names; + for (ir::Node *n : node->outputs) { + output_var_names.push_back(n->Name()); + } + op_dev_id = GetAppropriateDeviceID(output_var_names); + for (auto &varname : output_var_names) { var_name_on_devices_.emplace(varname, op_dev_id); } } else { @@ -582,18 +658,20 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, } PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", - op.Type()); - - result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], - op.Type(), places_[op_dev_id])); - - if (op.Type() == "send_barrier") { - ConnectOp(result, result->ops_.back().get(), "send"); - } else if (op.Type() == "recv") { - ConnectOp(result, result->ops_.back().get(), "send_barrier"); - } else if (op.Type() == "fetch_barrier") { - ConnectOp(result, result->ops_.back().get(), "recv"); - } else if (op.Type() == "send") { + node->Op()->Type()); + + result->Get("ops").emplace_back(new RPCOpHandle( + result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], + node->Op()->Type(), places_[op_dev_id])); + + if (node->Op()->Type() == "send_barrier") { + ConnectOp(result, result->Get("ops").back().get(), "send"); + } else if (node->Op()->Type() == "recv") { + ConnectOp(result, result->Get("ops").back().get(), + "send_barrier"); + } else if (node->Op()->Type() == "fetch_barrier") { + ConnectOp(result, result->Get("ops").back().get(), "recv"); + } else if (node->Op()->Type() == "send") { // do nothing } else { PADDLE_THROW( @@ -601,12 +679,12 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, "send, send_barrier. recv, fetch_barrier]"); } - CreateOpHandleIOs(result, op, op_dev_id); + CreateOpHandleIOs(result, node, op_dev_id); } -bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { +bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { return boost::get( - op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss)) && !loss_var_name_.empty(); // If loss_var is empty. This is test mode diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index a964e024885e56693224a6199e00ff30beaa1df4..2b7f4f586b4e750fde9245286c977258a9db6086 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace platform { @@ -45,13 +46,11 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &local_scopes, const BuildStrategy &strategy); #endif - - std::unique_ptr Build(const ProgramDesc &program) const override; + std::unique_ptr Apply(std::unique_ptr graph) const override; int GetVarDeviceID(const std::string &varname) const override; private: - void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, - size_t device_id) const; + void CreateOpHandleIOs(Graph *result, ir::Node *node, size_t device_id) const; private: std::string loss_var_name_; @@ -63,48 +62,46 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { platform::NCCLContextMap *nccl_ctxs_; #endif - bool IsScaleLossOp(const OpDesc &op) const; + bool IsScaleLossOp(ir::Node *node) const; - void CreateRPCOp(SSAGraph *result, const OpDesc &op) const; - void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const; + void CreateRPCOp(Graph *result, ir::Node *node) const; + void CreateDistTrainOp(Graph *result, ir::Node *node) const; /** * Is this operator as the end-point operator before/after send operator. */ - bool IsDistTrainOp(const OpDesc &op, - const std::vector &send_vars, + bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, const std::vector &recv_vars) const; std::vector FindDistTrainSendVars( - const ProgramDesc &program) const; + const std::vector> &nodes) const; std::vector FindDistTrainRecvVars( - const ProgramDesc &program) const; + const std::vector> &nodes) const; - void ConnectOp(SSAGraph *result, OpHandleBase *op, + void ConnectOp(Graph *result, OpHandleBase *op, const std::string &prev_op_name) const; - void CreateComputationalOps(SSAGraph *result, const OpDesc &op, + void CreateComputationalOps(Graph *result, ir::Node *node, size_t num_places) const; - void CreateScaleLossGradOp(SSAGraph *result) const; - VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og, + void CreateScaleLossGradOp(Graph *result) const; + VarHandle *CreateReduceOp(Graph *result, const std::string &og, int dst_dev_id) const; - void CreateComputationalOp(SSAGraph *result, const OpDesc &op, - int dev_id) const; + void CreateComputationalOp(Graph *result, ir::Node *node, int dev_id) const; bool IsParameterGradientOnce( const std::string &og, std::unordered_set *og_has_been_broadcast) const; - int GetOpDeviceID(const OpDesc &op) const; + int GetOpDeviceID(ir::Node *node) const; - void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; + void InsertAllReduceOp(Graph *result, const std::string &og) const; - void InsertDataBalanceOp(SSAGraph *result, + void InsertDataBalanceOp(Graph *result, const std::vector &datas) const; - void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, + void CreateBroadcastOp(Graph *result, const std::string &p_name, size_t src_dev_id) const; bool IsSparseGradient(const std::string &og) const; diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index d80bdcf15d798925c137460125964d3d7e65f67e..ee9f9184da65467b82794c99fe3e95b108373753 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -80,19 +80,21 @@ void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void OpHandleBase::AddInput(VarHandleBase *in) { this->inputs_.emplace_back(in); - in->pending_ops_.insert(this); + node_->inputs.push_back(in->Node()); + in->AddOutput(this, this->Node()); } void OpHandleBase::AddOutput(VarHandleBase *out) { outputs_.emplace_back(out); - out->generated_op_ = this; + node_->outputs.push_back(out->Node()); + out->AddInput(this, this->Node()); } void OpHandleBase::WaitInputVarGenerated() { for (auto in_var : inputs_) { if (NeedWait(in_var)) { for (auto &pair : dev_ctxes_) { - in_var->generated_op_->RecordWaitEventOnCtx(pair.second); + in_var->GeneratedOp()->RecordWaitEventOnCtx(pair.second); } } } @@ -101,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() { void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { for (auto *in : inputs_) { if (NeedWait(in)) { - in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]); + in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]); } } } @@ -117,7 +119,7 @@ size_t OpHandleBase::NoDummyInputSize() const { } bool OpHandleBase::NeedWait(VarHandleBase *in_var) { - return in_var && in_var->generated_op_; + return in_var && in_var->GeneratedOp(); } void OpHandleBase::RunAndRecordEvent(const std::function &callback) { diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 6aec178831161f8ac1306fc3ed72e3267ca3c7e5..2d7f18942890245249dd0619a40bb43833c9a2ee 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/details/var_handle.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/macros.h" @@ -26,9 +27,11 @@ namespace details { constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; +// Wraps ir::Node and provide helper utilities. +// It's responsible for populating necessary fields of ir::Node. class OpHandleBase { public: - OpHandleBase() {} + explicit OpHandleBase(ir::Node *node) : node_(node) {} virtual ~OpHandleBase(); @@ -82,6 +85,8 @@ class OpHandleBase { size_t NoDummyInputSize() const; + ir::Node *Node() { return node_; } + protected: void RunAndRecordEvent(const std::function &callback); @@ -90,6 +95,7 @@ class OpHandleBase { virtual void RunImpl() = 0; + ir::Node *node_; std::vector inputs_; std::vector outputs_; std::map dev_ctxes_; diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 4d14334cdfe06e2e805c2577458d6689e6324cc7..a6289b055f97b7b0e57928358d84117b33cf2df8 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -37,10 +37,13 @@ struct ReduceOpHandle : public OpHandleBase { #ifdef PADDLE_WITH_CUDA const platform::NCCLContextMap *nccl_ctxs_; - ReduceOpHandle(const std::vector &local_scopes, + ReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places, const platform::NCCLContextMap *nccl_ctxs) - : local_scopes_(local_scopes), places_(places), nccl_ctxs_(nccl_ctxs) { + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(nccl_ctxs) { if (nccl_ctxs_) { for (auto &p_ctx : nccl_ctxs_->contexts_) { dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); @@ -48,9 +51,9 @@ struct ReduceOpHandle : public OpHandleBase { } } #else - ReduceOpHandle(const std::vector &local_scopes, + ReduceOpHandle(ir::Node *node, const std::vector &local_scopes, const std::vector &places) - : local_scopes_(local_scopes), places_(places) {} + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} #endif std::string Name() const override; diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index ffdd7c14eb5097cc8285da090e4a72e1e3f43d86..3a9a58412391b188c5e804b41fa47b3607a36bd1 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -84,6 +84,7 @@ struct TestReduceOpHandle { } void InitReduceOp(size_t out_scope_idx) { + std::vector> nodes; // init scope for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); @@ -96,19 +97,21 @@ struct TestReduceOpHandle { } param_scopes_[out_scope_idx]->Var("out"); + nodes.emplace_back(new ir::Node("node")); if (use_gpu_) { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, + gpu_list_, nccl_ctxs_.get())); #else PADDLE_THROW("CUDA is not support."); #endif } else { #ifdef PADDLE_WITH_CUDA - op_handle_.reset( - new ReduceOpHandle(local_scopes_, gpu_list_, nccl_ctxs_.get())); + op_handle_.reset(new ReduceOpHandle(nodes.back().get(), local_scopes_, + gpu_list_, nccl_ctxs_.get())); #else - op_handle_.reset(new ReduceOpHandle(local_scopes_, gpu_list_)); + op_handle_.reset( + new ReduceOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); #endif } @@ -118,8 +121,10 @@ struct TestReduceOpHandle { if (!use_gpu_) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); } - auto *in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); - in_var_handle->generated_op_ = nullptr; + nodes.emplace_back(new ir::Node("node1")); + auto *in_var_handle = + new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); + in_var_handle->ClearGeneratedOp(); vars_.emplace_back(in_var_handle); op_handle_->AddInput(in_var_handle); } @@ -128,12 +133,13 @@ struct TestReduceOpHandle { vars_.emplace_back(new DummyVarHandle()); DummyVarHandle *in_dummy_var_handle = static_cast(vars_.back().get()); - in_dummy_var_handle->generated_op_ = nullptr; + in_dummy_var_handle->ClearGeneratedOp(); op_handle_->AddInput(in_dummy_var_handle); // add output - auto *out_var_handle = - new VarHandle(2, out_scope_idx, "out", gpu_list_[out_scope_idx]); + nodes.emplace_back(new ir::Node("node2")); + auto *out_var_handle = new VarHandle(nodes.back().get(), 2, out_scope_idx, + "out", gpu_list_[out_scope_idx]); vars_.emplace_back(out_var_handle); op_handle_->AddOutput(out_var_handle); diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index 586465f99fd94117c821be2952bffda385fbcf75..924ff4d118a192a43e5828a38fd1abbaac1a8526 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -18,10 +18,11 @@ namespace paddle { namespace framework { namespace details { -RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc, +RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, const Scope *local_scope, const std::string &name, const platform::Place &place) - : op_(framework::OpRegistry::CreateOp(op_desc)), + : OpHandleBase(node), + op_(framework::OpRegistry::CreateOp(op_desc)), local_scope_(local_scope), name_(name), place_(place) {} @@ -35,8 +36,8 @@ void RPCOpHandle::RunImpl() { if (in->DebugString() == "dummy") { // HACK continue; } - if (in->generated_op_) { - in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]); + if (in->GeneratedOp()) { + in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]); } } auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h index ae38c7fe19e102a330455d89a1068414a7835fab..7f99cdeacf618a9496eaef98520685d6d1621ae1 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.h +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -28,8 +28,9 @@ namespace framework { namespace details { struct RPCOpHandle : public OpHandleBase { - RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, - const std::string& name, const platform::Place& place); + RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc, + const Scope* local_scope, const std::string& name, + const platform::Place& place); std::string Name() const override; diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index d9c387e79dc71288e7330597fed57171d447f31b..609e18581957f62b040e04e937873b7a8fa5785a 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -19,10 +19,14 @@ namespace paddle { namespace framework { namespace details { -ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, +ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, + Scope *scope, platform::Place place, platform::DeviceContext *dev_ctx) - : coeff_(static_cast(1.0 / num_dev)), scope_(scope), place_(place) { + : OpHandleBase(node), + coeff_(static_cast(1.0 / num_dev)), + scope_(scope), + place_(place) { dev_ctxes_[place_] = dev_ctx; } diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index d93d599d46f130cf98f39f15697ce994a31e20c3..523b55724c82d4e2bef0520c10e5708c952a3ecc 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -25,7 +25,8 @@ namespace framework { namespace details { struct ScaleLossGradOpHandle : public OpHandleBase { - ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place, + ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope, + platform::Place place, platform::DeviceContext *context); ~ScaleLossGradOpHandle() final; diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index 20df7a4722d589ffd168f842e927cff8411096bb..cbfbcb1c0cd24f16773f9633310166371600790c 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -17,6 +17,9 @@ #include #include #include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" + #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" #include "paddle/fluid/framework/scope.h" diff --git a/paddle/fluid/framework/details/ssa_graph.cc b/paddle/fluid/framework/details/ssa_graph.cc deleted file mode 100644 index 1b8c889449059c563ea39f86250075ac2537cdbe..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph.cc +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/details/ssa_graph.h" diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h deleted file mode 100644 index e996a00c162186e47e77d007503ac67caa9f8024..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/ssa_graph.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/details/op_handle_base.h" -#include "paddle/fluid/framework/details/var_handle.h" - -namespace paddle { -namespace framework { -namespace details { - -// A SSA graph used by parallel executor. -struct SSAGraph { - // all variable in each devices. - // The outside vector is the device vector. Each element of this vector is a - // map from variable name to variables. The variables, who have the same name, - // will have a different version. The offset in the - // `std::vector>` is the version of varaibles. - std::vector< - std::unordered_map>>> - vars_; - - // aux variables to represent dependency. Useful to resolve data hazard. - std::unordered_set> dep_vars_; - - // all operators. NOTE that even we use a vector here, the operators is - // unordered. - std::vector> ops_; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 88a21f48879a15450051ad94ed76e1c48bf23014..7bc130ef6e8d2e0caf6e445d12950b87e6dd4dbd 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -17,8 +17,8 @@ namespace paddle { namespace framework { namespace details { -void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { - for (auto &var_map : graph->vars_) { +void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { + for (auto &var_map : graph->Get("vars")) { for (auto &name_pair : var_map) { if (name_pair.second.size() <= 1) { continue; @@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { auto it_old = name_pair.second.rbegin(); ++it_old; for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - auto *write_op = (*it_new)->generated_op_; - auto &read_ops = (*it_old)->pending_ops_; + OpHandleBase *write_op = (*it_new)->GeneratedOp(); + const auto &read_ops = (*it_old)->PendingOps(); for (auto *read_op : read_ops) { // Manually add a dependency var from read_op to write_op; @@ -37,10 +37,11 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { continue; } - auto *dep_var = new DummyVarHandle(); + auto *dep_var = new DummyVarHandle( + graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); read_op->AddOutput(dep_var); write_op->AddInput(dep_var); - graph->dep_vars_.emplace(dep_var); + graph->Get("dep_vars").emplace(dep_var); } } } @@ -48,13 +49,20 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) { } VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( - SSAGraph *graph, const std::string &each_var_name, - const platform::Place &place, size_t place_offset) { - auto &var_holders = graph->vars_[place_offset]; - auto &var_holder = var_holders[each_var_name]; + Graph *graph, ir::Node *node, const platform::Place &place, + size_t place_offset) { + auto &var_holders = graph->Get("vars")[place_offset]; + auto &var_holder = var_holders[node->Name()]; VarHandle *var = nullptr; if (var_holder.empty()) { - var = new VarHandle(0, place_offset, each_var_name, place); + if (node->Var()) { + var = new VarHandle(graph->CreateVarNode(node->Var()), 0, place_offset, + node->Name(), place); + } else { + var = new VarHandle( + graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable), 0, + place_offset, node->Name(), place); + } var_holder.emplace_back(var); } else { var = var_holder.rbegin()->get(); @@ -62,24 +70,26 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( return var; } -void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, - const std::string &each_var_name, +void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle, + ir::Node *new_node, const platform::Place &place, size_t place_offset) { - auto &vars = graph->vars_[place_offset][each_var_name]; + auto &vars = graph->Get("vars")[place_offset][new_node->Name()]; size_t version = vars.size(); - auto var = new VarHandle(version, place_offset, each_var_name, place); + auto var = + new VarHandle(new_node, version, place_offset, new_node->Name(), place); vars.emplace_back(var); op_handle->AddOutput(var); } -void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { - for (auto &op : graph->ops_) { +void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) { + for (auto &op : graph->Get("ops")) { if (!op->Outputs().empty()) { continue; } - auto *dummy_leaf = new DummyVarHandle(); - graph->dep_vars_.emplace(dummy_leaf); + auto *dummy_leaf = new DummyVarHandle( + graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); + graph->Get("dep_vars").emplace(dummy_leaf); op->AddOutput(dummy_leaf); } } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 18612c3c1b62cf4c2ebdc221c301c59ec81c2da7..e8e8acdb38f893302fb92c47d6f1cb2d38453e0f 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -16,20 +16,42 @@ #include #include +#include + +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" -#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + namespace paddle { namespace framework { namespace details { -class SSAGraphBuilder { +// all variable in each devices. +// The outside vector is the device vector. Each element of this vector is a +// map from variable name to variables. The variables, who have the same name, +// will have a differsent version. The offset in the +// `std::vector>` is the version of varaibles. +typedef std::vector< + std::unordered_map>>> + GraphVars; + +// aux variables to represent dependency. Useful to resolve data hazard. +typedef std::unordered_set> GraphDepVars; + +// all operators. NOTE that even we use a vector here, the operators is +// unordered. +typedef std::vector> GraphOps; + +class SSAGraphBuilder : public ir::Pass { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; + virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); @@ -42,20 +64,19 @@ class SSAGraphBuilder { * * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) */ - static void PolishGraphToSupportDataHazards(SSAGraph *graph); + static void PolishGraphToSupportDataHazards(Graph *graph); - static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, - const std::string &each_var_name, + static VarHandle *CreateOrGetLatestVarHandle(Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset); // Add an output variable (each_var_name, place, place_offset) to op_handle, // which belongs to graph - static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, - const std::string &each_var_name, - const platform::Place &place, size_t place_offset); + static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle, + ir::Node *new_node, const platform::Place &place, + size_t place_offset); - static void AddOutputToLeafOps(SSAGraph *graph); + static void AddOutputToLeafOps(Graph *graph); }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index da5428946ee588e8eac1f78929dc0432df532975..7c79d7f1e881c67514634d56caa715c41927dbce 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/ssa_graph.h" -#include #include "paddle/fluid/framework/details/ssa_graph_checker.h" +#include +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { namespace details { -bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { +bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { std::unordered_map pending_ops; std::unordered_set pending_vars; std::unordered_set ready_vars; @@ -28,12 +28,12 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { auto insert_pending_var = [&](VarHandleBase *var) { pending_vars.insert(var); - if (var->generated_op_ == nullptr) { + if (var->GeneratedOp() == nullptr) { ready_vars.emplace(var); } }; - for (auto &var_map : graph->vars_) { + for (auto &var_map : graph->Get("vars")) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { insert_pending_var(version_pair.get()); @@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { } } - for (auto &var : graph->dep_vars_) { + for (auto &var : graph->Get("dep_vars")) { insert_pending_var(var.get()); } - for (auto &op : graph->ops_) { + for (auto &op : graph->Get("ops")) { if (op->Inputs().empty()) { ready_ops.insert(op.get()); } else { @@ -71,7 +71,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { for (auto ready_var : ready_vars) { pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { + for (auto *op : ready_var->PendingOps()) { auto &deps = --pending_ops[op]; if (deps == 0) { ready_ops.insert(op); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 331aa9d2b5864c470dbd5e29ef6faccffdcf781c..f1080610381128325ea0affba760ac66798fd948 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -21,7 +21,6 @@ namespace paddle { namespace framework { namespace details { -struct SSAGraph; class SSAGraghBuilderWithChecker : public SSAGraphBuilder { public: @@ -29,17 +28,17 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { std::unique_ptr&& builder) : builder_(std::move(builder)) {} - std::unique_ptr Build(const ProgramDesc& program) const override { - auto graph = builder_->Build(program); - PADDLE_ENFORCE(IsValidGraph(graph.get())); - return graph; + std::unique_ptr Apply(std::unique_ptr graph) const override { + auto new_graph = builder_->Apply(std::move(graph)); + PADDLE_ENFORCE(IsValidGraph(new_graph.get())); + return std::move(new_graph); } int GetVarDeviceID(const std::string& var_name) const override { return builder_->GetVarDeviceID(var_name); } - bool IsValidGraph(const SSAGraph* graph) const; + bool IsValidGraph(const Graph* graph) const; private: std::unique_ptr builder_; diff --git a/paddle/fluid/framework/details/ssa_graph_executor.h b/paddle/fluid/framework/details/ssa_graph_executor.h index 958086033607a4ed8fb840f5b14fe5779625bd82..8815ec89b23bc874471eefde5fa855cd2a4bde1f 100644 --- a/paddle/fluid/framework/details/ssa_graph_executor.h +++ b/paddle/fluid/framework/details/ssa_graph_executor.h @@ -18,8 +18,8 @@ #include #include -#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 22a40ca4b25cdd8ed9856b6c71bffc79561edcac..6dd6fd262e35a192ba85eb3aa16660526d2ebca2 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -14,15 +14,15 @@ #include "paddle/fluid/framework/details/ssa_graph_printer.h" #include -#include "paddle/fluid/framework/details/ssa_graph.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { namespace details { template -static inline void IterAllVar(const SSAGraph &graph, Callback callback) { - for (auto &each : graph.vars_) { +static inline void IterAllVar(const Graph &graph, Callback callback) { + for (auto &each : graph.Get("vars")) { for (auto &pair1 : each) { for (auto &pair2 : pair1.second) { callback(*pair2); @@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) { } } - for (auto &var : graph.dep_vars_) { + for (auto &var : graph.Get("dep_vars")) { callback(*var); } } -void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, +void GraphvizSSAGraphPrinter::Print(const Graph &graph, std::ostream &sout) const { size_t var_id = 0; std::unordered_map vars; @@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, }); size_t op_id = 0; - for (auto &op : graph.ops_) { + for (auto &op : graph.Get("ops")) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 09b0333ef2cb43a306133aa5af98d37c11454d4d..411be02988a82b3e35d56833f92fc6fe405a2c3d 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -21,16 +21,16 @@ namespace paddle { namespace framework { namespace details { -struct SSAGraph; + class SSAGraphPrinter { public: virtual ~SSAGraphPrinter() {} - virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0; + virtual void Print(const Graph& graph, std::ostream& sout) const = 0; }; class GraphvizSSAGraphPrinter : public SSAGraphPrinter { public: - void Print(const SSAGraph& graph, std::ostream& sout) const override; + void Print(const Graph& graph, std::ostream& sout) const override; }; class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { @@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { stream_ptr_(std::move(sout)), stream_ref_(*stream_ptr_) {} - std::unique_ptr Build(const ProgramDesc& program) const override { - auto graph = builder_->Build(program); - printer_->Print(*graph, stream_ref_); - return graph; + std::unique_ptr Apply(std::unique_ptr graph) const override { + auto new_graph = builder_->Apply(std::move(graph)); + printer_->Print(*new_graph, stream_ref_); + return std::move(new_graph); } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 07097c7e75c6ce638549716cd6523f387cdefd92..38cde13fe279d264c51baff71cffcab7b6ebb227 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -14,13 +14,14 @@ #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/ssa_graph_builder.h" + namespace paddle { namespace framework { namespace details { ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, - const std::vector &places, - std::unique_ptr &&graph) + const std::vector &places, std::unique_ptr &&graph) : graph_(std::move(graph)), pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) : nullptr), @@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( std::unordered_set delayed_ops; // Transform SSAGraph to pending_ops & pending_vars - for (auto &var_map : graph_->vars_) { + for (auto &var_map : graph_->Get("vars")) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); } } } - for (auto &var : graph_->dep_vars_) { + for (auto &var : graph_->Get("dep_vars")) { InsertPendingVar(&pending_vars, &ready_vars, var.get()); } - for (auto &op : graph_->ops_) { + for (auto &op : graph_->Get("ops")) { if (op->Inputs().empty()) { // Special case, Op has no input. ready_ops.insert(op.get()); } else { @@ -64,11 +65,12 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Step 2. Insert FetchOps std::vector> fetch_ops; + std::vector> tmp_nodes; std::unordered_set> fetch_dependencies; FeedFetchList fetch_data(fetch_tensors.size()); - InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, - &pending_vars, &ready_vars, &fetch_data); + InsertFetchOps(fetch_tensors, &fetch_ops, &tmp_nodes, &fetch_dependencies, + &pending_ops, &pending_vars, &ready_vars, &fetch_data); auto run_all_ops = [&](std::unordered_set &set) { for (auto *op : set) { @@ -125,7 +127,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // Find the ready_ops after the ready_var. for (auto ready_var : cur_ready_vars) { pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { + for (auto *op : ready_var->PendingOps()) { auto &deps = pending_ops[op]; --deps; if (deps == 0) { @@ -151,6 +153,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( void ThreadedSSAGraphExecutor::InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, + std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, @@ -158,7 +161,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( std::unordered_map> fetched_vars; for (auto &fetch_var_name : fetch_tensors) { - for (auto &var_map : graph_->vars_) { + for (auto &var_map : graph_->Get("vars")) { auto it = var_map.find(fetch_var_name); if (it != var_map.end()) { fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); @@ -169,7 +172,10 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( for (size_t i = 0; i < fetch_tensors.size(); ++i) { auto &var_name = fetch_tensors[i]; auto &vars = fetched_vars.at(var_name); - auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_); + + temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); + auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, + &local_scopes_); fetch_ops->emplace_back(op); for (auto &p : places_) { @@ -180,7 +186,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( op->AddInput(var); } - auto *fetch_dummy = new DummyVarHandle(); + temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); + auto *fetch_dummy = new DummyVarHandle(temp_nodes->back().get()); op->AddOutput(fetch_dummy); fetch_dependencies->emplace(fetch_dummy); this->InsertPendingVar(pending_vars, ready_vars, fetch_dummy); @@ -198,7 +205,7 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( std::unordered_set *pending_vars, BlockingQueue *ready_vars, VarHandleBase *var) const { pending_vars->insert(var); - if (var->generated_op_ == nullptr) { + if (var->GeneratedOp() == nullptr) { ready_vars->Push(var); } } diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 09973b7a72881464ad9e7776d4aad3d2261a118d..bf7c0a367a19ff4ac9462334516f1577672faa68 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/ssa_graph_executor.h" +#include "paddle/fluid/framework/ir/graph.h" namespace paddle { namespace framework { @@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph); + std::unique_ptr &&graph); // Run a SSAGraph by a thread pool // Use topological sort algorithm @@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { details::OpHandleBase *op); private: - std::unique_ptr graph_; + std::unique_ptr graph_; std::unique_ptr<::ThreadPool> pool_; std::vector local_scopes_; std::vector places_; @@ -71,6 +72,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { void InsertFetchOps( const std::vector &fetch_tensors, std::vector> *fetch_ops, + std::vector> *temp_nodes, std::unordered_set> *fetch_dependencies, std::unordered_map *pending_ops, std::unordered_set *pending_vars, diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h index cae9af7217660fb7e4b8535ee8e022fb3a127668..d8c2bc40b9458a1d5a7dd8a32277d04f69295f09 100644 --- a/paddle/fluid/framework/details/var_handle.h +++ b/paddle/fluid/framework/details/var_handle.h @@ -13,11 +13,14 @@ // limitations under the License. #pragma once + +#include #include #include #include #include +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -25,19 +28,60 @@ namespace framework { namespace details { class OpHandleBase; +// Wraps ir::Node and provide helper utilities. +// It's responsible for populating necessary fields of ir::Node. +// // VarHandleBase is the var node in the dependency graph. // A variable can only be generated by a single operator. i.e. // This is a single assignment graph. struct VarHandleBase { + explicit VarHandleBase(ir::Node* node) : node_(node) {} + virtual ~VarHandleBase(); + virtual std::string DebugString() const = 0; + void AddInput(OpHandleBase* in, ir::Node* node) { + node_->inputs.clear(); + node_->inputs.push_back(node); + generated_op_ = in; + } + + void AddOutput(OpHandleBase* out, ir::Node* node) { + if (pending_ops_.find(out) == pending_ops_.end()) { + pending_ops_.insert(out); + node_->outputs.push_back(node); + } + } + + void RemoveOutput(OpHandleBase* out, ir::Node* node) { + pending_ops_.erase(out); + node_->outputs.erase( + std::remove(node_->outputs.begin(), node_->outputs.end(), node), + node_->outputs.end()); + } + + void ClearGeneratedOp() { + generated_op_ = nullptr; + node_->inputs.clear(); + } + + OpHandleBase* GeneratedOp() { return generated_op_; } + + const std::unordered_set& PendingOps() const { + return pending_ops_; + } + + ir::Node* Node() { return node_; } + + protected: // The operator who generate this variable. nullptr if the variable // is a root node. OpHandleBase* generated_op_{nullptr}; // Operators which depend on this variable ready. std::unordered_set pending_ops_; + ir::Node* node_; }; // VarHandle is actually a single version of Runtime Variable. @@ -46,11 +90,14 @@ struct VarHandleBase { // // NOTE: runtime variables have place. struct VarHandle : public VarHandleBase { + explicit VarHandle(ir::Node* node) : VarHandleBase(node) {} + std::string DebugString() const override; - VarHandle(size_t version, size_t scope_index, std::string name, - platform::Place place) - : version_(version), + VarHandle(ir::Node* node, size_t version, size_t scope_index, + std::string name, platform::Place place) + : VarHandleBase(node), + version_(version), scope_idx_(scope_index), name_(std::move(name)), place_(std::move(place)) {} @@ -70,6 +117,8 @@ struct VarHandle : public VarHandleBase { // Dummy Variable. It is used to represent dependencies between operators struct DummyVarHandle : public VarHandleBase { + explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {} + std::string DebugString() const override; }; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8ed06aa69c7f08b0600aa87cd482469ec78dfa3 --- /dev/null +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -0,0 +1,5 @@ +cc_library(graph SRCS graph.cc node) +cc_library(node SRCS node.cc) +cc_library(pass SRCS pass.cc graph node) + +cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc new file mode 100644 index 0000000000000000000000000000000000000000..688f7ba5825bf1a1ab65a0912663481913223e80 --- /dev/null +++ b/paddle/fluid/framework/ir/graph.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" + +namespace paddle { +namespace framework { + +// NOTE(paddle-dev): This graph contains circle. +Graph::Graph(const ProgramDesc &program) : program_(program) { + std::unordered_map all_vars; + for (auto *var : program.Block(0).AllVars()) { + all_vars.emplace(var->Name(), var); + } + + std::map var_nodes; + for (auto *op : program.Block(0).AllOps()) { + ir::Node *node = CreateOpNode(op); + + for (auto &each_var_name : op->InputArgumentNames()) { + ir::Node *var = nullptr; + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name); + } else if (all_vars.count(each_var_name) != 0) { + var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name] = var; + } else { + // TODO(paddle-dev): Seems some assumption doesn't hold? + LOG(ERROR) << op->Type() + << " input var not in all_var list: " << each_var_name; + var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); + var_nodes[each_var_name] = var; + } + node->inputs.push_back(var); + var->outputs.push_back(node); + } + + for (auto &each_var_name : op->OutputArgumentNames()) { + ir::Node *var = nullptr; + if (var_nodes.find(each_var_name) != var_nodes.end()) { + var = var_nodes.at(each_var_name); + } else { + var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name] = var; + } + node->outputs.push_back(var); + var->inputs.push_back(node); + } + } +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h new file mode 100644 index 0000000000000000000000000000000000000000..b4ac135b029005b723abca2cb9b9a9aa175eda40 --- /dev/null +++ b/paddle/fluid/framework/ir/graph.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/variant.h" + +namespace paddle { +namespace framework { + +class Graph { + public: + explicit Graph(const ProgramDesc& program); + + virtual ~Graph() { + for (auto& attr : attrs_) { + attr_dels_[attr.first](); + } + attrs_.clear(); + attr_dels_.clear(); + } + + template + AttrType& Get(const std::string& attr_name) const { + return *boost::any_cast(attrs_.at(attr_name)); + } + + template + void Set(const std::string& attr_name, AttrType* attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = [attr, attr_name]() { + VLOG(3) << "deleting " << attr_name; + delete attr; + }; + } + + ir::Node* CreateVarNode(VarDesc* var_desc) { + nodes.emplace_back(new ir::Node(var_desc)); + return nodes.back().get(); + } + + ir::Node* CreateOpNode(OpDesc* op_desc) { + nodes.emplace_back(new ir::Node(op_desc)); + return nodes.back().get(); + } + + ir::Node* CreateEmptyNode(const std::string& name, ir::Node::Type type) { + nodes.emplace_back(new ir::Node(name, type)); + return nodes.back().get(); + } + + std::vector> nodes; + + private: + // NOTE: program_ shouldn't be exposed to user. + const ProgramDesc& program_; + std::map attrs_; + std::map> attr_dels_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e23bf124f8822e25be0f6b1c7c8c5de4e4f600a --- /dev/null +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/graph.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { + +class NOP : public OperatorBase { + public: + NOP(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const Scope &scope, + const platform::Place &place) const override {} +}; + +class SumOpMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "").AsDuplicable(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class SumOpVarTypeInference : public VarTypeInference { + public: + void operator()(const OpDesc &op_desc, BlockDesc *block) const override { + auto &inputs = op_desc.Input("X"); + auto default_var_type = proto::VarType::SELECTED_ROWS; + + bool any_input_is_lod_tensor = std::any_of( + inputs.begin(), inputs.end(), [block](const std::string &name) { + return block->Var(name)->GetType() == proto::VarType::LOD_TENSOR; + }); + if (any_input_is_lod_tensor) { + default_var_type = proto::VarType::LOD_TENSOR; + } + + auto out_var_name = op_desc.Output("Out").front(); + block->Var(out_var_name)->SetType(default_var_type); + } +}; +} // namespace framework +} // namespace paddle + +REGISTER_OPERATOR(sum, paddle::framework::NOP, paddle::framework::SumOpMaker, + paddle::framework::SumOpVarTypeInference); +REGISTER_OPERATOR(sum_without_infer_var_type, paddle::framework::NOP, + paddle::framework::SumOpMaker); + +namespace paddle { +namespace framework { + +TEST(GraphTest, Basic) { + ProgramDesc prog; + auto *op = prog.MutableBlock(0)->AppendOp(); + op->SetType("sum"); + op->SetInput("X", {"test_a", "test_b", "test_c"}); + op->SetOutput("Out", {"test_out"}); + + prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_c")->SetType(proto::VarType::SELECTED_ROWS); + prog.MutableBlock(0)->Var("test_out"); + + op->InferVarType(prog.MutableBlock(0)); + + ASSERT_EQ(proto::VarType::SELECTED_ROWS, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::LOD_TENSOR); + op->InferVarType(prog.MutableBlock(0)); + ASSERT_EQ(proto::VarType::LOD_TENSOR, + prog.MutableBlock(0)->Var("test_out")->GetType()); + + std::unique_ptr g(new Graph(prog)); + ASSERT_EQ(g->nodes[0]->Name(), "sum"); + ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a"); + ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b"); + ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c"); + ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out"); + ASSERT_EQ(g->nodes[1]->Name(), "test_a"); + ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum"); + ASSERT_EQ(g->nodes[2]->Name(), "test_b"); + ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum"); + ASSERT_EQ(g->nodes[3]->Name(), "test_c"); + ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum"); + ASSERT_EQ(g->nodes[4]->Name(), "test_out"); + ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum"); + ASSERT_EQ(g->nodes.size(), 5); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc new file mode 100644 index 0000000000000000000000000000000000000000..86376e7e8bc8bee2ddbc18f7f24bcdd849a06cbf --- /dev/null +++ b/paddle/fluid/framework/ir/node.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h new file mode 100644 index 0000000000000000000000000000000000000000..b98c29b81ddc2f57553b8fe76fcfeb0936ddd837 --- /dev/null +++ b/paddle/fluid/framework/ir/node.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Node { + public: + enum class Type { kOperation, kVariable }; + explicit Node(const std::string& name, Type type) + : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} + + explicit Node(VarDesc* var_desc) + : name_(var_desc->Name()), + var_desc_(var_desc), + op_desc_(nullptr), + type_(Type::kVariable) {} + + explicit Node(OpDesc* op_desc) + : name_(op_desc->Type()), + var_desc_(nullptr), + op_desc_(op_desc), + type_(Type::kOperation) {} + + Type NodeType() const { return type_; } + + std::string Name() const { return name_; } + + VarDesc* Var() { + PADDLE_ENFORCE(type_ == Type::kVariable); + return var_desc_; + } + OpDesc* Op() { + PADDLE_ENFORCE(type_ == Type::kOperation); + return op_desc_; + } + + std::vector inputs; + std::vector outputs; + + protected: + const std::string name_; + VarDesc* var_desc_; + OpDesc* op_desc_; + Type type_; + + private: + DISABLE_COPY_AND_ASSIGN(Node); +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..c05d7d0bb54c8ba5938e08f7e8dace8f607d7b89 --- /dev/null +++ b/paddle/fluid/framework/ir/pass.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h new file mode 100644 index 0000000000000000000000000000000000000000..f52ba788d55ddb9ed27baa3f6ff0a97e52370fe0 --- /dev/null +++ b/paddle/fluid/framework/ir/pass.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Pass { + public: + Pass() = default; + virtual ~Pass() {} + + virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 9a72e1baa34274201c40bd83a7aace549a7fc6ae..1e5bba62b53025dacdbf2d74b35f266cf4e422c2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/ir/graph.h" + #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -129,12 +131,11 @@ ParallelExecutor::ParallelExecutor( PADDLE_THROW("Not compiled with CUDA."); #endif } - builder_ = builder_factory.Create(); + std::unique_ptr graph(new Graph(main_program)); + graph = builder_->Apply(std::move(graph)); member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, - builder_->Build(main_program))); - + exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); diff --git a/paddle/fluid/platform/variant.h b/paddle/fluid/platform/variant.h index 45f60fc9d76560b133fa06198a24c7eaccc24088..dc9fad29f281a1c6ac300b48f9e600ff802a5752 100644 --- a/paddle/fluid/platform/variant.h +++ b/paddle/fluid/platform/variant.h @@ -38,6 +38,7 @@ limitations under the License. */ #endif #endif +#include #include #include #include