From af79b192077a6485fc90be4112f80acce8fb748b Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Thu, 12 Jul 2018 09:47:37 +0800 Subject: [PATCH] add a simple program to graph --- paddle/fluid/framework/CMakeLists.txt | 2 +- .../framework/details/multi_devices_graph_builder.cc | 4 ++-- .../framework/details/multi_devices_graph_builder.h | 2 +- paddle/fluid/framework/details/ssa_graph_builder.h | 2 +- paddle/fluid/framework/details/ssa_graph_checker.h | 8 ++++---- paddle/fluid/framework/details/ssa_graph_printer.h | 8 ++++---- paddle/fluid/framework/ir/graph.cc | 9 ++++++++- paddle/fluid/framework/ir/graph.h | 8 ++++++++ paddle/fluid/framework/ir/pass.h | 6 ------ paddle/fluid/framework/parallel_executor.cc | 2 +- 10 files changed, 30 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index bae8f51bcf9..de06c860f55 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -94,7 +94,7 @@ else() endif() -cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor) +cc_library(parallel_executor SRCS parallel_executor.cc DEPS ssa_graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 9be4963c917..0a953704193 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -168,8 +168,8 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( } std::unique_ptr MultiDevSSAGraphBuilder::Build( - const ProgramDesc &program) const { - std::unique_ptr graph(new Graph); + std::unique_ptr graph) const { + const ProgramDesc &program = graph->Program(); for (auto *var : program.Block(0).AllVars()) { all_vars_.emplace(var->Name(), var); } diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index b9504665d04..248ea8ea62b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const BuildStrategy &strategy); #endif - std::unique_ptr Build(const ProgramDesc &program) const override; + std::unique_ptr Build(std::unique_ptr graph) const override; int GetVarDeviceID(const std::string &varname) const override; private: diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index 56c3077cb39..4fbf036241d 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -38,7 +38,7 @@ class SSAGraphBuilder { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; + virtual std::unique_ptr Build(std::unique_ptr graph) const = 0; virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 20fa432a8bd..7078b778bea 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -29,10 +29,10 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { std::unique_ptr&& builder) : builder_(std::move(builder)) {} - std::unique_ptr Build(const ProgramDesc& program) const override { - auto graph = builder_->Build(program); - PADDLE_ENFORCE(IsValidGraph(graph.get())); - return graph; + std::unique_ptr Build(std::unique_ptr graph) const override { + auto new_graph = builder_->Build(std::move(graph)); + PADDLE_ENFORCE(IsValidGraph(new_graph.get())); + return new_graph; } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index da98685a211..0bd2b10eda4 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -50,10 +50,10 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { stream_ptr_(std::move(sout)), stream_ref_(*stream_ptr_) {} - std::unique_ptr Build(const ProgramDesc& program) const override { - auto graph = builder_->Build(program); - printer_->Print(*graph, stream_ref_); - return graph; + std::unique_ptr Build(std::unique_ptr graph) const override { + auto new_graph = builder_->Build(std::move(graph)); + printer_->Print(*new_graph, stream_ref_); + return new_graph; } int GetVarDeviceID(const std::string& var_name) const override { diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index b5c5ba7c149..28ad4efc719 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -15,5 +15,12 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +std::unique_ptr ProgramToGraph(const ProgramDesc &program) { + std::unique_ptr graph(new Graph(program)); + return std::move(graph); +} + +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 72602840fc3..e83cb5a82a3 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/variant.h" @@ -28,6 +29,8 @@ namespace framework { class Graph { public: + explicit Graph(const ProgramDesc& program) : program_(program) {} + virtual ~Graph() { for (auto& attr : attrs_) { attr_dels_[attr.first](); @@ -36,6 +39,8 @@ class Graph { attr_dels_.clear(); } + const ProgramDesc& Program() const { return program_; } + template AttrType& Get(const std::string& attr_name) const { return *boost::any_cast(attrs_.at(attr_name)); @@ -63,9 +68,12 @@ class Graph { std::vector> nodes; private: + const ProgramDesc& program_; std::map attrs_; std::map> attr_dels_; }; +std::unique_ptr ProgramToGraph(const ProgramDesc& program); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 087ebb87098..2fc26c053f0 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -30,11 +30,5 @@ class Pass { } }; -std::unique_ptr ProgramToGraph(const ProgramDesc& program) { - std::unique_ptr g(new Graph); - - return std::move(g); -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3db2d9cdc4c..42bbd2b3ff4 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -133,7 +133,7 @@ ParallelExecutor::ParallelExecutor( } builder_ = builder_factory.Create(); - std::unique_ptr graph = builder_->Build(main_program); + std::unique_ptr graph = builder_->Build(ProgramToGraph(main_program)); std::unique_ptr ssa_graph(new details::SSAGraph); ssa_graph->vars_ = std::move(graph->Get("vars")); -- GitLab