From 68aa5004512cd12e8d81b08d2fe40ddcdfb59f2f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 11 Jul 2018 20:18:31 +0800 Subject: [PATCH] polish attrs --- .../details/multi_devices_graph_builder.cc | 9 +-- .../details/multi_devices_graph_builder.h | 2 +- .../framework/details/ssa_graph_builder.h | 2 +- .../framework/details/ssa_graph_checker.cc | 8 +-- .../framework/details/ssa_graph_checker.h | 4 +- .../framework/details/ssa_graph_printer.cc | 10 ++-- .../framework/details/ssa_graph_printer.h | 6 +- paddle/fluid/framework/ir/graph.h | 60 ++++++------------- paddle/fluid/framework/parallel_executor.cc | 13 +++- 9 files changed, 46 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 9ac961f1b15..9be4963c917 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -167,7 +167,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( return dev_id; } -std::unique_ptr MultiDevSSAGraphBuilder::Build( +std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { std::unique_ptr graph(new Graph); for (auto *var : program.Block(0).AllVars()) { @@ -301,12 +301,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); - - std::unique_ptr ssa_graph(new SSAGraph); - ssa_graph->vars_ = std::move(*graph->Erase("vars")); - ssa_graph->ops_ = std::move(*graph->Erase("ops")); - ssa_graph->dep_vars_ = std::move(*graph->Erase("dep_vars")); - return std::move(ssa_graph); + return std::move(graph); } bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 3d7642f522e..b9504665d04 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -47,7 +47,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const BuildStrategy &strategy); #endif - std::unique_ptr Build(const ProgramDesc &program) const override; + std::unique_ptr Build(const ProgramDesc &program) const override; int GetVarDeviceID(const std::string &varname) const override; private: diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index d5aabb9fd1a..56c3077cb39 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -38,7 +38,7 @@ class SSAGraphBuilder { public: SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {} - virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; + virtual std::unique_ptr Build(const ProgramDesc &program) const = 0; virtual int GetVarDeviceID(const std::string &var_name) const = 0; DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); diff --git a/paddle/fluid/framework/details/ssa_graph_checker.cc b/paddle/fluid/framework/details/ssa_graph_checker.cc index da5428946ee..c01334ca06f 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.cc +++ b/paddle/fluid/framework/details/ssa_graph_checker.cc @@ -20,7 +20,7 @@ namespace paddle { namespace framework { namespace details { -bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { +bool SSAGraghBuilderWithChecker::IsValidGraph(const Graph *graph) const { std::unordered_map pending_ops; std::unordered_set pending_vars; std::unordered_set ready_vars; @@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { } }; - for (auto &var_map : graph->vars_) { + for (auto &var_map : graph->Get("vars")) { for (auto &name_pair : var_map) { for (auto &version_pair : name_pair.second) { insert_pending_var(version_pair.get()); @@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const SSAGraph *graph) const { } } - for (auto &var : graph->dep_vars_) { + for (auto &var : graph->Get("dep_vars")) { insert_pending_var(var.get()); } - for (auto &op : graph->ops_) { + for (auto &op : graph->Get("ops")) { if (op->Inputs().empty()) { ready_ops.insert(op.get()); } else { diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 331aa9d2b58..20fa432a8bd 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -29,7 +29,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { std::unique_ptr&& builder) : builder_(std::move(builder)) {} - std::unique_ptr Build(const ProgramDesc& program) const override { + std::unique_ptr Build(const ProgramDesc& program) const override { auto graph = builder_->Build(program); PADDLE_ENFORCE(IsValidGraph(graph.get())); return graph; @@ -39,7 +39,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { return builder_->GetVarDeviceID(var_name); } - bool IsValidGraph(const SSAGraph* graph) const; + bool IsValidGraph(const Graph* graph) const; private: std::unique_ptr builder_; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.cc b/paddle/fluid/framework/details/ssa_graph_printer.cc index 22a40ca4b25..412b0a6ff2f 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.cc +++ b/paddle/fluid/framework/details/ssa_graph_printer.cc @@ -21,8 +21,8 @@ namespace framework { namespace details { template -static inline void IterAllVar(const SSAGraph &graph, Callback callback) { - for (auto &each : graph.vars_) { +static inline void IterAllVar(const Graph &graph, Callback callback) { + for (auto &each : graph.Get("vars")) { for (auto &pair1 : each) { for (auto &pair2 : pair1.second) { callback(*pair2); @@ -30,12 +30,12 @@ static inline void IterAllVar(const SSAGraph &graph, Callback callback) { } } - for (auto &var : graph.dep_vars_) { + for (auto &var : graph.Get("dep_vars")) { callback(*var); } } -void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, +void GraphvizSSAGraphPrinter::Print(const Graph &graph, std::ostream &sout) const { size_t var_id = 0; std::unordered_map vars; @@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph, }); size_t op_id = 0; - for (auto &op : graph.ops_) { + for (auto &op : graph.Get("ops")) { std::string op_name = "op_" + std::to_string(op_id++); sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" << std::endl; diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index 09b0333ef2c..da98685a211 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -25,12 +25,12 @@ struct SSAGraph; class SSAGraphPrinter { public: virtual ~SSAGraphPrinter() {} - virtual void Print(const SSAGraph& graph, std::ostream& sout) const = 0; + virtual void Print(const Graph& graph, std::ostream& sout) const = 0; }; class GraphvizSSAGraphPrinter : public SSAGraphPrinter { public: - void Print(const SSAGraph& graph, std::ostream& sout) const override; + void Print(const Graph& graph, std::ostream& sout) const override; }; class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { @@ -50,7 +50,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { stream_ptr_(std::move(sout)), stream_ref_(*stream_ptr_) {} - std::unique_ptr Build(const ProgramDesc& program) const override { + std::unique_ptr Build(const ProgramDesc& program) const override { auto graph = builder_->Build(program); printer_->Print(*graph, stream_ref_); return graph; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index f1de4d740d3..21b9fa943e1 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -26,71 +26,45 @@ limitations under the License. */ namespace paddle { namespace framework { -class Graph; - -template -struct AnyAttr { - public: - explicit AnyAttr(AttrType* attr) : attr_(attr) {} - - AttrType& Get() { return *boost::any_cast(attr_); } - - private: - friend Graph; - - AttrType* Release() { - released_ = true; - return boost::any_cast(attr_); - } - - void Delete() { - if (!released_) { - delete boost::any_cast(attr_); - } - } - - bool released_ = false; - boost::any attr_; -}; - class Graph { public: virtual ~Graph() { - for (auto& attr : attrs) { - attr_dels[attr.first](); + for (auto& attr : attrs_) { + attr_dels_[attr.first](); } - attrs.clear(); - attr_dels.clear(); + attrs_.clear(); + attr_dels_.clear(); } template - AttrType& Get(const std::string& attr_name) { - return boost::any_cast>(attrs[attr_name]).Get(); + AttrType& Get(const std::string& attr_name) const { + return *boost::any_cast(attrs_.at(attr_name)); } template void Set(const std::string& attr_name, AttrType* attr) { - AnyAttr any_attr = AnyAttr(attr); - attrs[attr_name] = any_attr; - attr_dels[attr_name] = [&any_attr]() { any_attr.Delete(); }; + attrs_[attr_name] = attr; + attr_dels_[attr_name] = [attr, attr_name]() { + VLOG(3) << "deleting " << attr_name; + delete attr; + }; } template AttrType* Erase(const std::string& attr_name) { - AnyAttr attr_type = - boost::any_cast>(attrs[attr_name]); - attrs.erase(attr_name); - attr_dels.erase(attr_name); - return attr_type.Release(); + AttrType* attr = boost::any_cast(attrs_[attr_name]); + attrs_.erase(attr_name); + attr_dels_.erase(attr_name); + return attr; } std::vector inputs; std::vector outputs; std::vector> nodes; - std::map attrs; - std::map> attr_dels; private: + std::map attrs_; + std::map> attr_dels_; }; } // namespace framework diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 9a72e1baa34..3db2d9cdc4c 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/details/ssa_graph.h" + #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -131,9 +133,16 @@ ParallelExecutor::ParallelExecutor( } builder_ = builder_factory.Create(); + std::unique_ptr graph = builder_->Build(main_program); + + std::unique_ptr ssa_graph(new details::SSAGraph); + ssa_graph->vars_ = std::move(graph->Get("vars")); + ssa_graph->ops_ = std::move(graph->Get("ops")); + ssa_graph->dep_vars_ = + std::move(graph->Get("dep_vars")); + member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, places, - builder_->Build(main_program))); + exec_strategy, member_->local_scopes_, places, std::move(ssa_graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, std::move(var_infos), -- GitLab