From 142e832d21715c0ce651e4ac04f10554945e5ad7 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Wed, 25 Jul 2018 19:59:24 +0800 Subject: [PATCH] pass registration --- .../details/multi_devices_graph_builder.cc | 31 ++--- .../details/multi_devices_graph_builder.h | 27 ++-- .../details/ssa_graph_builder_factory.cc | 33 ++--- .../framework/details/ssa_graph_checker.h | 12 +- .../framework/details/ssa_graph_printer.h | 34 ++--- paddle/fluid/framework/ir/graph_viz_pass.cc | 6 +- paddle/fluid/framework/ir/graph_viz_pass.h | 6 - paddle/fluid/framework/ir/pass.cc | 9 +- paddle/fluid/framework/ir/pass.h | 117 +++++++++++++++++- paddle/fluid/framework/parallel_executor.cc | 20 ++- 10 files changed, 191 insertions(+), 104 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 22f0cb20d0..4fad520f40 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -34,30 +34,16 @@ namespace paddle { namespace framework { namespace details { +void MultiDevSSAGraphBuilder::Init() const { + loss_var_name_ = Get("loss_var_name"); + places_ = Get>("places"); + local_scopes_ = Get>("local_scopes"); + strategy_ = Get("strategy"); #ifdef PADDLE_WITH_CUDA -MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( - const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy) - : loss_var_name_(loss_var_name), - places_(places), - local_scopes_(local_scopes), - nccl_ctxs_(nccl_ctxs), - strategy_(strategy) { -#else -MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( - const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, const BuildStrategy &strategy) - : loss_var_name_(loss_var_name), - places_(places), - local_scopes_(local_scopes), - strategy_(strategy) { + nccl_ctxs_ = &Get("nccl_ctxs"); #endif - for (auto &p : params) { + + for (auto &p : Get>("params")) { grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); @@ -241,6 +227,7 @@ std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { std::unique_ptr MultiDevSSAGraphBuilder::Apply( std::unique_ptr graph) const { + Init(); // Give the topology sort order and rebuild the graph structure. std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); auto nodes = graph->ReleaseNodes(); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 55076f227b..c8c1b2a438 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -32,20 +32,6 @@ namespace details { class MultiDevSSAGraphBuilder : public SSAGraphBuilder { public: -#ifdef PADDLE_WITH_CUDA - MultiDevSSAGraphBuilder(const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - platform::NCCLContextMap *nccl_ctxs, - const BuildStrategy &strategy); -#else - MultiDevSSAGraphBuilder(const std::vector &places, - const std::string &loss_var_name, - const std::unordered_set ¶ms, - const std::vector &local_scopes, - const BuildStrategy &strategy); -#endif std::unique_ptr Apply( std::unique_ptr graph) const override; int GetVarDeviceID(const std::string &varname) const override; @@ -53,15 +39,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { private: void CreateOpHandleIOs(ir::Graph *result, ir::Node *node, size_t device_id) const; + void Init() const; private: - std::string loss_var_name_; - const std::vector &places_; - const std::vector &local_scopes_; - std::unordered_set grad_names_; + mutable std::string loss_var_name_; + mutable std::vector places_; + mutable std::vector local_scopes_; + mutable std::unordered_set grad_names_; #ifdef PADDLE_WITH_CUDA - platform::NCCLContextMap *nccl_ctxs_; + mutable platform::NCCLContextMap *nccl_ctxs_; #endif bool IsScaleLossOp(ir::Node *node) const; @@ -113,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector &var_names) const; private: - BuildStrategy strategy_; + mutable BuildStrategy strategy_; mutable std::unordered_map all_vars_; mutable std::unordered_map var_name_on_devices_; mutable std::vector balance_vars_; diff --git a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc index b4b49d3de6..e8d83943ac 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder_factory.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder_factory.cc @@ -22,26 +22,29 @@ namespace paddle { namespace framework { namespace details { std::unique_ptr SSAGraphBuilderFactory::Create() { - std::unique_ptr res( + std::unique_ptr res(new MultiDevSSAGraphBuilder); + res->SetNotOwned>("places", &places_); + res->SetNotOwned("loss_var_name", &loss_var_name_); + res->SetNotOwned>("params", ¶m_names_); + res->SetNotOwned>("local_scopes", &local_scopes_); + res->SetNotOwned("strategy", &strategy_); #ifdef PADDLE_WITH_CUDA - new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, - local_scopes_, nccl_ctxs_, strategy_) -#else - new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_, - local_scopes_, strategy_) + res->SetNotOwned("nccl_ctxs", nccl_ctxs_); #endif - ); // NOLINT if (!strategy_.debug_graphviz_path_.empty()) { - std::unique_ptr fout( - new std::ofstream(strategy_.debug_graphviz_path_)); - PADDLE_ENFORCE(fout->good()); - std::unique_ptr graphviz_printer( - new GraphvizSSAGraphPrinter()); - res.reset(new SSAGraghBuilderWithPrinter( - std::move(fout), std::move(graphviz_printer), std::move(res))); + SSAGraphBuilder *previous_pass = res.release(); + res.reset(new SSAGraghBuilderWithPrinter); + res->Set("previous_pass", previous_pass); + res->SetNotOwned("debug_graphviz_path", + &strategy_.debug_graphviz_path_); + res->Set("graph_printer", + new GraphvizSSAGraphPrinter); } - res.reset(new SSAGraghBuilderWithChecker(std::move(res))); + + SSAGraphBuilder *previous_pass = res.release(); + res.reset(new SSAGraghBuilderWithChecker); + res->Set("previous_pass", previous_pass); return res; } diff --git a/paddle/fluid/framework/details/ssa_graph_checker.h b/paddle/fluid/framework/details/ssa_graph_checker.h index 51ce6e5eca..ae5ad16b0c 100644 --- a/paddle/fluid/framework/details/ssa_graph_checker.h +++ b/paddle/fluid/framework/details/ssa_graph_checker.h @@ -24,25 +24,19 @@ namespace details { class SSAGraghBuilderWithChecker : public SSAGraphBuilder { public: - explicit SSAGraghBuilderWithChecker( - std::unique_ptr&& builder) - : builder_(std::move(builder)) {} - std::unique_ptr Apply( std::unique_ptr graph) const override { - auto new_graph = builder_->Apply(std::move(graph)); + auto new_graph = + Get("previous_pass").Apply(std::move(graph)); PADDLE_ENFORCE(IsValidGraph(new_graph.get())); return new_graph; } int GetVarDeviceID(const std::string& var_name) const override { - return builder_->GetVarDeviceID(var_name); + return Get("previous_pass").GetVarDeviceID(var_name); } bool IsValidGraph(const ir::Graph* graph) const; - - private: - std::unique_ptr builder_; }; } // namespace details diff --git a/paddle/fluid/framework/details/ssa_graph_printer.h b/paddle/fluid/framework/details/ssa_graph_printer.h index a77c1bad3f..2a939ef4c9 100644 --- a/paddle/fluid/framework/details/ssa_graph_printer.h +++ b/paddle/fluid/framework/details/ssa_graph_printer.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include #include #include "paddle/fluid/framework/details/ssa_graph_builder.h" @@ -35,37 +37,21 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { public: - SSAGraghBuilderWithPrinter(std::ostream& sout, - std::unique_ptr&& printer, - std::unique_ptr&& builder) - : printer_(std::move(printer)), - builder_(std::move(builder)), - stream_ref_(sout) {} - - SSAGraghBuilderWithPrinter(std::unique_ptr&& sout, - std::unique_ptr&& printer, - std::unique_ptr&& builder) - : printer_(std::move(printer)), - builder_(std::move(builder)), - stream_ptr_(std::move(sout)), - stream_ref_(*stream_ptr_) {} - std::unique_ptr Apply( std::unique_ptr graph) const override { - auto new_graph = builder_->Apply(std::move(graph)); - printer_->Print(*new_graph, stream_ref_); + auto new_graph = + Get("previous_pass").Apply(std::move(graph)); + + std::unique_ptr fout( + new std::ofstream(Get("debug_graphviz_path"))); + PADDLE_ENFORCE(fout->good()); + Get("graph_printer").Print(*new_graph, *fout); return new_graph; } int GetVarDeviceID(const std::string& var_name) const override { - return builder_->GetVarDeviceID(var_name); + return Get("previous_pass").GetVarDeviceID(var_name); } - - private: - std::unique_ptr printer_; - std::unique_ptr builder_; - std::unique_ptr stream_ptr_; - std::ostream& stream_ref_; }; } // namespace details diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index c839ebadac..7d1cff7178 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -23,7 +23,8 @@ namespace ir { std::unique_ptr GraphVizPass::Apply( std::unique_ptr graph) const { - std::unique_ptr fout(new std::ofstream(graph_viz_path_)); + const std::string graph_viz_path = Get("graph_viz_path"); + std::unique_ptr fout(new std::ofstream(graph_viz_path)); PADDLE_ENFORCE(fout->good()); std::ostream& sout = *fout; @@ -61,6 +62,9 @@ std::unique_ptr GraphVizPass::Apply( sout << "}\n"; return graph; } + } // namespace ir } // namespace framework } // namespace paddle + +REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass); diff --git a/paddle/fluid/framework/ir/graph_viz_pass.h b/paddle/fluid/framework/ir/graph_viz_pass.h index 08c534f417..04c0c35d12 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.h +++ b/paddle/fluid/framework/ir/graph_viz_pass.h @@ -29,14 +29,8 @@ namespace ir { class GraphVizPass : public Pass { public: - explicit GraphVizPass(const std::string& graph_viz_path) - : graph_viz_path_(graph_viz_path) {} - std::unique_ptr Apply( std::unique_ptr graph) const override; - - private: - const std::string graph_viz_path_; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index c05d7d0bb5..0e68ecb56f 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -15,5 +15,12 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/pass.h" namespace paddle { -namespace framework {} // namespace framework +namespace framework { +namespace ir { +PassRegistry& PassRegistry::Instance() { + static PassRegistry g_pass_info_map; + return g_pass_info_map; +} +} // namespace ir +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index f52ba788d5..9466924262 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -14,9 +14,14 @@ limitations under the License. */ #pragma once +#include +#include +#include + #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/variant.h" namespace paddle { namespace framework { @@ -25,10 +30,120 @@ namespace ir { class Pass { public: Pass() = default; - virtual ~Pass() {} + virtual ~Pass() { + for (auto &attr : attrs_) { + if (attr_dels_.find(attr.first) != attr_dels_.end()) { + attr_dels_[attr.first](); + } + } + attrs_.clear(); + attr_dels_.clear(); + } virtual std::unique_ptr Apply(std::unique_ptr graph) const = 0; + + template + AttrType &Get(const std::string &attr_name) const { + return *boost::any_cast(attrs_.at(attr_name)); + } + + template + void Set(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + attrs_[attr_name] = attr; + attr_dels_[attr_name] = [attr, attr_name]() { + VLOG(3) << "deleting " << attr_name; + delete attr; + }; + } + + template + void SetNotOwned(const std::string &attr_name, AttrType *attr) { + PADDLE_ENFORCE(attrs_.count(attr_name) == 0); + attrs_[attr_name] = attr; + } + + private: + std::map attrs_; + std::map> attr_dels_; +}; + +using PassCreator = std::function()>; + +class Registrar { + public: + // In our design, various kinds of passes, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which + // are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_PASS macros to + // call this method. So, as long as the callee code calls USE_PASS, the global + // registrar variable won't be removed by the linker. + void Touch() {} }; + +class PassRegistry { + public: + static PassRegistry &Instance(); + + bool Has(const std::string &pass_type) const { + return map_.find(pass_type) != map_.end(); + } + + void Insert(const std::string &type, const PassCreator &pass_creator) { + PADDLE_ENFORCE(!Has(type), "Pass %s has been registered", type); + map_.insert({type, pass_creator}); + } + + std::unique_ptr Get(const std::string &type) const { + PADDLE_ENFORCE(Has(type), "Pass %s has not been registered", type); + return map_.at(type)(); + } + + private: + PassRegistry() = default; + std::unordered_map map_; + + DISABLE_COPY_AND_ASSIGN(PassRegistry); +}; + +template +struct PassRegistrar : public Registrar { + explicit PassRegistrar(const char *pass_type) { + PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), + "'%s' is registered more than once.", pass_type); + PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr { + return std::unique_ptr(new PassType()); + }); + } +}; + +#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +#define REGISTER_PASS(pass_type, pass_class) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __reg_pass__##pass_type, \ + "REGISTER_PASS must be called in global namespace"); \ + static ::paddle::framework::ir::PassRegistrar \ + __pass_registrar_##pass_type##__(#pass_type); \ + int TouchPassRegistrar_##pass_type() { \ + __pass_registrar_##pass_type##__.Touch(); \ + return 0; \ + } + +#define USE_PASS(pass_type) \ + STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ + __use_pass_itself_##pass_type, \ + "USE_PASS must be called in global namespace"); \ + extern int TouchPassRegistrar_##pass_type(); \ + static int use_pass_itself_##pass_type##_ __attribute__((unused)) = \ + TouchPassRegistrar_##pass_type() + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index fbd5acc3e5..ff661d0013 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -132,19 +132,27 @@ ParallelExecutor::ParallelExecutor( PADDLE_THROW("Not compiled with CUDA."); #endif } - builder_ = builder_factory.Create(); + std::unique_ptr graph(new ir::Graph(main_program)); if (!build_strategy.debug_graphviz_path_.empty()) { - const std::string origin_graph_path = string::Sprintf( + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + const std::string graph_path = string::Sprintf( "%s%s", build_strategy.debug_graphviz_path_.c_str(), "_original_graph"); - graph = ir::GraphVizPass(origin_graph_path).Apply(std::move(graph)); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + graph = viz_pass->Apply(std::move(graph)); } + + builder_ = builder_factory.Create(); graph = builder_->Apply(std::move(graph)); + if (!build_strategy.debug_graphviz_path_.empty()) { - const std::string origin_graph_path = string::Sprintf( + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + const std::string graph_path = string::Sprintf( "%s%s", build_strategy.debug_graphviz_path_.c_str(), "_before_exec"); - graph = ir::GraphVizPass(origin_graph_path).Apply(std::move(graph)); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + graph = viz_pass->Apply(std::move(graph)); } + member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( @@ -297,3 +305,5 @@ ParallelExecutor::~ParallelExecutor() { } // namespace framework } // namespace paddle + +USE_PASS(graph_viz_pass); -- GitLab